#!/usr/bin/env python3 """ Disease Module Generator Loop for Synthea This script runs the module_generator.py script in a loop to generate modules for diseases in disease_list.json that don't have modules yet. Usage: python run_module_generator.py [--batch-size BATCH_SIZE] [--max-modules MAX_MODULES] Arguments: --batch-size BATCH_SIZE Number of modules to generate in each batch (default: 5) --max-modules MAX_MODULES Maximum total number of modules to generate (default: no limit) --prioritize Prioritize high-prevalence diseases first --max-cost MAX_COST Maximum cost in USD to spend (default: no limit) --strict Fail immediately on module validation errors instead of trying to fix them """ import os import sys import json import subprocess import argparse import time import logging import anthropic import re from pathlib import Path # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("module_generation_runner.log"), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) # Constants SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) # Support both Docker (/app) and local development paths if os.path.exists("/app/src/main/resources/disease_list.json"): DISEASE_LIST_PATH = "/app/src/main/resources/disease_list.json" MODULES_DIR = "/app/src/main/resources/modules" LOG_FILE_PATH = "/app/module_generation_runner.log" else: DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json") MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules") LOG_FILE_PATH = os.path.join(SYNTHEA_ROOT, "module_generation_runner.log") MODULE_GENERATOR_PATH = os.path.join(os.path.dirname(__file__), "module_generator.py") # API costs - Claude 3.7 Sonnet (as of March 2025) # These are approximate costs based on current pricing CLAUDE_INPUT_COST_PER_1K_TOKENS = 0.003 # $0.003 per 1K input tokens CLAUDE_OUTPUT_COST_PER_1K_TOKENS = 0.015 # $0.015 per 1K output tokens # Initialize cost tracking total_input_tokens = 0 total_output_tokens = 0 total_cost_usd = 0.0 def count_existing_modules(): """Count the number of modules already created""" module_files = list(Path(MODULES_DIR).glob("*.json")) return len(module_files) def count_remaining_diseases(): """Count how many diseases in disease_list.json don't have modules yet""" # Load the disease list with open(DISEASE_LIST_PATH, 'r') as f: diseases = json.load(f) # Get existing module names existing_modules = set() for module_file in Path(MODULES_DIR).glob("*.json"): module_name = module_file.stem.lower() existing_modules.add(module_name) # Count diseases that don't have modules remaining = 0 for disease in diseases: # Extract disease name from the disease object disease_name = disease['disease_name'] # Convert disease name to module filename format module_name = disease_name.lower() module_name = ''.join(c if c.isalnum() else '_' for c in module_name) module_name = module_name.strip('_') # Replace consecutive underscores with a single one while '__' in module_name: module_name = module_name.replace('__', '_') if module_name not in existing_modules: remaining += 1 return remaining, len(diseases) def parse_api_usage(log_content): """Parse API usage stats from log content""" global total_input_tokens, total_output_tokens, total_cost_usd # Look for token counts in the log - using the updated pattern from the logs # This matches the format: "Estimated input tokens: 7031, Estimated output tokens: 3225" input_token_matches = re.findall(r'Estimated input tokens: (\d+)', log_content) output_token_matches = re.findall(r'Estimated output tokens: (\d+)', log_content) batch_input_tokens = sum(int(count) for count in input_token_matches) batch_output_tokens = sum(int(count) for count in output_token_matches) # Calculate cost for this batch batch_input_cost = batch_input_tokens * CLAUDE_INPUT_COST_PER_1K_TOKENS / 1000 batch_output_cost = batch_output_tokens * CLAUDE_OUTPUT_COST_PER_1K_TOKENS / 1000 batch_total_cost = batch_input_cost + batch_output_cost # Update totals total_input_tokens += batch_input_tokens total_output_tokens += batch_output_tokens total_cost_usd += batch_total_cost # Log the token counts found for debugging logger.info(f"Found token counts in log - Input: {input_token_matches}, Output: {output_token_matches}") return batch_input_tokens, batch_output_tokens, batch_total_cost def run_module_generator(batch_size=5, prioritize=False, timeout=600, strict=False): """Run the module generator script with the specified batch size""" try: # Create a unique log file for this batch to capture token usage batch_log_file = f"module_generator_batch_{int(time.time())}.log" cmd = [ sys.executable, MODULE_GENERATOR_PATH, '--limit', str(batch_size), '--log-file', batch_log_file ] if prioritize: cmd.append('--prioritize') if strict: cmd.append('--strict') logger.info(f"Running command: {' '.join(cmd)}") print(f"\n[INFO] Starting batch, generating up to {batch_size} modules...") # Use timeout to prevent indefinite hanging try: result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=timeout) success = True except subprocess.TimeoutExpired: logger.error(f"Module generator timed out after {timeout} seconds") print(f"\n[ERROR] Process timed out after {timeout} seconds!") success = False result = None # Process output only if the command completed successfully if success and result: # Log output if result.stdout: for line in result.stdout.splitlines(): logger.info(f"Generator output: {line}") # Print progress-related lines directly to console if "Generated module for" in line or "modules complete" in line: print(f"[PROGRESS] {line}") # Check if any errors if result.stderr: for line in result.stderr.splitlines(): if "Error" in line or "Exception" in line: logger.error(f"Generator error: {line}") print(f"[ERROR] {line}") else: logger.info(f"Generator message: {line}") # Parse token usage from the log file if os.path.exists(batch_log_file): with open(batch_log_file, 'r') as f: log_content = f.read() # Parse usage statistics batch_input_tokens, batch_output_tokens, batch_cost = parse_api_usage(log_content) # Log usage for this batch logger.info(f"Batch API Usage - Input tokens: {batch_input_tokens:,}, " f"Output tokens: {batch_output_tokens:,}, Cost: ${batch_cost:.4f}") logger.info(f"Total API Usage - Input tokens: {total_input_tokens:,}, " f"Output tokens: {total_output_tokens:,}, Total Cost: ${total_cost_usd:.4f}") # Print to console in a prominent way print("\n" + "="*80) print(f"BATCH API USAGE - Cost: ${batch_cost:.4f}") print(f"TOTAL API USAGE - Cost: ${total_cost_usd:.4f}") print("="*80 + "\n") # Clean up batch log file try: os.remove(batch_log_file) except: pass # Count created files to verify success if os.path.exists(batch_log_file): with open(batch_log_file, 'r') as f: log_content = f.read() modules_created = log_content.count("Successfully created module for") if modules_created > 0: print(f"[SUCCESS] Created {modules_created} module(s) in this batch") # Clean up even if processing didn't complete try: os.remove(batch_log_file) except: pass return success except subprocess.CalledProcessError as e: logger.error(f"Module generator failed with exit code {e.returncode}") logger.error(f"Error output: {e.stderr}") print(f"[ERROR] Module generator process failed with exit code {e.returncode}") return False except Exception as e: logger.error(f"Failed to run module generator: {e}") print(f"[ERROR] Failed to run module generator: {e}") return False def validate_and_fix_existing_modules(strict=False): """Check all existing modules for JSON validity and fix if needed""" print("Validating existing modules...") if strict: print("Running in strict mode - modules will NOT be fixed automatically") fixed_count = 0 module_files = list(Path(MODULES_DIR).glob("*.json")) # First, show a summary of all modules print(f"Found {len(module_files)} module files") invalid_modules = [] # Check all modules for validity for module_path in module_files: try: with open(module_path, 'r') as f: content = f.read() # Try to parse the JSON try: json.loads(content) # If it parses successfully, all good continue except json.JSONDecodeError as e: invalid_modules.append((module_path, e)) print(f"⚠️ Invalid JSON found in {module_path.name}: {e}") except Exception as e: print(f"Error processing {module_path}: {e}") if not invalid_modules: print("✅ All modules are valid JSON") return 0 print(f"\nFound {len(invalid_modules)} invalid module(s) to fix") # Now fix the invalid modules for module_path, error in invalid_modules: try: print(f"\nAttempting to fix {module_path.name}...") if strict: print(f"❌ Module {module_path.name} has validation issues that must be fixed manually") continue # Skip fixing in strict mode with open(module_path, 'r') as f: content = f.read() # Try to fix common issues fixed_content = content # Fix trailing commas fixed_content = re.sub(r',\s*}', '}', fixed_content) fixed_content = re.sub(r',\s*]', ']', fixed_content) # Check for incomplete states section if '"states": {' in fixed_content: # Make sure it has a proper closing brace states_match = re.search(r'"states":\s*{(.*)}(?=\s*,|\s*})', fixed_content, re.DOTALL) if not states_match and '"states": {' in fixed_content: print(" - Module appears to have an incomplete states section") # Find last complete state definition last_state_end = None state_pattern = re.compile(r'("[^"]+"\s*:\s*{[^{]*?}),?', re.DOTALL) for match in state_pattern.finditer(fixed_content, fixed_content.find('"states": {') + 12): last_state_end = match.end() if last_state_end: # Add closing brace after the last valid state fixed_content = fixed_content[:last_state_end] + '\n }\n }' + fixed_content[last_state_end:] print(" - Added closing braces for states section") # Fix missing/unbalanced braces open_braces = fixed_content.count('{') close_braces = fixed_content.count('}') if open_braces > close_braces: fixed_content += '}' * (open_braces - close_braces) print(f" - Added {open_braces - close_braces} missing closing braces") elif close_braces > open_braces: for _ in range(close_braces - open_braces): fixed_content = fixed_content.rstrip().rstrip('}') + '}' print(f" - Removed {close_braces - open_braces} excess closing braces") # Try to parse the fixed content try: parsed_json = json.loads(fixed_content) # Check if it has all required elements if 'name' not in parsed_json: print(" - Warning: Module missing 'name' field") if 'states' not in parsed_json: print(" - Warning: Module missing 'states' field") if 'gmf_version' not in parsed_json: print(" - Adding missing gmf_version field") # Make sure we have a proper JSON object if fixed_content.rstrip().endswith('}'): # Add the gmf_version before the final brace fixed_content = fixed_content.rstrip().rstrip('}') + ',\n "gmf_version": 2\n}' # Backup the original file backup_path = str(module_path) + '.bak' import shutil shutil.copy2(str(module_path), backup_path) # Write the fixed content with open(module_path, 'w') as f: f.write(fixed_content) print(f"✅ Fixed module {module_path.name}") fixed_count += 1 except json.JSONDecodeError as e: print(f"❌ Could not fix module {module_path.name}: {e}") except Exception as e: print(f"Error processing {module_path}: {e}") print(f"\nValidation complete. Fixed {fixed_count} of {len(invalid_modules)} invalid modules.") return fixed_count def main(): parser = argparse.ArgumentParser(description='Run the module generator in a loop') parser.add_argument('--batch-size', type=int, default=5, help='Number of modules to generate in each batch (default: 5)') parser.add_argument('--max-modules', type=int, default=None, help='Maximum total number of modules to generate (default: no limit)') parser.add_argument('--max-batches', type=int, default=None, help='Maximum number of batches to run (default: no limit)') parser.add_argument('--prioritize', action='store_true', help='Prioritize high-prevalence diseases first') parser.add_argument('--max-cost', type=float, default=None, help='Maximum cost in USD to spend (default: no limit)') parser.add_argument('--timeout', type=int, default=600, help='Timeout in seconds for each batch process (default: 600)') parser.add_argument('--fix-modules', action='store_true', help='Check and fix existing modules for JSON validity') parser.add_argument('--strict', action='store_true', help='Fail immediately on module validation errors instead of trying to fix them') args = parser.parse_args() # If fix-modules flag is set, validate and fix existing modules if args.fix_modules: validate_and_fix_existing_modules(strict=args.strict) return # Initial counts initial_modules = count_existing_modules() initial_remaining, total_diseases = count_remaining_diseases() logger.info(f"Starting module generation loop") logger.info(f"Currently have {initial_modules} modules") logger.info(f"Remaining diseases without modules: {initial_remaining} of {total_diseases}") # Set up counters modules_created = 0 batch_count = 0 # Loop until we've hit our maximum or there are no more diseases to process while True: # Break if we've reached the maximum modules to generate if args.max_modules is not None and modules_created >= args.max_modules: logger.info(f"Reached maximum modules limit ({args.max_modules})") print(f"\n{'!'*80}") print(f"MODULE LIMIT REACHED: {modules_created} >= {args.max_modules}") print(f"{'!'*80}\n") break # Break if we've reached the maximum batches if args.max_batches is not None and batch_count >= args.max_batches: logger.info(f"Reached maximum batch limit ({args.max_batches})") print(f"\n{'!'*80}") print(f"BATCH LIMIT REACHED: {batch_count} >= {args.max_batches}") print(f"{'!'*80}\n") break # Break if we've reached the maximum cost if args.max_cost is not None and total_cost_usd >= args.max_cost: logger.info(f"Reached maximum cost limit (${args.max_cost:.2f})") print(f"\n{'!'*80}") print(f"COST LIMIT REACHED: ${total_cost_usd:.4f} >= ${args.max_cost:.2f}") print(f"{'!'*80}\n") break # Count how many modules we need to generate in this batch if args.max_modules is not None: # Adjust batch size if we're near the maximum current_batch_size = min(args.batch_size, args.max_modules - modules_created) else: current_batch_size = args.batch_size if current_batch_size <= 0: logger.info("No more modules to generate in this batch") print("[INFO] No more modules to generate (reached limit)") break # Run the generator for this batch batch_count += 1 logger.info(f"Starting batch {batch_count} (generating up to {current_batch_size} modules)") # Run the generator success = run_module_generator( batch_size=current_batch_size, prioritize=args.prioritize, timeout=args.timeout, strict=args.strict ) if not success: logger.error(f"Batch {batch_count} failed, stopping") print(f"[ERROR] Batch {batch_count} failed, stopping") break # Count how many modules we've created so far new_module_count = count_existing_modules() modules_created = new_module_count - initial_modules # Count remaining diseases remaining_diseases, _ = count_remaining_diseases() logger.info(f"Completed batch {batch_count}") logger.info(f"Total modules created: {modules_created}") logger.info(f"Remaining diseases without modules: {remaining_diseases}") print(f"[STATUS] Batch {batch_count} complete") print(f"[STATUS] Total modules created: {modules_created}") print(f"[STATUS] Remaining diseases: {remaining_diseases} of {total_diseases}") # Break if no diseases left if remaining_diseases <= 0: logger.info("All diseases have modules now, finished!") print("[SUCCESS] All diseases have modules now, finished!") break # Sleep briefly between batches to avoid system overload time.sleep(2) # Final status logger.info(f"Module generation complete!") logger.info(f"Started with {initial_modules} modules, now have {count_existing_modules()}") logger.info(f"Created {modules_created} new modules in {batch_count} batches") remaining, total = count_remaining_diseases() logger.info(f"Remaining diseases without modules: {remaining} of {total}") if __name__ == "__main__": main()