Files
synthea-alldiseases/module_generator/run_module_generator.py

478 lines
20 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Disease Module Generator Loop for Synthea
This script runs the module_generator.py script in a loop to generate
modules for diseases in disease_list.json that don't have modules yet.
Usage:
python run_module_generator.py [--batch-size BATCH_SIZE] [--max-modules MAX_MODULES]
Arguments:
--batch-size BATCH_SIZE Number of modules to generate in each batch (default: 5)
--max-modules MAX_MODULES Maximum total number of modules to generate (default: no limit)
--prioritize Prioritize high-prevalence diseases first
--max-cost MAX_COST Maximum cost in USD to spend (default: no limit)
--strict Fail immediately on module validation errors instead of trying to fix them
"""
import os
import sys
import json
import subprocess
import argparse
import time
import logging
import anthropic
import re
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("module_generation_runner.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Constants
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
# Support both Docker (/app) and local development paths
if os.path.exists("/app/src/main/resources/disease_list.json"):
DISEASE_LIST_PATH = "/app/src/main/resources/disease_list.json"
MODULES_DIR = "/app/src/main/resources/modules"
LOG_FILE_PATH = "/app/module_generation_runner.log"
else:
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
LOG_FILE_PATH = os.path.join(SYNTHEA_ROOT, "module_generation_runner.log")
MODULE_GENERATOR_PATH = os.path.join(os.path.dirname(__file__), "module_generator.py")
# API costs - Claude 3.7 Sonnet (as of March 2025)
# These are approximate costs based on current pricing
CLAUDE_INPUT_COST_PER_1K_TOKENS = 0.003 # $0.003 per 1K input tokens
CLAUDE_OUTPUT_COST_PER_1K_TOKENS = 0.015 # $0.015 per 1K output tokens
# Initialize cost tracking
total_input_tokens = 0
total_output_tokens = 0
total_cost_usd = 0.0
def count_existing_modules():
"""Count the number of modules already created"""
module_files = list(Path(MODULES_DIR).glob("*.json"))
return len(module_files)
def count_remaining_diseases():
"""Count how many diseases in disease_list.json don't have modules yet"""
# Load the disease list
with open(DISEASE_LIST_PATH, 'r') as f:
diseases = json.load(f)
# Get existing module names
existing_modules = set()
for module_file in Path(MODULES_DIR).glob("*.json"):
module_name = module_file.stem.lower()
existing_modules.add(module_name)
# Count diseases that don't have modules
remaining = 0
for disease in diseases:
# Extract disease name from the disease object
disease_name = disease['disease_name']
# Convert disease name to module filename format
module_name = disease_name.lower()
module_name = ''.join(c if c.isalnum() else '_' for c in module_name)
module_name = module_name.strip('_')
# Replace consecutive underscores with a single one
while '__' in module_name:
module_name = module_name.replace('__', '_')
if module_name not in existing_modules:
remaining += 1
return remaining, len(diseases)
def parse_api_usage(log_content):
"""Parse API usage stats from log content"""
global total_input_tokens, total_output_tokens, total_cost_usd
# Look for token counts in the log - using the updated pattern from the logs
# This matches the format: "Estimated input tokens: 7031, Estimated output tokens: 3225"
input_token_matches = re.findall(r'Estimated input tokens: (\d+)', log_content)
output_token_matches = re.findall(r'Estimated output tokens: (\d+)', log_content)
batch_input_tokens = sum(int(count) for count in input_token_matches)
batch_output_tokens = sum(int(count) for count in output_token_matches)
# Calculate cost for this batch
batch_input_cost = batch_input_tokens * CLAUDE_INPUT_COST_PER_1K_TOKENS / 1000
batch_output_cost = batch_output_tokens * CLAUDE_OUTPUT_COST_PER_1K_TOKENS / 1000
batch_total_cost = batch_input_cost + batch_output_cost
# Update totals
total_input_tokens += batch_input_tokens
total_output_tokens += batch_output_tokens
total_cost_usd += batch_total_cost
# Log the token counts found for debugging
logger.info(f"Found token counts in log - Input: {input_token_matches}, Output: {output_token_matches}")
return batch_input_tokens, batch_output_tokens, batch_total_cost
def run_module_generator(batch_size=5, prioritize=False, timeout=600, strict=False):
"""Run the module generator script with the specified batch size"""
try:
# Create a unique log file for this batch to capture token usage
batch_log_file = f"module_generator_batch_{int(time.time())}.log"
cmd = [
sys.executable,
MODULE_GENERATOR_PATH,
'--limit', str(batch_size),
'--log-file', batch_log_file
]
if prioritize:
cmd.append('--prioritize')
if strict:
cmd.append('--strict')
logger.info(f"Running command: {' '.join(cmd)}")
print(f"\n[INFO] Starting batch, generating up to {batch_size} modules...")
# Use timeout to prevent indefinite hanging
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=timeout)
success = True
except subprocess.TimeoutExpired:
logger.error(f"Module generator timed out after {timeout} seconds")
print(f"\n[ERROR] Process timed out after {timeout} seconds!")
success = False
result = None
# Process output only if the command completed successfully
if success and result:
# Log output
if result.stdout:
for line in result.stdout.splitlines():
logger.info(f"Generator output: {line}")
# Print progress-related lines directly to console
if "Generated module for" in line or "modules complete" in line:
print(f"[PROGRESS] {line}")
# Check if any errors
if result.stderr:
for line in result.stderr.splitlines():
if "Error" in line or "Exception" in line:
logger.error(f"Generator error: {line}")
print(f"[ERROR] {line}")
else:
logger.info(f"Generator message: {line}")
# Parse token usage from the log file
if os.path.exists(batch_log_file):
with open(batch_log_file, 'r') as f:
log_content = f.read()
# Parse usage statistics
batch_input_tokens, batch_output_tokens, batch_cost = parse_api_usage(log_content)
# Log usage for this batch
logger.info(f"Batch API Usage - Input tokens: {batch_input_tokens:,}, "
f"Output tokens: {batch_output_tokens:,}, Cost: ${batch_cost:.4f}")
logger.info(f"Total API Usage - Input tokens: {total_input_tokens:,}, "
f"Output tokens: {total_output_tokens:,}, Total Cost: ${total_cost_usd:.4f}")
# Print to console in a prominent way
print("\n" + "="*80)
print(f"BATCH API USAGE - Cost: ${batch_cost:.4f}")
print(f"TOTAL API USAGE - Cost: ${total_cost_usd:.4f}")
print("="*80 + "\n")
# Clean up batch log file
try:
os.remove(batch_log_file)
except:
pass
# Count created files to verify success
if os.path.exists(batch_log_file):
with open(batch_log_file, 'r') as f:
log_content = f.read()
modules_created = log_content.count("Successfully created module for")
if modules_created > 0:
print(f"[SUCCESS] Created {modules_created} module(s) in this batch")
# Clean up even if processing didn't complete
try:
os.remove(batch_log_file)
except:
pass
return success
except subprocess.CalledProcessError as e:
logger.error(f"Module generator failed with exit code {e.returncode}")
logger.error(f"Error output: {e.stderr}")
print(f"[ERROR] Module generator process failed with exit code {e.returncode}")
return False
except Exception as e:
logger.error(f"Failed to run module generator: {e}")
print(f"[ERROR] Failed to run module generator: {e}")
return False
def validate_and_fix_existing_modules(strict=False):
"""Check all existing modules for JSON validity and fix if needed"""
print("Validating existing modules...")
if strict:
print("Running in strict mode - modules will NOT be fixed automatically")
fixed_count = 0
module_files = list(Path(MODULES_DIR).glob("*.json"))
# First, show a summary of all modules
print(f"Found {len(module_files)} module files")
invalid_modules = []
# Check all modules for validity
for module_path in module_files:
try:
with open(module_path, 'r') as f:
content = f.read()
# Try to parse the JSON
try:
json.loads(content)
# If it parses successfully, all good
continue
except json.JSONDecodeError as e:
invalid_modules.append((module_path, e))
print(f"⚠️ Invalid JSON found in {module_path.name}: {e}")
except Exception as e:
print(f"Error processing {module_path}: {e}")
if not invalid_modules:
print("✅ All modules are valid JSON")
return 0
print(f"\nFound {len(invalid_modules)} invalid module(s) to fix")
# Now fix the invalid modules
for module_path, error in invalid_modules:
try:
print(f"\nAttempting to fix {module_path.name}...")
if strict:
print(f"❌ Module {module_path.name} has validation issues that must be fixed manually")
continue # Skip fixing in strict mode
with open(module_path, 'r') as f:
content = f.read()
# Try to fix common issues
fixed_content = content
# Fix trailing commas
fixed_content = re.sub(r',\s*}', '}', fixed_content)
fixed_content = re.sub(r',\s*]', ']', fixed_content)
# Check for incomplete states section
if '"states": {' in fixed_content:
# Make sure it has a proper closing brace
states_match = re.search(r'"states":\s*{(.*)}(?=\s*,|\s*})', fixed_content, re.DOTALL)
if not states_match and '"states": {' in fixed_content:
print(" - Module appears to have an incomplete states section")
# Find last complete state definition
last_state_end = None
state_pattern = re.compile(r'("[^"]+"\s*:\s*{[^{]*?}),?', re.DOTALL)
for match in state_pattern.finditer(fixed_content, fixed_content.find('"states": {') + 12):
last_state_end = match.end()
if last_state_end:
# Add closing brace after the last valid state
fixed_content = fixed_content[:last_state_end] + '\n }\n }' + fixed_content[last_state_end:]
print(" - Added closing braces for states section")
# Fix missing/unbalanced braces
open_braces = fixed_content.count('{')
close_braces = fixed_content.count('}')
if open_braces > close_braces:
fixed_content += '}' * (open_braces - close_braces)
print(f" - Added {open_braces - close_braces} missing closing braces")
elif close_braces > open_braces:
for _ in range(close_braces - open_braces):
fixed_content = fixed_content.rstrip().rstrip('}') + '}'
print(f" - Removed {close_braces - open_braces} excess closing braces")
# Try to parse the fixed content
try:
parsed_json = json.loads(fixed_content)
# Check if it has all required elements
if 'name' not in parsed_json:
print(" - Warning: Module missing 'name' field")
if 'states' not in parsed_json:
print(" - Warning: Module missing 'states' field")
if 'gmf_version' not in parsed_json:
print(" - Adding missing gmf_version field")
# Make sure we have a proper JSON object
if fixed_content.rstrip().endswith('}'):
# Add the gmf_version before the final brace
fixed_content = fixed_content.rstrip().rstrip('}') + ',\n "gmf_version": 2\n}'
# Backup the original file
backup_path = str(module_path) + '.bak'
import shutil
shutil.copy2(str(module_path), backup_path)
# Write the fixed content
with open(module_path, 'w') as f:
f.write(fixed_content)
print(f"✅ Fixed module {module_path.name}")
fixed_count += 1
except json.JSONDecodeError as e:
print(f"❌ Could not fix module {module_path.name}: {e}")
except Exception as e:
print(f"Error processing {module_path}: {e}")
print(f"\nValidation complete. Fixed {fixed_count} of {len(invalid_modules)} invalid modules.")
return fixed_count
def main():
parser = argparse.ArgumentParser(description='Run the module generator in a loop')
parser.add_argument('--batch-size', type=int, default=5,
help='Number of modules to generate in each batch (default: 5)')
parser.add_argument('--max-modules', type=int, default=None,
help='Maximum total number of modules to generate (default: no limit)')
parser.add_argument('--max-batches', type=int, default=None,
help='Maximum number of batches to run (default: no limit)')
parser.add_argument('--prioritize', action='store_true',
help='Prioritize high-prevalence diseases first')
parser.add_argument('--max-cost', type=float, default=None,
help='Maximum cost in USD to spend (default: no limit)')
parser.add_argument('--timeout', type=int, default=600,
help='Timeout in seconds for each batch process (default: 600)')
parser.add_argument('--fix-modules', action='store_true',
help='Check and fix existing modules for JSON validity')
parser.add_argument('--strict', action='store_true',
help='Fail immediately on module validation errors instead of trying to fix them')
args = parser.parse_args()
# If fix-modules flag is set, validate and fix existing modules
if args.fix_modules:
validate_and_fix_existing_modules(strict=args.strict)
return
# Initial counts
initial_modules = count_existing_modules()
initial_remaining, total_diseases = count_remaining_diseases()
logger.info(f"Starting module generation loop")
logger.info(f"Currently have {initial_modules} modules")
logger.info(f"Remaining diseases without modules: {initial_remaining} of {total_diseases}")
# Set up counters
modules_created = 0
batch_count = 0
# Loop until we've hit our maximum or there are no more diseases to process
while True:
# Break if we've reached the maximum modules to generate
if args.max_modules is not None and modules_created >= args.max_modules:
logger.info(f"Reached maximum modules limit ({args.max_modules})")
print(f"\n{'!'*80}")
print(f"MODULE LIMIT REACHED: {modules_created} >= {args.max_modules}")
print(f"{'!'*80}\n")
break
# Break if we've reached the maximum batches
if args.max_batches is not None and batch_count >= args.max_batches:
logger.info(f"Reached maximum batch limit ({args.max_batches})")
print(f"\n{'!'*80}")
print(f"BATCH LIMIT REACHED: {batch_count} >= {args.max_batches}")
print(f"{'!'*80}\n")
break
# Break if we've reached the maximum cost
if args.max_cost is not None and total_cost_usd >= args.max_cost:
logger.info(f"Reached maximum cost limit (${args.max_cost:.2f})")
print(f"\n{'!'*80}")
print(f"COST LIMIT REACHED: ${total_cost_usd:.4f} >= ${args.max_cost:.2f}")
print(f"{'!'*80}\n")
break
# Count how many modules we need to generate in this batch
if args.max_modules is not None:
# Adjust batch size if we're near the maximum
current_batch_size = min(args.batch_size, args.max_modules - modules_created)
else:
current_batch_size = args.batch_size
if current_batch_size <= 0:
logger.info("No more modules to generate in this batch")
print("[INFO] No more modules to generate (reached limit)")
break
# Run the generator for this batch
batch_count += 1
logger.info(f"Starting batch {batch_count} (generating up to {current_batch_size} modules)")
# Run the generator
success = run_module_generator(
batch_size=current_batch_size,
prioritize=args.prioritize,
timeout=args.timeout,
strict=args.strict
)
if not success:
logger.error(f"Batch {batch_count} failed, stopping")
print(f"[ERROR] Batch {batch_count} failed, stopping")
break
# Count how many modules we've created so far
new_module_count = count_existing_modules()
modules_created = new_module_count - initial_modules
# Count remaining diseases
remaining_diseases, _ = count_remaining_diseases()
logger.info(f"Completed batch {batch_count}")
logger.info(f"Total modules created: {modules_created}")
logger.info(f"Remaining diseases without modules: {remaining_diseases}")
print(f"[STATUS] Batch {batch_count} complete")
print(f"[STATUS] Total modules created: {modules_created}")
print(f"[STATUS] Remaining diseases: {remaining_diseases} of {total_diseases}")
# Break if no diseases left
if remaining_diseases <= 0:
logger.info("All diseases have modules now, finished!")
print("[SUCCESS] All diseases have modules now, finished!")
break
# Sleep briefly between batches to avoid system overload
time.sleep(2)
# Final status
logger.info(f"Module generation complete!")
logger.info(f"Started with {initial_modules} modules, now have {count_existing_modules()}")
logger.info(f"Created {modules_created} new modules in {batch_count} batches")
remaining, total = count_remaining_diseases()
logger.info(f"Remaining diseases without modules: {remaining} of {total}")
if __name__ == "__main__":
main()