478 lines
20 KiB
Python
Executable File
478 lines
20 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Disease Module Generator Loop for Synthea
|
|
|
|
This script runs the module_generator.py script in a loop to generate
|
|
modules for diseases in disease_list.json that don't have modules yet.
|
|
|
|
Usage:
|
|
python run_module_generator.py [--batch-size BATCH_SIZE] [--max-modules MAX_MODULES]
|
|
|
|
Arguments:
|
|
--batch-size BATCH_SIZE Number of modules to generate in each batch (default: 5)
|
|
--max-modules MAX_MODULES Maximum total number of modules to generate (default: no limit)
|
|
--prioritize Prioritize high-prevalence diseases first
|
|
--max-cost MAX_COST Maximum cost in USD to spend (default: no limit)
|
|
--strict Fail immediately on module validation errors instead of trying to fix them
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import subprocess
|
|
import argparse
|
|
import time
|
|
import logging
|
|
import anthropic
|
|
import re
|
|
from pathlib import Path
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("module_generation_runner.log"),
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants
|
|
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
|
# Support both Docker (/app) and local development paths
|
|
if os.path.exists("/app/src/main/resources/disease_list.json"):
|
|
DISEASE_LIST_PATH = "/app/src/main/resources/disease_list.json"
|
|
MODULES_DIR = "/app/src/main/resources/modules"
|
|
LOG_FILE_PATH = "/app/module_generation_runner.log"
|
|
else:
|
|
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
|
|
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
|
|
LOG_FILE_PATH = os.path.join(SYNTHEA_ROOT, "module_generation_runner.log")
|
|
|
|
MODULE_GENERATOR_PATH = os.path.join(os.path.dirname(__file__), "module_generator.py")
|
|
|
|
# API costs - Claude 3.7 Sonnet (as of March 2025)
|
|
# These are approximate costs based on current pricing
|
|
CLAUDE_INPUT_COST_PER_1K_TOKENS = 0.003 # $0.003 per 1K input tokens
|
|
CLAUDE_OUTPUT_COST_PER_1K_TOKENS = 0.015 # $0.015 per 1K output tokens
|
|
|
|
# Initialize cost tracking
|
|
total_input_tokens = 0
|
|
total_output_tokens = 0
|
|
total_cost_usd = 0.0
|
|
|
|
def count_existing_modules():
|
|
"""Count the number of modules already created"""
|
|
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
|
return len(module_files)
|
|
|
|
def count_remaining_diseases():
|
|
"""Count how many diseases in disease_list.json don't have modules yet"""
|
|
# Load the disease list
|
|
with open(DISEASE_LIST_PATH, 'r') as f:
|
|
diseases = json.load(f)
|
|
|
|
# Get existing module names
|
|
existing_modules = set()
|
|
for module_file in Path(MODULES_DIR).glob("*.json"):
|
|
module_name = module_file.stem.lower()
|
|
existing_modules.add(module_name)
|
|
|
|
# Count diseases that don't have modules
|
|
remaining = 0
|
|
for disease in diseases:
|
|
# Extract disease name from the disease object
|
|
disease_name = disease['disease_name']
|
|
# Convert disease name to module filename format
|
|
module_name = disease_name.lower()
|
|
module_name = ''.join(c if c.isalnum() else '_' for c in module_name)
|
|
module_name = module_name.strip('_')
|
|
# Replace consecutive underscores with a single one
|
|
while '__' in module_name:
|
|
module_name = module_name.replace('__', '_')
|
|
|
|
if module_name not in existing_modules:
|
|
remaining += 1
|
|
|
|
return remaining, len(diseases)
|
|
|
|
def parse_api_usage(log_content):
|
|
"""Parse API usage stats from log content"""
|
|
global total_input_tokens, total_output_tokens, total_cost_usd
|
|
|
|
# Look for token counts in the log - using the updated pattern from the logs
|
|
# This matches the format: "Estimated input tokens: 7031, Estimated output tokens: 3225"
|
|
input_token_matches = re.findall(r'Estimated input tokens: (\d+)', log_content)
|
|
output_token_matches = re.findall(r'Estimated output tokens: (\d+)', log_content)
|
|
|
|
batch_input_tokens = sum(int(count) for count in input_token_matches)
|
|
batch_output_tokens = sum(int(count) for count in output_token_matches)
|
|
|
|
# Calculate cost for this batch
|
|
batch_input_cost = batch_input_tokens * CLAUDE_INPUT_COST_PER_1K_TOKENS / 1000
|
|
batch_output_cost = batch_output_tokens * CLAUDE_OUTPUT_COST_PER_1K_TOKENS / 1000
|
|
batch_total_cost = batch_input_cost + batch_output_cost
|
|
|
|
# Update totals
|
|
total_input_tokens += batch_input_tokens
|
|
total_output_tokens += batch_output_tokens
|
|
total_cost_usd += batch_total_cost
|
|
|
|
# Log the token counts found for debugging
|
|
logger.info(f"Found token counts in log - Input: {input_token_matches}, Output: {output_token_matches}")
|
|
|
|
return batch_input_tokens, batch_output_tokens, batch_total_cost
|
|
|
|
def run_module_generator(batch_size=5, prioritize=False, timeout=600, strict=False):
|
|
"""Run the module generator script with the specified batch size"""
|
|
try:
|
|
# Create a unique log file for this batch to capture token usage
|
|
batch_log_file = f"module_generator_batch_{int(time.time())}.log"
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
MODULE_GENERATOR_PATH,
|
|
'--limit', str(batch_size),
|
|
'--log-file', batch_log_file
|
|
]
|
|
|
|
if prioritize:
|
|
cmd.append('--prioritize')
|
|
|
|
if strict:
|
|
cmd.append('--strict')
|
|
|
|
logger.info(f"Running command: {' '.join(cmd)}")
|
|
print(f"\n[INFO] Starting batch, generating up to {batch_size} modules...")
|
|
|
|
# Use timeout to prevent indefinite hanging
|
|
try:
|
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=timeout)
|
|
success = True
|
|
except subprocess.TimeoutExpired:
|
|
logger.error(f"Module generator timed out after {timeout} seconds")
|
|
print(f"\n[ERROR] Process timed out after {timeout} seconds!")
|
|
success = False
|
|
result = None
|
|
|
|
# Process output only if the command completed successfully
|
|
if success and result:
|
|
# Log output
|
|
if result.stdout:
|
|
for line in result.stdout.splitlines():
|
|
logger.info(f"Generator output: {line}")
|
|
|
|
# Print progress-related lines directly to console
|
|
if "Generated module for" in line or "modules complete" in line:
|
|
print(f"[PROGRESS] {line}")
|
|
|
|
# Check if any errors
|
|
if result.stderr:
|
|
for line in result.stderr.splitlines():
|
|
if "Error" in line or "Exception" in line:
|
|
logger.error(f"Generator error: {line}")
|
|
print(f"[ERROR] {line}")
|
|
else:
|
|
logger.info(f"Generator message: {line}")
|
|
|
|
# Parse token usage from the log file
|
|
if os.path.exists(batch_log_file):
|
|
with open(batch_log_file, 'r') as f:
|
|
log_content = f.read()
|
|
|
|
# Parse usage statistics
|
|
batch_input_tokens, batch_output_tokens, batch_cost = parse_api_usage(log_content)
|
|
|
|
# Log usage for this batch
|
|
logger.info(f"Batch API Usage - Input tokens: {batch_input_tokens:,}, "
|
|
f"Output tokens: {batch_output_tokens:,}, Cost: ${batch_cost:.4f}")
|
|
logger.info(f"Total API Usage - Input tokens: {total_input_tokens:,}, "
|
|
f"Output tokens: {total_output_tokens:,}, Total Cost: ${total_cost_usd:.4f}")
|
|
|
|
# Print to console in a prominent way
|
|
print("\n" + "="*80)
|
|
print(f"BATCH API USAGE - Cost: ${batch_cost:.4f}")
|
|
print(f"TOTAL API USAGE - Cost: ${total_cost_usd:.4f}")
|
|
print("="*80 + "\n")
|
|
|
|
# Clean up batch log file
|
|
try:
|
|
os.remove(batch_log_file)
|
|
except:
|
|
pass
|
|
|
|
# Count created files to verify success
|
|
if os.path.exists(batch_log_file):
|
|
with open(batch_log_file, 'r') as f:
|
|
log_content = f.read()
|
|
modules_created = log_content.count("Successfully created module for")
|
|
if modules_created > 0:
|
|
print(f"[SUCCESS] Created {modules_created} module(s) in this batch")
|
|
|
|
# Clean up even if processing didn't complete
|
|
try:
|
|
os.remove(batch_log_file)
|
|
except:
|
|
pass
|
|
|
|
return success
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"Module generator failed with exit code {e.returncode}")
|
|
logger.error(f"Error output: {e.stderr}")
|
|
print(f"[ERROR] Module generator process failed with exit code {e.returncode}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to run module generator: {e}")
|
|
print(f"[ERROR] Failed to run module generator: {e}")
|
|
return False
|
|
|
|
def validate_and_fix_existing_modules(strict=False):
|
|
"""Check all existing modules for JSON validity and fix if needed"""
|
|
print("Validating existing modules...")
|
|
|
|
if strict:
|
|
print("Running in strict mode - modules will NOT be fixed automatically")
|
|
|
|
fixed_count = 0
|
|
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
|
|
|
# First, show a summary of all modules
|
|
print(f"Found {len(module_files)} module files")
|
|
|
|
invalid_modules = []
|
|
|
|
# Check all modules for validity
|
|
for module_path in module_files:
|
|
try:
|
|
with open(module_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Try to parse the JSON
|
|
try:
|
|
json.loads(content)
|
|
# If it parses successfully, all good
|
|
continue
|
|
except json.JSONDecodeError as e:
|
|
invalid_modules.append((module_path, e))
|
|
print(f"⚠️ Invalid JSON found in {module_path.name}: {e}")
|
|
except Exception as e:
|
|
print(f"Error processing {module_path}: {e}")
|
|
|
|
if not invalid_modules:
|
|
print("✅ All modules are valid JSON")
|
|
return 0
|
|
|
|
print(f"\nFound {len(invalid_modules)} invalid module(s) to fix")
|
|
|
|
# Now fix the invalid modules
|
|
for module_path, error in invalid_modules:
|
|
try:
|
|
print(f"\nAttempting to fix {module_path.name}...")
|
|
|
|
if strict:
|
|
print(f"❌ Module {module_path.name} has validation issues that must be fixed manually")
|
|
continue # Skip fixing in strict mode
|
|
|
|
with open(module_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Try to fix common issues
|
|
fixed_content = content
|
|
|
|
# Fix trailing commas
|
|
fixed_content = re.sub(r',\s*}', '}', fixed_content)
|
|
fixed_content = re.sub(r',\s*]', ']', fixed_content)
|
|
|
|
# Check for incomplete states section
|
|
if '"states": {' in fixed_content:
|
|
# Make sure it has a proper closing brace
|
|
states_match = re.search(r'"states":\s*{(.*)}(?=\s*,|\s*})', fixed_content, re.DOTALL)
|
|
if not states_match and '"states": {' in fixed_content:
|
|
print(" - Module appears to have an incomplete states section")
|
|
|
|
# Find last complete state definition
|
|
last_state_end = None
|
|
state_pattern = re.compile(r'("[^"]+"\s*:\s*{[^{]*?}),?', re.DOTALL)
|
|
for match in state_pattern.finditer(fixed_content, fixed_content.find('"states": {') + 12):
|
|
last_state_end = match.end()
|
|
|
|
if last_state_end:
|
|
# Add closing brace after the last valid state
|
|
fixed_content = fixed_content[:last_state_end] + '\n }\n }' + fixed_content[last_state_end:]
|
|
print(" - Added closing braces for states section")
|
|
|
|
# Fix missing/unbalanced braces
|
|
open_braces = fixed_content.count('{')
|
|
close_braces = fixed_content.count('}')
|
|
|
|
if open_braces > close_braces:
|
|
fixed_content += '}' * (open_braces - close_braces)
|
|
print(f" - Added {open_braces - close_braces} missing closing braces")
|
|
elif close_braces > open_braces:
|
|
for _ in range(close_braces - open_braces):
|
|
fixed_content = fixed_content.rstrip().rstrip('}') + '}'
|
|
print(f" - Removed {close_braces - open_braces} excess closing braces")
|
|
|
|
# Try to parse the fixed content
|
|
try:
|
|
parsed_json = json.loads(fixed_content)
|
|
|
|
# Check if it has all required elements
|
|
if 'name' not in parsed_json:
|
|
print(" - Warning: Module missing 'name' field")
|
|
if 'states' not in parsed_json:
|
|
print(" - Warning: Module missing 'states' field")
|
|
if 'gmf_version' not in parsed_json:
|
|
print(" - Adding missing gmf_version field")
|
|
# Make sure we have a proper JSON object
|
|
if fixed_content.rstrip().endswith('}'):
|
|
# Add the gmf_version before the final brace
|
|
fixed_content = fixed_content.rstrip().rstrip('}') + ',\n "gmf_version": 2\n}'
|
|
|
|
# Backup the original file
|
|
backup_path = str(module_path) + '.bak'
|
|
import shutil
|
|
shutil.copy2(str(module_path), backup_path)
|
|
|
|
# Write the fixed content
|
|
with open(module_path, 'w') as f:
|
|
f.write(fixed_content)
|
|
|
|
print(f"✅ Fixed module {module_path.name}")
|
|
fixed_count += 1
|
|
except json.JSONDecodeError as e:
|
|
print(f"❌ Could not fix module {module_path.name}: {e}")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {module_path}: {e}")
|
|
|
|
print(f"\nValidation complete. Fixed {fixed_count} of {len(invalid_modules)} invalid modules.")
|
|
return fixed_count
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Run the module generator in a loop')
|
|
parser.add_argument('--batch-size', type=int, default=5,
|
|
help='Number of modules to generate in each batch (default: 5)')
|
|
parser.add_argument('--max-modules', type=int, default=None,
|
|
help='Maximum total number of modules to generate (default: no limit)')
|
|
parser.add_argument('--max-batches', type=int, default=None,
|
|
help='Maximum number of batches to run (default: no limit)')
|
|
parser.add_argument('--prioritize', action='store_true',
|
|
help='Prioritize high-prevalence diseases first')
|
|
parser.add_argument('--max-cost', type=float, default=None,
|
|
help='Maximum cost in USD to spend (default: no limit)')
|
|
parser.add_argument('--timeout', type=int, default=600,
|
|
help='Timeout in seconds for each batch process (default: 600)')
|
|
parser.add_argument('--fix-modules', action='store_true',
|
|
help='Check and fix existing modules for JSON validity')
|
|
parser.add_argument('--strict', action='store_true',
|
|
help='Fail immediately on module validation errors instead of trying to fix them')
|
|
args = parser.parse_args()
|
|
|
|
# If fix-modules flag is set, validate and fix existing modules
|
|
if args.fix_modules:
|
|
validate_and_fix_existing_modules(strict=args.strict)
|
|
return
|
|
|
|
# Initial counts
|
|
initial_modules = count_existing_modules()
|
|
initial_remaining, total_diseases = count_remaining_diseases()
|
|
|
|
logger.info(f"Starting module generation loop")
|
|
logger.info(f"Currently have {initial_modules} modules")
|
|
logger.info(f"Remaining diseases without modules: {initial_remaining} of {total_diseases}")
|
|
|
|
# Set up counters
|
|
modules_created = 0
|
|
batch_count = 0
|
|
|
|
# Loop until we've hit our maximum or there are no more diseases to process
|
|
while True:
|
|
# Break if we've reached the maximum modules to generate
|
|
if args.max_modules is not None and modules_created >= args.max_modules:
|
|
logger.info(f"Reached maximum modules limit ({args.max_modules})")
|
|
print(f"\n{'!'*80}")
|
|
print(f"MODULE LIMIT REACHED: {modules_created} >= {args.max_modules}")
|
|
print(f"{'!'*80}\n")
|
|
break
|
|
|
|
# Break if we've reached the maximum batches
|
|
if args.max_batches is not None and batch_count >= args.max_batches:
|
|
logger.info(f"Reached maximum batch limit ({args.max_batches})")
|
|
print(f"\n{'!'*80}")
|
|
print(f"BATCH LIMIT REACHED: {batch_count} >= {args.max_batches}")
|
|
print(f"{'!'*80}\n")
|
|
break
|
|
|
|
# Break if we've reached the maximum cost
|
|
if args.max_cost is not None and total_cost_usd >= args.max_cost:
|
|
logger.info(f"Reached maximum cost limit (${args.max_cost:.2f})")
|
|
print(f"\n{'!'*80}")
|
|
print(f"COST LIMIT REACHED: ${total_cost_usd:.4f} >= ${args.max_cost:.2f}")
|
|
print(f"{'!'*80}\n")
|
|
break
|
|
|
|
# Count how many modules we need to generate in this batch
|
|
if args.max_modules is not None:
|
|
# Adjust batch size if we're near the maximum
|
|
current_batch_size = min(args.batch_size, args.max_modules - modules_created)
|
|
else:
|
|
current_batch_size = args.batch_size
|
|
|
|
if current_batch_size <= 0:
|
|
logger.info("No more modules to generate in this batch")
|
|
print("[INFO] No more modules to generate (reached limit)")
|
|
break
|
|
|
|
# Run the generator for this batch
|
|
batch_count += 1
|
|
logger.info(f"Starting batch {batch_count} (generating up to {current_batch_size} modules)")
|
|
|
|
# Run the generator
|
|
success = run_module_generator(
|
|
batch_size=current_batch_size,
|
|
prioritize=args.prioritize,
|
|
timeout=args.timeout,
|
|
strict=args.strict
|
|
)
|
|
|
|
if not success:
|
|
logger.error(f"Batch {batch_count} failed, stopping")
|
|
print(f"[ERROR] Batch {batch_count} failed, stopping")
|
|
break
|
|
|
|
# Count how many modules we've created so far
|
|
new_module_count = count_existing_modules()
|
|
modules_created = new_module_count - initial_modules
|
|
|
|
# Count remaining diseases
|
|
remaining_diseases, _ = count_remaining_diseases()
|
|
|
|
logger.info(f"Completed batch {batch_count}")
|
|
logger.info(f"Total modules created: {modules_created}")
|
|
logger.info(f"Remaining diseases without modules: {remaining_diseases}")
|
|
|
|
print(f"[STATUS] Batch {batch_count} complete")
|
|
print(f"[STATUS] Total modules created: {modules_created}")
|
|
print(f"[STATUS] Remaining diseases: {remaining_diseases} of {total_diseases}")
|
|
|
|
# Break if no diseases left
|
|
if remaining_diseases <= 0:
|
|
logger.info("All diseases have modules now, finished!")
|
|
print("[SUCCESS] All diseases have modules now, finished!")
|
|
break
|
|
|
|
# Sleep briefly between batches to avoid system overload
|
|
time.sleep(2)
|
|
|
|
# Final status
|
|
logger.info(f"Module generation complete!")
|
|
logger.info(f"Started with {initial_modules} modules, now have {count_existing_modules()}")
|
|
logger.info(f"Created {modules_created} new modules in {batch_count} batches")
|
|
|
|
remaining, total = count_remaining_diseases()
|
|
logger.info(f"Remaining diseases without modules: {remaining} of {total}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |