Trying to fix basic functionality again.
This commit is contained in:
478
module_generator/run_module_generator.py
Executable file
478
module_generator/run_module_generator.py
Executable file
@@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Disease Module Generator Loop for Synthea
|
||||
|
||||
This script runs the module_generator.py script in a loop to generate
|
||||
modules for diseases in disease_list.json that don't have modules yet.
|
||||
|
||||
Usage:
|
||||
python run_module_generator.py [--batch-size BATCH_SIZE] [--max-modules MAX_MODULES]
|
||||
|
||||
Arguments:
|
||||
--batch-size BATCH_SIZE Number of modules to generate in each batch (default: 5)
|
||||
--max-modules MAX_MODULES Maximum total number of modules to generate (default: no limit)
|
||||
--prioritize Prioritize high-prevalence diseases first
|
||||
--max-cost MAX_COST Maximum cost in USD to spend (default: no limit)
|
||||
--strict Fail immediately on module validation errors instead of trying to fix them
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import anthropic
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("module_generation_runner.log"),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
# Support both Docker (/app) and local development paths
|
||||
if os.path.exists("/app/src/main/resources/disease_list.json"):
|
||||
DISEASE_LIST_PATH = "/app/src/main/resources/disease_list.json"
|
||||
MODULES_DIR = "/app/src/main/resources/modules"
|
||||
LOG_FILE_PATH = "/app/module_generation_runner.log"
|
||||
else:
|
||||
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
|
||||
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
|
||||
LOG_FILE_PATH = os.path.join(SYNTHEA_ROOT, "module_generation_runner.log")
|
||||
|
||||
MODULE_GENERATOR_PATH = os.path.join(os.path.dirname(__file__), "module_generator.py")
|
||||
|
||||
# API costs - Claude 3.7 Sonnet (as of March 2025)
|
||||
# These are approximate costs based on current pricing
|
||||
CLAUDE_INPUT_COST_PER_1K_TOKENS = 0.003 # $0.003 per 1K input tokens
|
||||
CLAUDE_OUTPUT_COST_PER_1K_TOKENS = 0.015 # $0.015 per 1K output tokens
|
||||
|
||||
# Initialize cost tracking
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_cost_usd = 0.0
|
||||
|
||||
def count_existing_modules():
|
||||
"""Count the number of modules already created"""
|
||||
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
||||
return len(module_files)
|
||||
|
||||
def count_remaining_diseases():
|
||||
"""Count how many diseases in disease_list.json don't have modules yet"""
|
||||
# Load the disease list
|
||||
with open(DISEASE_LIST_PATH, 'r') as f:
|
||||
diseases = json.load(f)
|
||||
|
||||
# Get existing module names
|
||||
existing_modules = set()
|
||||
for module_file in Path(MODULES_DIR).glob("*.json"):
|
||||
module_name = module_file.stem.lower()
|
||||
existing_modules.add(module_name)
|
||||
|
||||
# Count diseases that don't have modules
|
||||
remaining = 0
|
||||
for disease in diseases:
|
||||
# Extract disease name from the disease object
|
||||
disease_name = disease['disease_name']
|
||||
# Convert disease name to module filename format
|
||||
module_name = disease_name.lower()
|
||||
module_name = ''.join(c if c.isalnum() else '_' for c in module_name)
|
||||
module_name = module_name.strip('_')
|
||||
# Replace consecutive underscores with a single one
|
||||
while '__' in module_name:
|
||||
module_name = module_name.replace('__', '_')
|
||||
|
||||
if module_name not in existing_modules:
|
||||
remaining += 1
|
||||
|
||||
return remaining, len(diseases)
|
||||
|
||||
def parse_api_usage(log_content):
|
||||
"""Parse API usage stats from log content"""
|
||||
global total_input_tokens, total_output_tokens, total_cost_usd
|
||||
|
||||
# Look for token counts in the log - using the updated pattern from the logs
|
||||
# This matches the format: "Estimated input tokens: 7031, Estimated output tokens: 3225"
|
||||
input_token_matches = re.findall(r'Estimated input tokens: (\d+)', log_content)
|
||||
output_token_matches = re.findall(r'Estimated output tokens: (\d+)', log_content)
|
||||
|
||||
batch_input_tokens = sum(int(count) for count in input_token_matches)
|
||||
batch_output_tokens = sum(int(count) for count in output_token_matches)
|
||||
|
||||
# Calculate cost for this batch
|
||||
batch_input_cost = batch_input_tokens * CLAUDE_INPUT_COST_PER_1K_TOKENS / 1000
|
||||
batch_output_cost = batch_output_tokens * CLAUDE_OUTPUT_COST_PER_1K_TOKENS / 1000
|
||||
batch_total_cost = batch_input_cost + batch_output_cost
|
||||
|
||||
# Update totals
|
||||
total_input_tokens += batch_input_tokens
|
||||
total_output_tokens += batch_output_tokens
|
||||
total_cost_usd += batch_total_cost
|
||||
|
||||
# Log the token counts found for debugging
|
||||
logger.info(f"Found token counts in log - Input: {input_token_matches}, Output: {output_token_matches}")
|
||||
|
||||
return batch_input_tokens, batch_output_tokens, batch_total_cost
|
||||
|
||||
def run_module_generator(batch_size=5, prioritize=False, timeout=600, strict=False):
|
||||
"""Run the module generator script with the specified batch size"""
|
||||
try:
|
||||
# Create a unique log file for this batch to capture token usage
|
||||
batch_log_file = f"module_generator_batch_{int(time.time())}.log"
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
MODULE_GENERATOR_PATH,
|
||||
'--limit', str(batch_size),
|
||||
'--log-file', batch_log_file
|
||||
]
|
||||
|
||||
if prioritize:
|
||||
cmd.append('--prioritize')
|
||||
|
||||
if strict:
|
||||
cmd.append('--strict')
|
||||
|
||||
logger.info(f"Running command: {' '.join(cmd)}")
|
||||
print(f"\n[INFO] Starting batch, generating up to {batch_size} modules...")
|
||||
|
||||
# Use timeout to prevent indefinite hanging
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=timeout)
|
||||
success = True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"Module generator timed out after {timeout} seconds")
|
||||
print(f"\n[ERROR] Process timed out after {timeout} seconds!")
|
||||
success = False
|
||||
result = None
|
||||
|
||||
# Process output only if the command completed successfully
|
||||
if success and result:
|
||||
# Log output
|
||||
if result.stdout:
|
||||
for line in result.stdout.splitlines():
|
||||
logger.info(f"Generator output: {line}")
|
||||
|
||||
# Print progress-related lines directly to console
|
||||
if "Generated module for" in line or "modules complete" in line:
|
||||
print(f"[PROGRESS] {line}")
|
||||
|
||||
# Check if any errors
|
||||
if result.stderr:
|
||||
for line in result.stderr.splitlines():
|
||||
if "Error" in line or "Exception" in line:
|
||||
logger.error(f"Generator error: {line}")
|
||||
print(f"[ERROR] {line}")
|
||||
else:
|
||||
logger.info(f"Generator message: {line}")
|
||||
|
||||
# Parse token usage from the log file
|
||||
if os.path.exists(batch_log_file):
|
||||
with open(batch_log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Parse usage statistics
|
||||
batch_input_tokens, batch_output_tokens, batch_cost = parse_api_usage(log_content)
|
||||
|
||||
# Log usage for this batch
|
||||
logger.info(f"Batch API Usage - Input tokens: {batch_input_tokens:,}, "
|
||||
f"Output tokens: {batch_output_tokens:,}, Cost: ${batch_cost:.4f}")
|
||||
logger.info(f"Total API Usage - Input tokens: {total_input_tokens:,}, "
|
||||
f"Output tokens: {total_output_tokens:,}, Total Cost: ${total_cost_usd:.4f}")
|
||||
|
||||
# Print to console in a prominent way
|
||||
print("\n" + "="*80)
|
||||
print(f"BATCH API USAGE - Cost: ${batch_cost:.4f}")
|
||||
print(f"TOTAL API USAGE - Cost: ${total_cost_usd:.4f}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Clean up batch log file
|
||||
try:
|
||||
os.remove(batch_log_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Count created files to verify success
|
||||
if os.path.exists(batch_log_file):
|
||||
with open(batch_log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
modules_created = log_content.count("Successfully created module for")
|
||||
if modules_created > 0:
|
||||
print(f"[SUCCESS] Created {modules_created} module(s) in this batch")
|
||||
|
||||
# Clean up even if processing didn't complete
|
||||
try:
|
||||
os.remove(batch_log_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
return success
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Module generator failed with exit code {e.returncode}")
|
||||
logger.error(f"Error output: {e.stderr}")
|
||||
print(f"[ERROR] Module generator process failed with exit code {e.returncode}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run module generator: {e}")
|
||||
print(f"[ERROR] Failed to run module generator: {e}")
|
||||
return False
|
||||
|
||||
def validate_and_fix_existing_modules(strict=False):
|
||||
"""Check all existing modules for JSON validity and fix if needed"""
|
||||
print("Validating existing modules...")
|
||||
|
||||
if strict:
|
||||
print("Running in strict mode - modules will NOT be fixed automatically")
|
||||
|
||||
fixed_count = 0
|
||||
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
||||
|
||||
# First, show a summary of all modules
|
||||
print(f"Found {len(module_files)} module files")
|
||||
|
||||
invalid_modules = []
|
||||
|
||||
# Check all modules for validity
|
||||
for module_path in module_files:
|
||||
try:
|
||||
with open(module_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Try to parse the JSON
|
||||
try:
|
||||
json.loads(content)
|
||||
# If it parses successfully, all good
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
invalid_modules.append((module_path, e))
|
||||
print(f"⚠️ Invalid JSON found in {module_path.name}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {module_path}: {e}")
|
||||
|
||||
if not invalid_modules:
|
||||
print("✅ All modules are valid JSON")
|
||||
return 0
|
||||
|
||||
print(f"\nFound {len(invalid_modules)} invalid module(s) to fix")
|
||||
|
||||
# Now fix the invalid modules
|
||||
for module_path, error in invalid_modules:
|
||||
try:
|
||||
print(f"\nAttempting to fix {module_path.name}...")
|
||||
|
||||
if strict:
|
||||
print(f"❌ Module {module_path.name} has validation issues that must be fixed manually")
|
||||
continue # Skip fixing in strict mode
|
||||
|
||||
with open(module_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Try to fix common issues
|
||||
fixed_content = content
|
||||
|
||||
# Fix trailing commas
|
||||
fixed_content = re.sub(r',\s*}', '}', fixed_content)
|
||||
fixed_content = re.sub(r',\s*]', ']', fixed_content)
|
||||
|
||||
# Check for incomplete states section
|
||||
if '"states": {' in fixed_content:
|
||||
# Make sure it has a proper closing brace
|
||||
states_match = re.search(r'"states":\s*{(.*)}(?=\s*,|\s*})', fixed_content, re.DOTALL)
|
||||
if not states_match and '"states": {' in fixed_content:
|
||||
print(" - Module appears to have an incomplete states section")
|
||||
|
||||
# Find last complete state definition
|
||||
last_state_end = None
|
||||
state_pattern = re.compile(r'("[^"]+"\s*:\s*{[^{]*?}),?', re.DOTALL)
|
||||
for match in state_pattern.finditer(fixed_content, fixed_content.find('"states": {') + 12):
|
||||
last_state_end = match.end()
|
||||
|
||||
if last_state_end:
|
||||
# Add closing brace after the last valid state
|
||||
fixed_content = fixed_content[:last_state_end] + '\n }\n }' + fixed_content[last_state_end:]
|
||||
print(" - Added closing braces for states section")
|
||||
|
||||
# Fix missing/unbalanced braces
|
||||
open_braces = fixed_content.count('{')
|
||||
close_braces = fixed_content.count('}')
|
||||
|
||||
if open_braces > close_braces:
|
||||
fixed_content += '}' * (open_braces - close_braces)
|
||||
print(f" - Added {open_braces - close_braces} missing closing braces")
|
||||
elif close_braces > open_braces:
|
||||
for _ in range(close_braces - open_braces):
|
||||
fixed_content = fixed_content.rstrip().rstrip('}') + '}'
|
||||
print(f" - Removed {close_braces - open_braces} excess closing braces")
|
||||
|
||||
# Try to parse the fixed content
|
||||
try:
|
||||
parsed_json = json.loads(fixed_content)
|
||||
|
||||
# Check if it has all required elements
|
||||
if 'name' not in parsed_json:
|
||||
print(" - Warning: Module missing 'name' field")
|
||||
if 'states' not in parsed_json:
|
||||
print(" - Warning: Module missing 'states' field")
|
||||
if 'gmf_version' not in parsed_json:
|
||||
print(" - Adding missing gmf_version field")
|
||||
# Make sure we have a proper JSON object
|
||||
if fixed_content.rstrip().endswith('}'):
|
||||
# Add the gmf_version before the final brace
|
||||
fixed_content = fixed_content.rstrip().rstrip('}') + ',\n "gmf_version": 2\n}'
|
||||
|
||||
# Backup the original file
|
||||
backup_path = str(module_path) + '.bak'
|
||||
import shutil
|
||||
shutil.copy2(str(module_path), backup_path)
|
||||
|
||||
# Write the fixed content
|
||||
with open(module_path, 'w') as f:
|
||||
f.write(fixed_content)
|
||||
|
||||
print(f"✅ Fixed module {module_path.name}")
|
||||
fixed_count += 1
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ Could not fix module {module_path.name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {module_path}: {e}")
|
||||
|
||||
print(f"\nValidation complete. Fixed {fixed_count} of {len(invalid_modules)} invalid modules.")
|
||||
return fixed_count
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run the module generator in a loop')
|
||||
parser.add_argument('--batch-size', type=int, default=5,
|
||||
help='Number of modules to generate in each batch (default: 5)')
|
||||
parser.add_argument('--max-modules', type=int, default=None,
|
||||
help='Maximum total number of modules to generate (default: no limit)')
|
||||
parser.add_argument('--max-batches', type=int, default=None,
|
||||
help='Maximum number of batches to run (default: no limit)')
|
||||
parser.add_argument('--prioritize', action='store_true',
|
||||
help='Prioritize high-prevalence diseases first')
|
||||
parser.add_argument('--max-cost', type=float, default=None,
|
||||
help='Maximum cost in USD to spend (default: no limit)')
|
||||
parser.add_argument('--timeout', type=int, default=600,
|
||||
help='Timeout in seconds for each batch process (default: 600)')
|
||||
parser.add_argument('--fix-modules', action='store_true',
|
||||
help='Check and fix existing modules for JSON validity')
|
||||
parser.add_argument('--strict', action='store_true',
|
||||
help='Fail immediately on module validation errors instead of trying to fix them')
|
||||
args = parser.parse_args()
|
||||
|
||||
# If fix-modules flag is set, validate and fix existing modules
|
||||
if args.fix_modules:
|
||||
validate_and_fix_existing_modules(strict=args.strict)
|
||||
return
|
||||
|
||||
# Initial counts
|
||||
initial_modules = count_existing_modules()
|
||||
initial_remaining, total_diseases = count_remaining_diseases()
|
||||
|
||||
logger.info(f"Starting module generation loop")
|
||||
logger.info(f"Currently have {initial_modules} modules")
|
||||
logger.info(f"Remaining diseases without modules: {initial_remaining} of {total_diseases}")
|
||||
|
||||
# Set up counters
|
||||
modules_created = 0
|
||||
batch_count = 0
|
||||
|
||||
# Loop until we've hit our maximum or there are no more diseases to process
|
||||
while True:
|
||||
# Break if we've reached the maximum modules to generate
|
||||
if args.max_modules is not None and modules_created >= args.max_modules:
|
||||
logger.info(f"Reached maximum modules limit ({args.max_modules})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"MODULE LIMIT REACHED: {modules_created} >= {args.max_modules}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Break if we've reached the maximum batches
|
||||
if args.max_batches is not None and batch_count >= args.max_batches:
|
||||
logger.info(f"Reached maximum batch limit ({args.max_batches})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"BATCH LIMIT REACHED: {batch_count} >= {args.max_batches}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Break if we've reached the maximum cost
|
||||
if args.max_cost is not None and total_cost_usd >= args.max_cost:
|
||||
logger.info(f"Reached maximum cost limit (${args.max_cost:.2f})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"COST LIMIT REACHED: ${total_cost_usd:.4f} >= ${args.max_cost:.2f}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Count how many modules we need to generate in this batch
|
||||
if args.max_modules is not None:
|
||||
# Adjust batch size if we're near the maximum
|
||||
current_batch_size = min(args.batch_size, args.max_modules - modules_created)
|
||||
else:
|
||||
current_batch_size = args.batch_size
|
||||
|
||||
if current_batch_size <= 0:
|
||||
logger.info("No more modules to generate in this batch")
|
||||
print("[INFO] No more modules to generate (reached limit)")
|
||||
break
|
||||
|
||||
# Run the generator for this batch
|
||||
batch_count += 1
|
||||
logger.info(f"Starting batch {batch_count} (generating up to {current_batch_size} modules)")
|
||||
|
||||
# Run the generator
|
||||
success = run_module_generator(
|
||||
batch_size=current_batch_size,
|
||||
prioritize=args.prioritize,
|
||||
timeout=args.timeout,
|
||||
strict=args.strict
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"Batch {batch_count} failed, stopping")
|
||||
print(f"[ERROR] Batch {batch_count} failed, stopping")
|
||||
break
|
||||
|
||||
# Count how many modules we've created so far
|
||||
new_module_count = count_existing_modules()
|
||||
modules_created = new_module_count - initial_modules
|
||||
|
||||
# Count remaining diseases
|
||||
remaining_diseases, _ = count_remaining_diseases()
|
||||
|
||||
logger.info(f"Completed batch {batch_count}")
|
||||
logger.info(f"Total modules created: {modules_created}")
|
||||
logger.info(f"Remaining diseases without modules: {remaining_diseases}")
|
||||
|
||||
print(f"[STATUS] Batch {batch_count} complete")
|
||||
print(f"[STATUS] Total modules created: {modules_created}")
|
||||
print(f"[STATUS] Remaining diseases: {remaining_diseases} of {total_diseases}")
|
||||
|
||||
# Break if no diseases left
|
||||
if remaining_diseases <= 0:
|
||||
logger.info("All diseases have modules now, finished!")
|
||||
print("[SUCCESS] All diseases have modules now, finished!")
|
||||
break
|
||||
|
||||
# Sleep briefly between batches to avoid system overload
|
||||
time.sleep(2)
|
||||
|
||||
# Final status
|
||||
logger.info(f"Module generation complete!")
|
||||
logger.info(f"Started with {initial_modules} modules, now have {count_existing_modules()}")
|
||||
logger.info(f"Created {modules_created} new modules in {batch_count} batches")
|
||||
|
||||
remaining, total = count_remaining_diseases()
|
||||
logger.info(f"Remaining diseases without modules: {remaining} of {total}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user