Trying to fix basic functionality again.
This commit is contained in:
83
module_generator/README_module_generator.md
Normal file
83
module_generator/README_module_generator.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Synthea Module Generator
|
||||
|
||||
This tool automates the creation of disease modules for Synthea based on `disease_list.json`. It uses Claude 3.7 to generate appropriate JSON structures for each disease, leveraging existing modules as templates.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Python 3.6+
|
||||
2. Required Python packages:
|
||||
```
|
||||
pip install anthropic tqdm
|
||||
```
|
||||
3. Anthropic API key:
|
||||
```
|
||||
export ANTHROPIC_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Generate 10 modules (default limit)
|
||||
python module_generator.py
|
||||
|
||||
# Generate modules only for specific ICD-10 code categories
|
||||
python module_generator.py --diseases I20,I21,I22
|
||||
|
||||
# Generate up to 50 modules
|
||||
python module_generator.py --limit 50
|
||||
|
||||
# Prioritize high-prevalence diseases (recommended)
|
||||
python module_generator.py --prioritize
|
||||
|
||||
# Combine options for best results
|
||||
python module_generator.py --diseases I,J,K --limit 100 --prioritize
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
1. The script loads the complete disease list from `disease_list.json`
|
||||
2. It filters out diseases that already have modules
|
||||
3. If `--prioritize` is enabled, it:
|
||||
- Estimates the prevalence of each disease using a heuristic scoring system
|
||||
- Prioritizes diseases based on common conditions, ICD-10 chapter, and name specificity
|
||||
- Selects the highest-scoring diseases first
|
||||
4. For each selected disease:
|
||||
- Finds the most relevant existing module as a template (based on ICD-10 code)
|
||||
- Sends a prompt to Claude with the disease details and template
|
||||
- Validates the generated JSON
|
||||
- Saves the new module to the appropriate location
|
||||
- Updates the progress tracking file
|
||||
|
||||
## Configuration
|
||||
|
||||
- `CLAUDE_MODEL`: Set the Claude model to use (default: `claude-3-7-sonnet-20240229`)
|
||||
- `SYNTHEA_ROOT`: Path to the Synthea root directory (auto-detected)
|
||||
|
||||
## Cost Estimation
|
||||
|
||||
The script uses Claude 3.7 Sonnet, which costs approximately:
|
||||
- Input: $3 per million tokens
|
||||
- Output: $15 per million tokens
|
||||
|
||||
A typical generation will use:
|
||||
- ~10K input tokens (template + prompt)
|
||||
- ~5K output tokens (generated module)
|
||||
|
||||
At this rate, generating 1,000 modules would cost approximately:
|
||||
- Input: 10M tokens = $30
|
||||
- Output: 5M tokens = $75
|
||||
- Total: ~$105
|
||||
|
||||
## Logging
|
||||
|
||||
The script logs all activity to both the console and to `module_generation.log` in the current directory.
|
||||
|
||||
## Notes
|
||||
|
||||
- The script includes a 1-second delay between API calls to avoid rate limits
|
||||
- Generated modules should be manually reviewed for quality and accuracy
|
||||
- You may want to run the script incrementally (e.g., by disease category) to review results
|
||||
- The script optimizes API usage by:
|
||||
- Checking if a module already exists before generating (by filename or ICD-10 code)
|
||||
- Only using Claude when a new module genuinely needs to be created
|
||||
- Prioritizing high-prevalence diseases when using the `--prioritize` flag
|
||||
732
module_generator/module_generator.py
Executable file
732
module_generator/module_generator.py
Executable file
@@ -0,0 +1,732 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Module Generator for Synthea
|
||||
|
||||
This script automates the creation of new disease modules for Synthea based on
|
||||
disease_list.json. It uses Claude 3.7 to generate appropriate JSON structures
|
||||
for each disease, leveraging existing modules as templates.
|
||||
|
||||
Usage:
|
||||
python module_generator.py [--diseases DISEASES] [--limit LIMIT]
|
||||
|
||||
Arguments:
|
||||
--diseases DISEASES Comma-separated list of specific diseases to process (ICD-10 codes)
|
||||
--limit LIMIT Maximum number of modules to generate (default: 10)
|
||||
|
||||
Example:
|
||||
python src/main/python/run_module_generator.py --batch-size 3 --max-cost 10.0 --prioritize
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
import argparse
|
||||
import time
|
||||
import anthropic
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
def setup_logging(log_file_path=None):
|
||||
"""Configure logging to both file and console"""
|
||||
if log_file_path is None:
|
||||
log_file_path = "module_generation.log"
|
||||
|
||||
handlers = [
|
||||
logging.FileHandler(log_file_path),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=handlers
|
||||
)
|
||||
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
||||
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
|
||||
|
||||
# Allow overriding MODULES_DIR with environment variable (for Docker)
|
||||
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
|
||||
if not os.path.exists(MODULES_DIR):
|
||||
os.makedirs(MODULES_DIR, exist_ok=True)
|
||||
logger.info(f"Created modules directory at {MODULES_DIR}")
|
||||
|
||||
PROGRESS_FILE = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_modules_progress.md")
|
||||
TEMPLATES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/templates/modules")
|
||||
|
||||
# Check if directories exist, create if they don't
|
||||
if not os.path.exists(TEMPLATES_DIR):
|
||||
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
||||
logger.info(f"Created templates directory at {TEMPLATES_DIR}")
|
||||
|
||||
# Check if progress file exists, create if it doesn't
|
||||
if not os.path.exists(PROGRESS_FILE):
|
||||
with open(PROGRESS_FILE, 'w') as f:
|
||||
f.write("# Disease Module Progress\n\n")
|
||||
f.write("## Completed Modules\n\n")
|
||||
f.write("## Planned Modules\n\n")
|
||||
logger.info(f"Created progress file at {PROGRESS_FILE}")
|
||||
|
||||
# Initialize Claude client (API key should be stored in ANTHROPIC_API_KEY environment variable)
|
||||
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||
if not api_key:
|
||||
logger.error("ANTHROPIC_API_KEY environment variable is not set")
|
||||
sys.exit(1)
|
||||
|
||||
# Create client without the proxies parameter which causes issues with older versions
|
||||
client = anthropic.Client(api_key=api_key)
|
||||
CLAUDE_MODEL = "claude-3-7-sonnet-20250219" # Updated to the current model version
|
||||
MAX_TOKENS = 4096 # Maximum allowed tokens for Claude 3.7 Sonnet
|
||||
|
||||
def validate_condition_format(module_json):
|
||||
"""Validate that conditions in the module follow Synthea's expected format"""
|
||||
try:
|
||||
module_dict = json.loads(module_json) if isinstance(module_json, str) else module_json
|
||||
|
||||
# Function to recursively check objects for improper condition structure
|
||||
def check_conditions(obj):
|
||||
issues = []
|
||||
|
||||
if isinstance(obj, dict):
|
||||
# Check if this is a condition object with nested condition_type
|
||||
if "condition" in obj and isinstance(obj["condition"], dict):
|
||||
condition = obj["condition"]
|
||||
# Look for the improper nested structure
|
||||
if "condition_type" in condition and isinstance(condition["condition_type"], dict):
|
||||
issues.append("Found nested condition_type in a condition object")
|
||||
|
||||
# Recursively check all dictionary values
|
||||
for key, value in obj.items():
|
||||
child_issues = check_conditions(value)
|
||||
issues.extend(child_issues)
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Recursively check all list items
|
||||
for item in obj:
|
||||
child_issues = check_conditions(item)
|
||||
issues.extend(child_issues)
|
||||
|
||||
return issues
|
||||
|
||||
# Check the entire module
|
||||
issues = check_conditions(module_dict)
|
||||
return len(issues) == 0, issues
|
||||
|
||||
except Exception as e:
|
||||
return False, [f"Validation error: {str(e)}"]
|
||||
|
||||
def load_disease_list() -> Dict[str, Dict[str, str]]:
|
||||
"""Load the disease list from JSON file"""
|
||||
try:
|
||||
with open(DISEASE_LIST_PATH, 'r') as f:
|
||||
diseases_list = json.load(f)
|
||||
|
||||
# Convert list to dict using disease_name as the key
|
||||
diseases_dict = {}
|
||||
for disease in diseases_list:
|
||||
disease_name = disease.get('disease_name', '')
|
||||
if disease_name:
|
||||
# Create a new disease info dict with additional info
|
||||
disease_info = {
|
||||
'icd_10': disease.get('id', ''),
|
||||
'snomed': disease.get('snomed', ''),
|
||||
'ICD-10_name': disease.get('ICD-10_name', '')
|
||||
}
|
||||
diseases_dict[disease_name] = disease_info
|
||||
|
||||
return diseases_dict
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading disease list: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def get_existing_modules() -> List[str]:
|
||||
"""Get list of existing module names (without extension)"""
|
||||
module_files = glob.glob(os.path.join(MODULES_DIR, "*.json"))
|
||||
return [os.path.splitext(os.path.basename(f))[0].lower() for f in module_files]
|
||||
|
||||
def find_most_relevant_module(disease_info: Dict[str, str], existing_modules: List[str]) -> str:
|
||||
"""Find the most relevant existing module to use as a template"""
|
||||
# First check if an icd code match exists
|
||||
icd_code_prefix = disease_info.get("icd_10", "")[:3] # Get first 3 chars of ICD-10 code
|
||||
|
||||
# Look for modules that might relate to the same body system
|
||||
related_modules = []
|
||||
for module_path in glob.glob(os.path.join(MODULES_DIR, "*.json")):
|
||||
with open(module_path, 'r') as f:
|
||||
try:
|
||||
content = f.read()
|
||||
# Check if this module contains the same ICD-10 prefix
|
||||
if f'"{icd_code_prefix}' in content:
|
||||
related_modules.append(module_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
if related_modules:
|
||||
# Return the most complex related module (largest file size as a heuristic)
|
||||
return max(related_modules, key=os.path.getsize)
|
||||
|
||||
# If no ICD match, return a default template based on disease type
|
||||
if icd_code_prefix.startswith('I'): # Circulatory system
|
||||
return os.path.join(MODULES_DIR, "hypertensive_renal_disease.json")
|
||||
elif icd_code_prefix.startswith('J'): # Respiratory system
|
||||
return os.path.join(MODULES_DIR, "asthma.json")
|
||||
elif icd_code_prefix.startswith('K'): # Digestive system
|
||||
return os.path.join(MODULES_DIR, "appendicitis.json")
|
||||
else:
|
||||
# Default to a simple template
|
||||
return os.path.join(TEMPLATES_DIR, "prevalence.json")
|
||||
|
||||
def generate_module_with_claude(
|
||||
disease_name: str,
|
||||
disease_info: Dict[str, str],
|
||||
template_path: str
|
||||
) -> Tuple[str, int, int]:
|
||||
"""Use Claude to generate a new module based on the template"""
|
||||
|
||||
# Load the template module
|
||||
logger.info(f"Loading template module from {os.path.basename(template_path)}")
|
||||
with open(template_path, 'r') as f:
|
||||
template_content = f.read()
|
||||
|
||||
# Get template module name
|
||||
template_name = os.path.splitext(os.path.basename(template_path))[0]
|
||||
|
||||
# Construct the prompt
|
||||
prompt = f"""You are a medical expert and software developer tasked with creating disease modules for Synthea, an open-source synthetic patient generator.
|
||||
|
||||
TASK: Create a valid JSON module for the disease: "{disease_name}" (ICD-10 code: {disease_info.get('icd_10', 'Unknown')}).
|
||||
|
||||
CRITICAL: Your response MUST ONLY contain valid, parseable JSON that starts with {{ and ends with }}. No explanations, text, markdown formatting, or code blocks.
|
||||
|
||||
Disease Information:
|
||||
- Name: {disease_name}
|
||||
- ICD-10 Code: {disease_info.get('icd_10', 'Unknown')}
|
||||
- SNOMED-CT Code: {disease_info.get('snomed', 'Unknown')}
|
||||
|
||||
I'm providing an example module structure based on {template_name}.json as a reference:
|
||||
|
||||
```json
|
||||
{template_content}
|
||||
```
|
||||
|
||||
Technical Requirements (MUST follow ALL of these):
|
||||
1. JSON Format:
|
||||
- Use valid JSON syntax with no errors
|
||||
- No trailing commas
|
||||
- All property names in double quotes
|
||||
- Always add a "gmf_version": 2 field at the top level
|
||||
- Check that all brackets and braces are properly matched
|
||||
|
||||
2. Module Structure:
|
||||
- Include "name" field with {disease_name}
|
||||
- Include "remarks" array with at least 3 items
|
||||
- Include "states" object with complete set of disease states
|
||||
- Add at least 2-3 reference URLs in the "remarks" section
|
||||
- Every state needs a unique name and valid type (like "Initial", "Terminal", etc.)
|
||||
- All transitions must be valid (direct_transition, distributed_transition, etc.)
|
||||
|
||||
3. Medical Content:
|
||||
- Include accurate medical codes (SNOMED-CT, ICD-10, LOINC, RxNorm)
|
||||
- Model realistic disease prevalence based on age, gender, race
|
||||
- Include relevant symptoms, diagnostic criteria, treatments
|
||||
- Only include states that make clinical sense for this specific disease
|
||||
|
||||
4. CRITICAL CONDITION STRUCTURE REQUIREMENTS:
|
||||
- In conditional statements, the 'condition_type' MUST be a top-level key within the condition object
|
||||
- INCORRECT: "condition": {"condition_type": {"nested": "value"}, "name": "SomeState"}
|
||||
- CORRECT: "condition": {"condition_type": "PriorState", "name": "SomeState"}
|
||||
- Common condition types: "Age", "Gender", "Race", "Attribute", "And", "Or", "Not", "PriorState", "Active Condition"
|
||||
- For PriorState conditions, use: {"condition_type": "PriorState", "name": "StateName"}
|
||||
- For attribute checks, use: {"condition_type": "Attribute", "attribute": "attr_name", "operator": "==", "value": true}
|
||||
|
||||
IMPORTANT REMINDER: Respond with ONLY valid JSON, not explanations. The entire response should be a single JSON object that can be directly parsed.
|
||||
"""
|
||||
|
||||
# Request generation from Claude using older API version 0.8.1
|
||||
logger.info(f"Sending request to Claude API for '{disease_name}'")
|
||||
try:
|
||||
# Use the Anthropic API with older syntax
|
||||
response = client.completion(
|
||||
model=CLAUDE_MODEL,
|
||||
max_tokens_to_sample=MAX_TOKENS,
|
||||
temperature=0.1, # Lower temperature for more consistent, predicable output
|
||||
prompt=f"\n\nHuman: {prompt}\n\nAssistant: ",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# For v0.8.1, the response is a completion object with a completion property
|
||||
content = response.completion
|
||||
|
||||
# Remove any markdown code block indicators
|
||||
content = re.sub(r'^```json\s*', '', content)
|
||||
content = re.sub(r'^```\s*', '', content)
|
||||
content = re.sub(r'```$', '', content)
|
||||
|
||||
# Log content type for debugging
|
||||
logger.debug(f"Content type: {type(content)}")
|
||||
|
||||
# Ensure content is string
|
||||
if not isinstance(content, str):
|
||||
content = content.decode('utf-8') if hasattr(content, 'decode') else str(content)
|
||||
|
||||
# Estimate token usage - we don't get this directly with streaming
|
||||
# Rough estimate: 1 token ≈ 4 characters
|
||||
input_token_estimate = len(prompt) // 4
|
||||
output_token_estimate = len(content) // 4
|
||||
|
||||
# Log estimated token usage
|
||||
logger.info(f"API call for '{disease_name}' - Estimated input tokens: {input_token_estimate}, "
|
||||
f"Estimated output tokens: {output_token_estimate}")
|
||||
|
||||
# Validate and format JSON
|
||||
try:
|
||||
# Parse and format with Python's built-in json module
|
||||
parsed = json.loads(content)
|
||||
|
||||
# Validate condition structure
|
||||
valid, issues = validate_condition_format(parsed)
|
||||
|
||||
if not valid:
|
||||
logger.warning(f"Generated module for {disease_name} has condition structure issues: {issues}")
|
||||
logger.warning("Requesting regeneration with corrected instructions...")
|
||||
|
||||
# Add more specific instructions for the retry
|
||||
retry_prompt = prompt + "\n\nPLEASE FIX THESE ISSUES: " + ", ".join(issues)
|
||||
retry_prompt += "\n\nREMINDER: All condition objects must have 'condition_type' as a top-level key, NOT nested."
|
||||
|
||||
# Call Claude again with the refined prompt using older API
|
||||
logger.info(f"Retrying generation for '{disease_name}' with specific instructions about issues")
|
||||
retry_response = client.completion(
|
||||
model=CLAUDE_MODEL,
|
||||
max_tokens_to_sample=MAX_TOKENS,
|
||||
temperature=0.1,
|
||||
prompt=f"\n\nHuman: {retry_prompt}\n\nAssistant: ",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# For v0.8.1, the response is a completion object with a completion property
|
||||
retry_content = retry_response.completion
|
||||
|
||||
# Remove any markdown code block indicators
|
||||
retry_content = re.sub(r'^```json\s*', '', retry_content)
|
||||
retry_content = re.sub(r'^```\s*', '', retry_content)
|
||||
retry_content = re.sub(r'```$', '', retry_content)
|
||||
|
||||
# Try to parse the retry response
|
||||
try:
|
||||
retry_parsed = json.loads(retry_content)
|
||||
|
||||
# Validate the retry response
|
||||
retry_valid, retry_issues = validate_condition_format(retry_parsed)
|
||||
|
||||
if not retry_valid and args.strict:
|
||||
logger.error(f"Failed to fix condition structure issues after retry: {retry_issues}")
|
||||
raise ValueError(f"Module format validation failed: {retry_issues}")
|
||||
elif not retry_valid:
|
||||
logger.warning(f"Retry still has issues, but proceeding due to non-strict mode: {retry_issues}")
|
||||
# Use the retry response even with issues
|
||||
formatted_json = json.dumps(retry_parsed, indent=2)
|
||||
else:
|
||||
# Successfully fixed the issues
|
||||
logger.info(f"Successfully fixed condition structure issues for '{disease_name}'")
|
||||
formatted_json = json.dumps(retry_parsed, indent=2)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Retry response is still not valid JSON: {e}")
|
||||
if args.strict:
|
||||
raise ValueError(f"Failed to generate valid JSON after retry: {e}")
|
||||
else:
|
||||
# Fall back to the original response in non-strict mode
|
||||
logger.warning("Using original response despite issues (non-strict mode)")
|
||||
formatted_json = json.dumps(parsed, indent=2)
|
||||
else:
|
||||
# Original response was valid
|
||||
formatted_json = json.dumps(parsed, indent=2)
|
||||
|
||||
logger.info(f"Successfully generated valid JSON for '{disease_name}'")
|
||||
return formatted_json, input_token_estimate, output_token_estimate
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Generated content is not valid JSON: {e}")
|
||||
logger.debug(f"Generated content: {content[:500]}...") # Log first 500 chars for debugging
|
||||
|
||||
# Try different extraction methods for the JSON
|
||||
extraction_attempts = [
|
||||
# Method 1: Find content between JSON code block markers
|
||||
re.search(r'```json([\s\S]*?)```', content),
|
||||
# Method 2: Find content between code block markers
|
||||
re.search(r'```([\s\S]*?)```', content),
|
||||
# Method 3: Find content between curly braces
|
||||
re.search(r'({[\s\S]*})', content),
|
||||
# Method 4: Find anything that looks like JSON starting with {
|
||||
re.search(r'({.*})', content, re.DOTALL)
|
||||
]
|
||||
|
||||
for attempt in extraction_attempts:
|
||||
if attempt:
|
||||
try:
|
||||
extracted_content = attempt.group(1).strip()
|
||||
# Add missing braces and fix incomplete JSON structures
|
||||
# First clean up the JSON to remove any trailing commas before closing brackets
|
||||
extracted_content = re.sub(r',\s*}', '}', extracted_content)
|
||||
extracted_content = re.sub(r',\s*]', ']', extracted_content)
|
||||
|
||||
# Count opening and closing braces to detect missing ones
|
||||
open_braces = extracted_content.count('{')
|
||||
close_braces = extracted_content.count('}')
|
||||
open_brackets = extracted_content.count('[')
|
||||
close_brackets = extracted_content.count(']')
|
||||
|
||||
# Add missing braces or brackets if needed
|
||||
if open_braces > close_braces:
|
||||
extracted_content += '}' * (open_braces - close_braces)
|
||||
logger.info(f"Added {open_braces - close_braces} missing closing braces")
|
||||
elif close_braces > open_braces:
|
||||
# Remove excess closing braces
|
||||
for _ in range(close_braces - open_braces):
|
||||
extracted_content = extracted_content.rstrip().rstrip('}') + '}'
|
||||
logger.info(f"Removed {close_braces - open_braces} excess closing braces")
|
||||
|
||||
if open_brackets > close_brackets:
|
||||
extracted_content += ']' * (open_brackets - close_brackets)
|
||||
logger.info(f"Added {open_brackets - close_brackets} missing closing brackets")
|
||||
elif close_brackets > open_brackets:
|
||||
# Remove excess closing brackets
|
||||
for _ in range(close_brackets - open_brackets):
|
||||
last_bracket = extracted_content.rfind(']')
|
||||
if last_bracket >= 0:
|
||||
extracted_content = extracted_content[:last_bracket] + extracted_content[last_bracket+1:]
|
||||
logger.info(f"Removed {close_brackets - open_brackets} excess closing brackets")
|
||||
|
||||
# Parse and format with Python's json module
|
||||
parsed = json.loads(extracted_content)
|
||||
# Format with Python's json module
|
||||
formatted_json = json.dumps(parsed, indent=2)
|
||||
logger.info(f"Successfully extracted valid JSON for '{disease_name}' after extraction attempt")
|
||||
return formatted_json, input_token_estimate, output_token_estimate
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# If all attempts fail, try manual repair of common issues
|
||||
try:
|
||||
# Remove any triple backticks
|
||||
cleaned = re.sub(r'```.*?```', '', content, flags=re.DOTALL)
|
||||
# Remove any backticks
|
||||
cleaned = re.sub(r'`', '', cleaned)
|
||||
# Ensure proper quotes (replace single quotes with double quotes where needed)
|
||||
cleaned = re.sub(r"'([^']*?)':", r'"\1":', cleaned)
|
||||
# Fix trailing commas before closing brackets
|
||||
cleaned = re.sub(r',\s*}', '}', cleaned)
|
||||
cleaned = re.sub(r',\s*]', ']', cleaned)
|
||||
|
||||
# Find the module content between { }
|
||||
module_match = re.search(r'({[\s\S]*})', cleaned)
|
||||
if module_match:
|
||||
module_json = module_match.group(1)
|
||||
# Parse and format with Python's json module
|
||||
parsed = json.loads(module_json)
|
||||
# Format with Python's json module
|
||||
formatted_json = json.dumps(parsed, indent=2)
|
||||
logger.info(f"Successfully repaired and extracted JSON for '{disease_name}'")
|
||||
return formatted_json, input_token_estimate, output_token_estimate
|
||||
except Exception as repair_error:
|
||||
logger.error(f"Failed to repair JSON: {repair_error}")
|
||||
|
||||
# Log a sample of the problematic content for debugging
|
||||
debug_sample = content[:min(500, len(content))] + "..." if len(content) > 500 else content
|
||||
logger.error(f"Could not extract valid JSON from Claude's response for '{disease_name}'")
|
||||
logger.error(f"Response sample (first 500 chars): {debug_sample}")
|
||||
logger.error(f"Full content length: {len(content)} characters")
|
||||
raise ValueError("Could not extract valid JSON from Claude's response")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating module with Claude: {e}")
|
||||
raise
|
||||
|
||||
def update_progress_file(disease_name: str, icd_code: str) -> None:
|
||||
"""Update the progress tracking file with the newly created module"""
|
||||
try:
|
||||
with open(PROGRESS_FILE, 'r') as f:
|
||||
content = f.readlines()
|
||||
|
||||
# Find the "Completed Modules" section and add the new module
|
||||
for i, line in enumerate(content):
|
||||
if line.startswith('## Completed Modules'):
|
||||
# Count existing modules to determine next number
|
||||
count = 0
|
||||
for j in range(i+1, len(content)):
|
||||
if content[j].strip() and content[j][0].isdigit():
|
||||
count += 1
|
||||
elif content[j].startswith('##'):
|
||||
break
|
||||
|
||||
# Insert new module entry
|
||||
content.insert(i+1+count, f"{count+1}. {disease_name} ({icd_code}) - Generated by module_generator.py\n")
|
||||
break
|
||||
|
||||
with open(PROGRESS_FILE, 'w') as f:
|
||||
f.writelines(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating progress file: {e}")
|
||||
|
||||
def normalize_filename(disease_name: str) -> str:
|
||||
"""Convert disease name to a valid filename"""
|
||||
# Replace spaces and special characters with underscores
|
||||
filename = re.sub(r'[^a-zA-Z0-9]', '_', disease_name.lower())
|
||||
# Remove consecutive underscores
|
||||
filename = re.sub(r'_+', '_', filename)
|
||||
# Remove leading/trailing underscores
|
||||
filename = filename.strip('_')
|
||||
return filename
|
||||
|
||||
def get_existing_module_path(disease_name: str, disease_info: Dict[str, str]) -> Optional[str]:
|
||||
"""Check if a module for this disease already exists and return its path"""
|
||||
# Method 1: Check by normalized filename
|
||||
normalized_name = normalize_filename(disease_name)
|
||||
candidate_path = os.path.join(MODULES_DIR, f"{normalized_name}.json")
|
||||
if os.path.exists(candidate_path):
|
||||
return candidate_path
|
||||
|
||||
# Method 2: Check by ICD-10 code in existing modules
|
||||
icd_code = disease_info.get('icd_10', '')
|
||||
if icd_code:
|
||||
for module_path in glob.glob(os.path.join(MODULES_DIR, "*.json")):
|
||||
with open(module_path, 'r') as f:
|
||||
try:
|
||||
content = f.read()
|
||||
# Check for exact ICD-10 code match
|
||||
if f'"code": "{icd_code}"' in content:
|
||||
return module_path
|
||||
except:
|
||||
pass
|
||||
|
||||
# No existing module found
|
||||
return None
|
||||
|
||||
def estimate_disease_prevalence(disease_name: str, icd_code: str) -> float:
|
||||
"""Estimate disease prevalence for prioritization (higher = more prevalent)"""
|
||||
# This is a simple heuristic - you could replace with actual prevalence data if available
|
||||
|
||||
# Some common conditions tend to have higher prevalence
|
||||
common_conditions = [
|
||||
"hypertension", "diabetes", "arthritis", "asthma", "depression",
|
||||
"anxiety", "obesity", "cancer", "heart", "copd", "stroke", "pneumonia",
|
||||
"bronchitis", "influenza", "infection", "pain", "fracture"
|
||||
]
|
||||
|
||||
score = 1.0
|
||||
|
||||
# Check if it contains common condition keywords
|
||||
name_lower = disease_name.lower()
|
||||
for condition in common_conditions:
|
||||
if condition in name_lower:
|
||||
score += 2.0
|
||||
break
|
||||
|
||||
# ICD-10 chapter weighting (approximate prevalence by chapter)
|
||||
if icd_code.startswith('I'): # Circulatory system
|
||||
score += 5.0
|
||||
elif icd_code.startswith('J'): # Respiratory system
|
||||
score += 4.0
|
||||
elif icd_code.startswith('K'): # Digestive system
|
||||
score += 3.5
|
||||
elif icd_code.startswith('M'): # Musculoskeletal system
|
||||
score += 3.0
|
||||
elif icd_code.startswith('E'): # Endocrine, nutritional and metabolic diseases
|
||||
score += 4.0
|
||||
elif icd_code.startswith('F'): # Mental and behavioral disorders
|
||||
score += 3.5
|
||||
|
||||
# Prefer shorter, more specific disease names (likely more common conditions)
|
||||
if len(disease_name.split()) <= 3:
|
||||
score += 1.0
|
||||
|
||||
return score
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate Synthea disease modules')
|
||||
parser.add_argument('--diseases', type=str, help='Comma-separated list of ICD-10 codes to process')
|
||||
parser.add_argument('--limit', type=int, default=10, help='Maximum number of modules to generate')
|
||||
parser.add_argument('--prioritize', action='store_true', help='Prioritize high-prevalence diseases')
|
||||
parser.add_argument('--log-file', type=str, help='Path to log file for token usage tracking')
|
||||
parser.add_argument('--strict', action='store_true',
|
||||
help='Fail immediately on validation errors instead of trying to fix them')
|
||||
# Add support for direct disease generation
|
||||
parser.add_argument('--disease', type=str, help='Single disease name to generate')
|
||||
parser.add_argument('--output', type=str, help='Output path for the generated module')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging with custom log file if specified
|
||||
global logger
|
||||
if args.log_file:
|
||||
logger = setup_logging(args.log_file)
|
||||
else:
|
||||
logger = setup_logging()
|
||||
|
||||
# Check if we're operating in direct mode (single disease)
|
||||
if args.disease and args.output:
|
||||
try:
|
||||
logger.info(f"Generating module for single disease: {args.disease}")
|
||||
|
||||
# Create a simple disease info dictionary
|
||||
disease_info = {
|
||||
'icd_10': '', # Empty since we don't have this info
|
||||
'snomed': '', # Empty since we don't have this info
|
||||
'ICD-10_name': args.disease
|
||||
}
|
||||
|
||||
# Try to find a relevant template
|
||||
templates = glob.glob(os.path.join(TEMPLATES_DIR, "*.json"))
|
||||
if templates:
|
||||
template_path = templates[0] # Use the first template found
|
||||
else:
|
||||
# Use a simple default template if none found
|
||||
logger.warning("No template found, using a minimal default")
|
||||
template_path = os.path.join(MODULES_DIR, "appendicitis.json")
|
||||
if not os.path.exists(template_path):
|
||||
raise ValueError(f"Cannot find any suitable template for generation")
|
||||
|
||||
# Generate the module
|
||||
module_content, _, _ = generate_module_with_claude(args.disease, disease_info, template_path)
|
||||
|
||||
# Save the module to the specified output path
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(module_content)
|
||||
|
||||
logger.info(f"Successfully generated module for {args.disease}")
|
||||
print(f"Successfully generated module for {args.disease}")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating single disease module: {e}")
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load disease list
|
||||
all_diseases = load_disease_list()
|
||||
logger.info(f"Loaded {len(all_diseases)} diseases from disease_list.json")
|
||||
|
||||
# Get existing modules
|
||||
existing_modules = get_existing_modules()
|
||||
logger.info(f"Found {len(existing_modules)} existing modules")
|
||||
|
||||
# Filter diseases to process
|
||||
if args.diseases:
|
||||
disease_codes = [code.strip() for code in args.diseases.split(',')]
|
||||
to_process = {name: info for name, info in all_diseases.items()
|
||||
if info.get('icd_10', '').split('.')[0] in disease_codes}
|
||||
logger.info(f"Filtered to {len(to_process)} diseases matching specified codes")
|
||||
else:
|
||||
# Process all diseases up to the limit
|
||||
to_process = all_diseases
|
||||
|
||||
# Only include diseases that don't already have modules by filename
|
||||
# (We'll do a more thorough check later with get_existing_module_path)
|
||||
diseases_to_create = {}
|
||||
candidate_diseases = []
|
||||
|
||||
for name, info in to_process.items():
|
||||
normalized_name = normalize_filename(name)
|
||||
if normalized_name not in existing_modules:
|
||||
if args.prioritize:
|
||||
# Add to candidates for prioritization
|
||||
icd_code = info.get('icd_10', '').split('.')[0]
|
||||
prevalence_score = estimate_disease_prevalence(name, icd_code)
|
||||
candidate_diseases.append((name, info, prevalence_score))
|
||||
else:
|
||||
diseases_to_create[name] = info
|
||||
|
||||
# Respect the limit for non-prioritized mode
|
||||
if len(diseases_to_create) >= args.limit:
|
||||
break
|
||||
|
||||
# If prioritizing, sort by estimated prevalence and take top N
|
||||
if args.prioritize and candidate_diseases:
|
||||
logger.info("Prioritizing diseases by estimated prevalence")
|
||||
candidate_diseases.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Log top candidates for transparency
|
||||
logger.info("Top candidates by estimated prevalence:")
|
||||
for i, (name, info, score) in enumerate(candidate_diseases[:min(10, len(candidate_diseases))]):
|
||||
logger.info(f" {i+1}. {name} (ICD-10: {info.get('icd_10', 'Unknown')}) - Score: {score:.2f}")
|
||||
|
||||
# Select top N diseases
|
||||
for name, info, _ in candidate_diseases[:args.limit]:
|
||||
diseases_to_create[name] = info
|
||||
|
||||
logger.info(f"Will generate modules for {len(diseases_to_create)} diseases")
|
||||
|
||||
# Generate modules
|
||||
for disease_name, disease_info in tqdm(diseases_to_create.items(), desc="Generating modules"):
|
||||
try:
|
||||
# First check if module already exists - no need to use LLM if it does
|
||||
existing_module_path = get_existing_module_path(disease_name, disease_info)
|
||||
|
||||
if existing_module_path:
|
||||
# Module already exists, just copy it
|
||||
logger.info(f"Module for {disease_name} already exists at {existing_module_path}")
|
||||
|
||||
# Read existing module
|
||||
with open(existing_module_path, 'r') as f:
|
||||
module_content = f.read()
|
||||
|
||||
# Save to normalized filename if different from existing path
|
||||
filename = normalize_filename(disease_name) + ".json"
|
||||
output_path = os.path.join(MODULES_DIR, filename)
|
||||
|
||||
if output_path != existing_module_path:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(module_content)
|
||||
logger.info(f"Copied existing module to {output_path}")
|
||||
else:
|
||||
logger.info(f"Existing module already has correct filename")
|
||||
|
||||
# Update progress file
|
||||
icd_code = disease_info.get('icd_10', 'Unknown')
|
||||
update_progress_file(disease_name, icd_code)
|
||||
|
||||
else:
|
||||
# No existing module, generate with Claude
|
||||
# Find best template
|
||||
template_path = find_most_relevant_module(disease_info, existing_modules)
|
||||
logger.info(f"Using {os.path.basename(template_path)} as template for {disease_name}")
|
||||
|
||||
# Generate module
|
||||
module_content, input_tokens, output_tokens = generate_module_with_claude(disease_name, disease_info, template_path)
|
||||
|
||||
# Save module
|
||||
filename = normalize_filename(disease_name) + ".json"
|
||||
output_path = os.path.join(MODULES_DIR, filename)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(module_content)
|
||||
|
||||
logger.info(f"Successfully created module for {disease_name}")
|
||||
|
||||
# Update progress file
|
||||
icd_code = disease_info.get('icd_10', 'Unknown')
|
||||
update_progress_file(disease_name, icd_code)
|
||||
|
||||
# Sleep to avoid hitting API rate limits
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate module for {disease_name}: {e}")
|
||||
|
||||
logger.info("Module generation complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
478
module_generator/run_module_generator.py
Executable file
478
module_generator/run_module_generator.py
Executable file
@@ -0,0 +1,478 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Disease Module Generator Loop for Synthea
|
||||
|
||||
This script runs the module_generator.py script in a loop to generate
|
||||
modules for diseases in disease_list.json that don't have modules yet.
|
||||
|
||||
Usage:
|
||||
python run_module_generator.py [--batch-size BATCH_SIZE] [--max-modules MAX_MODULES]
|
||||
|
||||
Arguments:
|
||||
--batch-size BATCH_SIZE Number of modules to generate in each batch (default: 5)
|
||||
--max-modules MAX_MODULES Maximum total number of modules to generate (default: no limit)
|
||||
--prioritize Prioritize high-prevalence diseases first
|
||||
--max-cost MAX_COST Maximum cost in USD to spend (default: no limit)
|
||||
--strict Fail immediately on module validation errors instead of trying to fix them
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import anthropic
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("module_generation_runner.log"),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
# Support both Docker (/app) and local development paths
|
||||
if os.path.exists("/app/src/main/resources/disease_list.json"):
|
||||
DISEASE_LIST_PATH = "/app/src/main/resources/disease_list.json"
|
||||
MODULES_DIR = "/app/src/main/resources/modules"
|
||||
LOG_FILE_PATH = "/app/module_generation_runner.log"
|
||||
else:
|
||||
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
|
||||
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
|
||||
LOG_FILE_PATH = os.path.join(SYNTHEA_ROOT, "module_generation_runner.log")
|
||||
|
||||
MODULE_GENERATOR_PATH = os.path.join(os.path.dirname(__file__), "module_generator.py")
|
||||
|
||||
# API costs - Claude 3.7 Sonnet (as of March 2025)
|
||||
# These are approximate costs based on current pricing
|
||||
CLAUDE_INPUT_COST_PER_1K_TOKENS = 0.003 # $0.003 per 1K input tokens
|
||||
CLAUDE_OUTPUT_COST_PER_1K_TOKENS = 0.015 # $0.015 per 1K output tokens
|
||||
|
||||
# Initialize cost tracking
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_cost_usd = 0.0
|
||||
|
||||
def count_existing_modules():
|
||||
"""Count the number of modules already created"""
|
||||
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
||||
return len(module_files)
|
||||
|
||||
def count_remaining_diseases():
|
||||
"""Count how many diseases in disease_list.json don't have modules yet"""
|
||||
# Load the disease list
|
||||
with open(DISEASE_LIST_PATH, 'r') as f:
|
||||
diseases = json.load(f)
|
||||
|
||||
# Get existing module names
|
||||
existing_modules = set()
|
||||
for module_file in Path(MODULES_DIR).glob("*.json"):
|
||||
module_name = module_file.stem.lower()
|
||||
existing_modules.add(module_name)
|
||||
|
||||
# Count diseases that don't have modules
|
||||
remaining = 0
|
||||
for disease in diseases:
|
||||
# Extract disease name from the disease object
|
||||
disease_name = disease['disease_name']
|
||||
# Convert disease name to module filename format
|
||||
module_name = disease_name.lower()
|
||||
module_name = ''.join(c if c.isalnum() else '_' for c in module_name)
|
||||
module_name = module_name.strip('_')
|
||||
# Replace consecutive underscores with a single one
|
||||
while '__' in module_name:
|
||||
module_name = module_name.replace('__', '_')
|
||||
|
||||
if module_name not in existing_modules:
|
||||
remaining += 1
|
||||
|
||||
return remaining, len(diseases)
|
||||
|
||||
def parse_api_usage(log_content):
|
||||
"""Parse API usage stats from log content"""
|
||||
global total_input_tokens, total_output_tokens, total_cost_usd
|
||||
|
||||
# Look for token counts in the log - using the updated pattern from the logs
|
||||
# This matches the format: "Estimated input tokens: 7031, Estimated output tokens: 3225"
|
||||
input_token_matches = re.findall(r'Estimated input tokens: (\d+)', log_content)
|
||||
output_token_matches = re.findall(r'Estimated output tokens: (\d+)', log_content)
|
||||
|
||||
batch_input_tokens = sum(int(count) for count in input_token_matches)
|
||||
batch_output_tokens = sum(int(count) for count in output_token_matches)
|
||||
|
||||
# Calculate cost for this batch
|
||||
batch_input_cost = batch_input_tokens * CLAUDE_INPUT_COST_PER_1K_TOKENS / 1000
|
||||
batch_output_cost = batch_output_tokens * CLAUDE_OUTPUT_COST_PER_1K_TOKENS / 1000
|
||||
batch_total_cost = batch_input_cost + batch_output_cost
|
||||
|
||||
# Update totals
|
||||
total_input_tokens += batch_input_tokens
|
||||
total_output_tokens += batch_output_tokens
|
||||
total_cost_usd += batch_total_cost
|
||||
|
||||
# Log the token counts found for debugging
|
||||
logger.info(f"Found token counts in log - Input: {input_token_matches}, Output: {output_token_matches}")
|
||||
|
||||
return batch_input_tokens, batch_output_tokens, batch_total_cost
|
||||
|
||||
def run_module_generator(batch_size=5, prioritize=False, timeout=600, strict=False):
|
||||
"""Run the module generator script with the specified batch size"""
|
||||
try:
|
||||
# Create a unique log file for this batch to capture token usage
|
||||
batch_log_file = f"module_generator_batch_{int(time.time())}.log"
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
MODULE_GENERATOR_PATH,
|
||||
'--limit', str(batch_size),
|
||||
'--log-file', batch_log_file
|
||||
]
|
||||
|
||||
if prioritize:
|
||||
cmd.append('--prioritize')
|
||||
|
||||
if strict:
|
||||
cmd.append('--strict')
|
||||
|
||||
logger.info(f"Running command: {' '.join(cmd)}")
|
||||
print(f"\n[INFO] Starting batch, generating up to {batch_size} modules...")
|
||||
|
||||
# Use timeout to prevent indefinite hanging
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=timeout)
|
||||
success = True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"Module generator timed out after {timeout} seconds")
|
||||
print(f"\n[ERROR] Process timed out after {timeout} seconds!")
|
||||
success = False
|
||||
result = None
|
||||
|
||||
# Process output only if the command completed successfully
|
||||
if success and result:
|
||||
# Log output
|
||||
if result.stdout:
|
||||
for line in result.stdout.splitlines():
|
||||
logger.info(f"Generator output: {line}")
|
||||
|
||||
# Print progress-related lines directly to console
|
||||
if "Generated module for" in line or "modules complete" in line:
|
||||
print(f"[PROGRESS] {line}")
|
||||
|
||||
# Check if any errors
|
||||
if result.stderr:
|
||||
for line in result.stderr.splitlines():
|
||||
if "Error" in line or "Exception" in line:
|
||||
logger.error(f"Generator error: {line}")
|
||||
print(f"[ERROR] {line}")
|
||||
else:
|
||||
logger.info(f"Generator message: {line}")
|
||||
|
||||
# Parse token usage from the log file
|
||||
if os.path.exists(batch_log_file):
|
||||
with open(batch_log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Parse usage statistics
|
||||
batch_input_tokens, batch_output_tokens, batch_cost = parse_api_usage(log_content)
|
||||
|
||||
# Log usage for this batch
|
||||
logger.info(f"Batch API Usage - Input tokens: {batch_input_tokens:,}, "
|
||||
f"Output tokens: {batch_output_tokens:,}, Cost: ${batch_cost:.4f}")
|
||||
logger.info(f"Total API Usage - Input tokens: {total_input_tokens:,}, "
|
||||
f"Output tokens: {total_output_tokens:,}, Total Cost: ${total_cost_usd:.4f}")
|
||||
|
||||
# Print to console in a prominent way
|
||||
print("\n" + "="*80)
|
||||
print(f"BATCH API USAGE - Cost: ${batch_cost:.4f}")
|
||||
print(f"TOTAL API USAGE - Cost: ${total_cost_usd:.4f}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Clean up batch log file
|
||||
try:
|
||||
os.remove(batch_log_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Count created files to verify success
|
||||
if os.path.exists(batch_log_file):
|
||||
with open(batch_log_file, 'r') as f:
|
||||
log_content = f.read()
|
||||
modules_created = log_content.count("Successfully created module for")
|
||||
if modules_created > 0:
|
||||
print(f"[SUCCESS] Created {modules_created} module(s) in this batch")
|
||||
|
||||
# Clean up even if processing didn't complete
|
||||
try:
|
||||
os.remove(batch_log_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
return success
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Module generator failed with exit code {e.returncode}")
|
||||
logger.error(f"Error output: {e.stderr}")
|
||||
print(f"[ERROR] Module generator process failed with exit code {e.returncode}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run module generator: {e}")
|
||||
print(f"[ERROR] Failed to run module generator: {e}")
|
||||
return False
|
||||
|
||||
def validate_and_fix_existing_modules(strict=False):
|
||||
"""Check all existing modules for JSON validity and fix if needed"""
|
||||
print("Validating existing modules...")
|
||||
|
||||
if strict:
|
||||
print("Running in strict mode - modules will NOT be fixed automatically")
|
||||
|
||||
fixed_count = 0
|
||||
module_files = list(Path(MODULES_DIR).glob("*.json"))
|
||||
|
||||
# First, show a summary of all modules
|
||||
print(f"Found {len(module_files)} module files")
|
||||
|
||||
invalid_modules = []
|
||||
|
||||
# Check all modules for validity
|
||||
for module_path in module_files:
|
||||
try:
|
||||
with open(module_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Try to parse the JSON
|
||||
try:
|
||||
json.loads(content)
|
||||
# If it parses successfully, all good
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
invalid_modules.append((module_path, e))
|
||||
print(f"⚠️ Invalid JSON found in {module_path.name}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {module_path}: {e}")
|
||||
|
||||
if not invalid_modules:
|
||||
print("✅ All modules are valid JSON")
|
||||
return 0
|
||||
|
||||
print(f"\nFound {len(invalid_modules)} invalid module(s) to fix")
|
||||
|
||||
# Now fix the invalid modules
|
||||
for module_path, error in invalid_modules:
|
||||
try:
|
||||
print(f"\nAttempting to fix {module_path.name}...")
|
||||
|
||||
if strict:
|
||||
print(f"❌ Module {module_path.name} has validation issues that must be fixed manually")
|
||||
continue # Skip fixing in strict mode
|
||||
|
||||
with open(module_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Try to fix common issues
|
||||
fixed_content = content
|
||||
|
||||
# Fix trailing commas
|
||||
fixed_content = re.sub(r',\s*}', '}', fixed_content)
|
||||
fixed_content = re.sub(r',\s*]', ']', fixed_content)
|
||||
|
||||
# Check for incomplete states section
|
||||
if '"states": {' in fixed_content:
|
||||
# Make sure it has a proper closing brace
|
||||
states_match = re.search(r'"states":\s*{(.*)}(?=\s*,|\s*})', fixed_content, re.DOTALL)
|
||||
if not states_match and '"states": {' in fixed_content:
|
||||
print(" - Module appears to have an incomplete states section")
|
||||
|
||||
# Find last complete state definition
|
||||
last_state_end = None
|
||||
state_pattern = re.compile(r'("[^"]+"\s*:\s*{[^{]*?}),?', re.DOTALL)
|
||||
for match in state_pattern.finditer(fixed_content, fixed_content.find('"states": {') + 12):
|
||||
last_state_end = match.end()
|
||||
|
||||
if last_state_end:
|
||||
# Add closing brace after the last valid state
|
||||
fixed_content = fixed_content[:last_state_end] + '\n }\n }' + fixed_content[last_state_end:]
|
||||
print(" - Added closing braces for states section")
|
||||
|
||||
# Fix missing/unbalanced braces
|
||||
open_braces = fixed_content.count('{')
|
||||
close_braces = fixed_content.count('}')
|
||||
|
||||
if open_braces > close_braces:
|
||||
fixed_content += '}' * (open_braces - close_braces)
|
||||
print(f" - Added {open_braces - close_braces} missing closing braces")
|
||||
elif close_braces > open_braces:
|
||||
for _ in range(close_braces - open_braces):
|
||||
fixed_content = fixed_content.rstrip().rstrip('}') + '}'
|
||||
print(f" - Removed {close_braces - open_braces} excess closing braces")
|
||||
|
||||
# Try to parse the fixed content
|
||||
try:
|
||||
parsed_json = json.loads(fixed_content)
|
||||
|
||||
# Check if it has all required elements
|
||||
if 'name' not in parsed_json:
|
||||
print(" - Warning: Module missing 'name' field")
|
||||
if 'states' not in parsed_json:
|
||||
print(" - Warning: Module missing 'states' field")
|
||||
if 'gmf_version' not in parsed_json:
|
||||
print(" - Adding missing gmf_version field")
|
||||
# Make sure we have a proper JSON object
|
||||
if fixed_content.rstrip().endswith('}'):
|
||||
# Add the gmf_version before the final brace
|
||||
fixed_content = fixed_content.rstrip().rstrip('}') + ',\n "gmf_version": 2\n}'
|
||||
|
||||
# Backup the original file
|
||||
backup_path = str(module_path) + '.bak'
|
||||
import shutil
|
||||
shutil.copy2(str(module_path), backup_path)
|
||||
|
||||
# Write the fixed content
|
||||
with open(module_path, 'w') as f:
|
||||
f.write(fixed_content)
|
||||
|
||||
print(f"✅ Fixed module {module_path.name}")
|
||||
fixed_count += 1
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ Could not fix module {module_path.name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {module_path}: {e}")
|
||||
|
||||
print(f"\nValidation complete. Fixed {fixed_count} of {len(invalid_modules)} invalid modules.")
|
||||
return fixed_count
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run the module generator in a loop')
|
||||
parser.add_argument('--batch-size', type=int, default=5,
|
||||
help='Number of modules to generate in each batch (default: 5)')
|
||||
parser.add_argument('--max-modules', type=int, default=None,
|
||||
help='Maximum total number of modules to generate (default: no limit)')
|
||||
parser.add_argument('--max-batches', type=int, default=None,
|
||||
help='Maximum number of batches to run (default: no limit)')
|
||||
parser.add_argument('--prioritize', action='store_true',
|
||||
help='Prioritize high-prevalence diseases first')
|
||||
parser.add_argument('--max-cost', type=float, default=None,
|
||||
help='Maximum cost in USD to spend (default: no limit)')
|
||||
parser.add_argument('--timeout', type=int, default=600,
|
||||
help='Timeout in seconds for each batch process (default: 600)')
|
||||
parser.add_argument('--fix-modules', action='store_true',
|
||||
help='Check and fix existing modules for JSON validity')
|
||||
parser.add_argument('--strict', action='store_true',
|
||||
help='Fail immediately on module validation errors instead of trying to fix them')
|
||||
args = parser.parse_args()
|
||||
|
||||
# If fix-modules flag is set, validate and fix existing modules
|
||||
if args.fix_modules:
|
||||
validate_and_fix_existing_modules(strict=args.strict)
|
||||
return
|
||||
|
||||
# Initial counts
|
||||
initial_modules = count_existing_modules()
|
||||
initial_remaining, total_diseases = count_remaining_diseases()
|
||||
|
||||
logger.info(f"Starting module generation loop")
|
||||
logger.info(f"Currently have {initial_modules} modules")
|
||||
logger.info(f"Remaining diseases without modules: {initial_remaining} of {total_diseases}")
|
||||
|
||||
# Set up counters
|
||||
modules_created = 0
|
||||
batch_count = 0
|
||||
|
||||
# Loop until we've hit our maximum or there are no more diseases to process
|
||||
while True:
|
||||
# Break if we've reached the maximum modules to generate
|
||||
if args.max_modules is not None and modules_created >= args.max_modules:
|
||||
logger.info(f"Reached maximum modules limit ({args.max_modules})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"MODULE LIMIT REACHED: {modules_created} >= {args.max_modules}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Break if we've reached the maximum batches
|
||||
if args.max_batches is not None and batch_count >= args.max_batches:
|
||||
logger.info(f"Reached maximum batch limit ({args.max_batches})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"BATCH LIMIT REACHED: {batch_count} >= {args.max_batches}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Break if we've reached the maximum cost
|
||||
if args.max_cost is not None and total_cost_usd >= args.max_cost:
|
||||
logger.info(f"Reached maximum cost limit (${args.max_cost:.2f})")
|
||||
print(f"\n{'!'*80}")
|
||||
print(f"COST LIMIT REACHED: ${total_cost_usd:.4f} >= ${args.max_cost:.2f}")
|
||||
print(f"{'!'*80}\n")
|
||||
break
|
||||
|
||||
# Count how many modules we need to generate in this batch
|
||||
if args.max_modules is not None:
|
||||
# Adjust batch size if we're near the maximum
|
||||
current_batch_size = min(args.batch_size, args.max_modules - modules_created)
|
||||
else:
|
||||
current_batch_size = args.batch_size
|
||||
|
||||
if current_batch_size <= 0:
|
||||
logger.info("No more modules to generate in this batch")
|
||||
print("[INFO] No more modules to generate (reached limit)")
|
||||
break
|
||||
|
||||
# Run the generator for this batch
|
||||
batch_count += 1
|
||||
logger.info(f"Starting batch {batch_count} (generating up to {current_batch_size} modules)")
|
||||
|
||||
# Run the generator
|
||||
success = run_module_generator(
|
||||
batch_size=current_batch_size,
|
||||
prioritize=args.prioritize,
|
||||
timeout=args.timeout,
|
||||
strict=args.strict
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"Batch {batch_count} failed, stopping")
|
||||
print(f"[ERROR] Batch {batch_count} failed, stopping")
|
||||
break
|
||||
|
||||
# Count how many modules we've created so far
|
||||
new_module_count = count_existing_modules()
|
||||
modules_created = new_module_count - initial_modules
|
||||
|
||||
# Count remaining diseases
|
||||
remaining_diseases, _ = count_remaining_diseases()
|
||||
|
||||
logger.info(f"Completed batch {batch_count}")
|
||||
logger.info(f"Total modules created: {modules_created}")
|
||||
logger.info(f"Remaining diseases without modules: {remaining_diseases}")
|
||||
|
||||
print(f"[STATUS] Batch {batch_count} complete")
|
||||
print(f"[STATUS] Total modules created: {modules_created}")
|
||||
print(f"[STATUS] Remaining diseases: {remaining_diseases} of {total_diseases}")
|
||||
|
||||
# Break if no diseases left
|
||||
if remaining_diseases <= 0:
|
||||
logger.info("All diseases have modules now, finished!")
|
||||
print("[SUCCESS] All diseases have modules now, finished!")
|
||||
break
|
||||
|
||||
# Sleep briefly between batches to avoid system overload
|
||||
time.sleep(2)
|
||||
|
||||
# Final status
|
||||
logger.info(f"Module generation complete!")
|
||||
logger.info(f"Started with {initial_modules} modules, now have {count_existing_modules()}")
|
||||
logger.info(f"Created {modules_created} new modules in {batch_count} batches")
|
||||
|
||||
remaining, total = count_remaining_diseases()
|
||||
logger.info(f"Remaining diseases without modules: {remaining} of {total}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user