732 lines
33 KiB
Python
Executable File
732 lines
33 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Module Generator for Synthea
|
|
|
|
This script automates the creation of new disease modules for Synthea based on
|
|
disease_list.json. It uses Claude 3.7 to generate appropriate JSON structures
|
|
for each disease, leveraging existing modules as templates.
|
|
|
|
Usage:
|
|
python module_generator.py [--diseases DISEASES] [--limit LIMIT]
|
|
|
|
Arguments:
|
|
--diseases DISEASES Comma-separated list of specific diseases to process (ICD-10 codes)
|
|
--limit LIMIT Maximum number of modules to generate (default: 10)
|
|
|
|
Example:
|
|
python src/main/python/run_module_generator.py --batch-size 3 --max-cost 10.0 --prioritize
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import glob
|
|
import re
|
|
import argparse
|
|
import time
|
|
import anthropic
|
|
import logging
|
|
from tqdm import tqdm
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
def setup_logging(log_file_path=None):
|
|
"""Configure logging to both file and console"""
|
|
if log_file_path is None:
|
|
log_file_path = "module_generation.log"
|
|
|
|
handlers = [
|
|
logging.FileHandler(log_file_path),
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=handlers
|
|
)
|
|
|
|
return logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants
|
|
SYNTHEA_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
|
DISEASE_LIST_PATH = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_list.json")
|
|
|
|
# Allow overriding MODULES_DIR with environment variable (for Docker)
|
|
MODULES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/modules")
|
|
if not os.path.exists(MODULES_DIR):
|
|
os.makedirs(MODULES_DIR, exist_ok=True)
|
|
logger.info(f"Created modules directory at {MODULES_DIR}")
|
|
|
|
PROGRESS_FILE = os.path.join(SYNTHEA_ROOT, "src/main/resources/disease_modules_progress.md")
|
|
TEMPLATES_DIR = os.path.join(SYNTHEA_ROOT, "src/main/resources/templates/modules")
|
|
|
|
# Check if directories exist, create if they don't
|
|
if not os.path.exists(TEMPLATES_DIR):
|
|
os.makedirs(TEMPLATES_DIR, exist_ok=True)
|
|
logger.info(f"Created templates directory at {TEMPLATES_DIR}")
|
|
|
|
# Check if progress file exists, create if it doesn't
|
|
if not os.path.exists(PROGRESS_FILE):
|
|
with open(PROGRESS_FILE, 'w') as f:
|
|
f.write("# Disease Module Progress\n\n")
|
|
f.write("## Completed Modules\n\n")
|
|
f.write("## Planned Modules\n\n")
|
|
logger.info(f"Created progress file at {PROGRESS_FILE}")
|
|
|
|
# Initialize Claude client (API key should be stored in ANTHROPIC_API_KEY environment variable)
|
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
if not api_key:
|
|
logger.error("ANTHROPIC_API_KEY environment variable is not set")
|
|
sys.exit(1)
|
|
|
|
# Create client without the proxies parameter which causes issues with older versions
|
|
client = anthropic.Client(api_key=api_key)
|
|
CLAUDE_MODEL = "claude-3-7-sonnet-20250219" # Updated to the current model version
|
|
MAX_TOKENS = 4096 # Maximum allowed tokens for Claude 3.7 Sonnet
|
|
|
|
def validate_condition_format(module_json):
|
|
"""Validate that conditions in the module follow Synthea's expected format"""
|
|
try:
|
|
module_dict = json.loads(module_json) if isinstance(module_json, str) else module_json
|
|
|
|
# Function to recursively check objects for improper condition structure
|
|
def check_conditions(obj):
|
|
issues = []
|
|
|
|
if isinstance(obj, dict):
|
|
# Check if this is a condition object with nested condition_type
|
|
if "condition" in obj and isinstance(obj["condition"], dict):
|
|
condition = obj["condition"]
|
|
# Look for the improper nested structure
|
|
if "condition_type" in condition and isinstance(condition["condition_type"], dict):
|
|
issues.append("Found nested condition_type in a condition object")
|
|
|
|
# Recursively check all dictionary values
|
|
for key, value in obj.items():
|
|
child_issues = check_conditions(value)
|
|
issues.extend(child_issues)
|
|
|
|
elif isinstance(obj, list):
|
|
# Recursively check all list items
|
|
for item in obj:
|
|
child_issues = check_conditions(item)
|
|
issues.extend(child_issues)
|
|
|
|
return issues
|
|
|
|
# Check the entire module
|
|
issues = check_conditions(module_dict)
|
|
return len(issues) == 0, issues
|
|
|
|
except Exception as e:
|
|
return False, [f"Validation error: {str(e)}"]
|
|
|
|
def load_disease_list() -> Dict[str, Dict[str, str]]:
|
|
"""Load the disease list from JSON file"""
|
|
try:
|
|
with open(DISEASE_LIST_PATH, 'r') as f:
|
|
diseases_list = json.load(f)
|
|
|
|
# Convert list to dict using disease_name as the key
|
|
diseases_dict = {}
|
|
for disease in diseases_list:
|
|
disease_name = disease.get('disease_name', '')
|
|
if disease_name:
|
|
# Create a new disease info dict with additional info
|
|
disease_info = {
|
|
'icd_10': disease.get('id', ''),
|
|
'snomed': disease.get('snomed', ''),
|
|
'ICD-10_name': disease.get('ICD-10_name', '')
|
|
}
|
|
diseases_dict[disease_name] = disease_info
|
|
|
|
return diseases_dict
|
|
except Exception as e:
|
|
logger.error(f"Error loading disease list: {e}")
|
|
sys.exit(1)
|
|
|
|
def get_existing_modules() -> List[str]:
|
|
"""Get list of existing module names (without extension)"""
|
|
module_files = glob.glob(os.path.join(MODULES_DIR, "*.json"))
|
|
return [os.path.splitext(os.path.basename(f))[0].lower() for f in module_files]
|
|
|
|
def find_most_relevant_module(disease_info: Dict[str, str], existing_modules: List[str]) -> str:
|
|
"""Find the most relevant existing module to use as a template"""
|
|
# First check if an icd code match exists
|
|
icd_code_prefix = disease_info.get("icd_10", "")[:3] # Get first 3 chars of ICD-10 code
|
|
|
|
# Look for modules that might relate to the same body system
|
|
related_modules = []
|
|
for module_path in glob.glob(os.path.join(MODULES_DIR, "*.json")):
|
|
with open(module_path, 'r') as f:
|
|
try:
|
|
content = f.read()
|
|
# Check if this module contains the same ICD-10 prefix
|
|
if f'"{icd_code_prefix}' in content:
|
|
related_modules.append(module_path)
|
|
except:
|
|
pass
|
|
|
|
if related_modules:
|
|
# Return the most complex related module (largest file size as a heuristic)
|
|
return max(related_modules, key=os.path.getsize)
|
|
|
|
# If no ICD match, return a default template based on disease type
|
|
if icd_code_prefix.startswith('I'): # Circulatory system
|
|
return os.path.join(MODULES_DIR, "hypertensive_renal_disease.json")
|
|
elif icd_code_prefix.startswith('J'): # Respiratory system
|
|
return os.path.join(MODULES_DIR, "asthma.json")
|
|
elif icd_code_prefix.startswith('K'): # Digestive system
|
|
return os.path.join(MODULES_DIR, "appendicitis.json")
|
|
else:
|
|
# Default to a simple template
|
|
return os.path.join(TEMPLATES_DIR, "prevalence.json")
|
|
|
|
def generate_module_with_claude(
|
|
disease_name: str,
|
|
disease_info: Dict[str, str],
|
|
template_path: str
|
|
) -> Tuple[str, int, int]:
|
|
"""Use Claude to generate a new module based on the template"""
|
|
|
|
# Load the template module
|
|
logger.info(f"Loading template module from {os.path.basename(template_path)}")
|
|
with open(template_path, 'r') as f:
|
|
template_content = f.read()
|
|
|
|
# Get template module name
|
|
template_name = os.path.splitext(os.path.basename(template_path))[0]
|
|
|
|
# Construct the prompt
|
|
prompt = f"""You are a medical expert and software developer tasked with creating disease modules for Synthea, an open-source synthetic patient generator.
|
|
|
|
TASK: Create a valid JSON module for the disease: "{disease_name}" (ICD-10 code: {disease_info.get('icd_10', 'Unknown')}).
|
|
|
|
CRITICAL: Your response MUST ONLY contain valid, parseable JSON that starts with {{ and ends with }}. No explanations, text, markdown formatting, or code blocks.
|
|
|
|
Disease Information:
|
|
- Name: {disease_name}
|
|
- ICD-10 Code: {disease_info.get('icd_10', 'Unknown')}
|
|
- SNOMED-CT Code: {disease_info.get('snomed', 'Unknown')}
|
|
|
|
I'm providing an example module structure based on {template_name}.json as a reference:
|
|
|
|
```json
|
|
{template_content}
|
|
```
|
|
|
|
Technical Requirements (MUST follow ALL of these):
|
|
1. JSON Format:
|
|
- Use valid JSON syntax with no errors
|
|
- No trailing commas
|
|
- All property names in double quotes
|
|
- Always add a "gmf_version": 2 field at the top level
|
|
- Check that all brackets and braces are properly matched
|
|
|
|
2. Module Structure:
|
|
- Include "name" field with {disease_name}
|
|
- Include "remarks" array with at least 3 items
|
|
- Include "states" object with complete set of disease states
|
|
- Add at least 2-3 reference URLs in the "remarks" section
|
|
- Every state needs a unique name and valid type (like "Initial", "Terminal", etc.)
|
|
- All transitions must be valid (direct_transition, distributed_transition, etc.)
|
|
|
|
3. Medical Content:
|
|
- Include accurate medical codes (SNOMED-CT, ICD-10, LOINC, RxNorm)
|
|
- Model realistic disease prevalence based on age, gender, race
|
|
- Include relevant symptoms, diagnostic criteria, treatments
|
|
- Only include states that make clinical sense for this specific disease
|
|
|
|
4. CRITICAL CONDITION STRUCTURE REQUIREMENTS:
|
|
- In conditional statements, the 'condition_type' MUST be a top-level key within the condition object
|
|
- INCORRECT: "condition": {"condition_type": {"nested": "value"}, "name": "SomeState"}
|
|
- CORRECT: "condition": {"condition_type": "PriorState", "name": "SomeState"}
|
|
- Common condition types: "Age", "Gender", "Race", "Attribute", "And", "Or", "Not", "PriorState", "Active Condition"
|
|
- For PriorState conditions, use: {"condition_type": "PriorState", "name": "StateName"}
|
|
- For attribute checks, use: {"condition_type": "Attribute", "attribute": "attr_name", "operator": "==", "value": true}
|
|
|
|
IMPORTANT REMINDER: Respond with ONLY valid JSON, not explanations. The entire response should be a single JSON object that can be directly parsed.
|
|
"""
|
|
|
|
# Request generation from Claude using older API version 0.8.1
|
|
logger.info(f"Sending request to Claude API for '{disease_name}'")
|
|
try:
|
|
# Use the Anthropic API with older syntax
|
|
response = client.completion(
|
|
model=CLAUDE_MODEL,
|
|
max_tokens_to_sample=MAX_TOKENS,
|
|
temperature=0.1, # Lower temperature for more consistent, predicable output
|
|
prompt=f"\n\nHuman: {prompt}\n\nAssistant: ",
|
|
stream=False
|
|
)
|
|
|
|
# For v0.8.1, the response is a completion object with a completion property
|
|
content = response.completion
|
|
|
|
# Remove any markdown code block indicators
|
|
content = re.sub(r'^```json\s*', '', content)
|
|
content = re.sub(r'^```\s*', '', content)
|
|
content = re.sub(r'```$', '', content)
|
|
|
|
# Log content type for debugging
|
|
logger.debug(f"Content type: {type(content)}")
|
|
|
|
# Ensure content is string
|
|
if not isinstance(content, str):
|
|
content = content.decode('utf-8') if hasattr(content, 'decode') else str(content)
|
|
|
|
# Estimate token usage - we don't get this directly with streaming
|
|
# Rough estimate: 1 token ≈ 4 characters
|
|
input_token_estimate = len(prompt) // 4
|
|
output_token_estimate = len(content) // 4
|
|
|
|
# Log estimated token usage
|
|
logger.info(f"API call for '{disease_name}' - Estimated input tokens: {input_token_estimate}, "
|
|
f"Estimated output tokens: {output_token_estimate}")
|
|
|
|
# Validate and format JSON
|
|
try:
|
|
# Parse and format with Python's built-in json module
|
|
parsed = json.loads(content)
|
|
|
|
# Validate condition structure
|
|
valid, issues = validate_condition_format(parsed)
|
|
|
|
if not valid:
|
|
logger.warning(f"Generated module for {disease_name} has condition structure issues: {issues}")
|
|
logger.warning("Requesting regeneration with corrected instructions...")
|
|
|
|
# Add more specific instructions for the retry
|
|
retry_prompt = prompt + "\n\nPLEASE FIX THESE ISSUES: " + ", ".join(issues)
|
|
retry_prompt += "\n\nREMINDER: All condition objects must have 'condition_type' as a top-level key, NOT nested."
|
|
|
|
# Call Claude again with the refined prompt using older API
|
|
logger.info(f"Retrying generation for '{disease_name}' with specific instructions about issues")
|
|
retry_response = client.completion(
|
|
model=CLAUDE_MODEL,
|
|
max_tokens_to_sample=MAX_TOKENS,
|
|
temperature=0.1,
|
|
prompt=f"\n\nHuman: {retry_prompt}\n\nAssistant: ",
|
|
stream=False
|
|
)
|
|
|
|
# For v0.8.1, the response is a completion object with a completion property
|
|
retry_content = retry_response.completion
|
|
|
|
# Remove any markdown code block indicators
|
|
retry_content = re.sub(r'^```json\s*', '', retry_content)
|
|
retry_content = re.sub(r'^```\s*', '', retry_content)
|
|
retry_content = re.sub(r'```$', '', retry_content)
|
|
|
|
# Try to parse the retry response
|
|
try:
|
|
retry_parsed = json.loads(retry_content)
|
|
|
|
# Validate the retry response
|
|
retry_valid, retry_issues = validate_condition_format(retry_parsed)
|
|
|
|
if not retry_valid and args.strict:
|
|
logger.error(f"Failed to fix condition structure issues after retry: {retry_issues}")
|
|
raise ValueError(f"Module format validation failed: {retry_issues}")
|
|
elif not retry_valid:
|
|
logger.warning(f"Retry still has issues, but proceeding due to non-strict mode: {retry_issues}")
|
|
# Use the retry response even with issues
|
|
formatted_json = json.dumps(retry_parsed, indent=2)
|
|
else:
|
|
# Successfully fixed the issues
|
|
logger.info(f"Successfully fixed condition structure issues for '{disease_name}'")
|
|
formatted_json = json.dumps(retry_parsed, indent=2)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Retry response is still not valid JSON: {e}")
|
|
if args.strict:
|
|
raise ValueError(f"Failed to generate valid JSON after retry: {e}")
|
|
else:
|
|
# Fall back to the original response in non-strict mode
|
|
logger.warning("Using original response despite issues (non-strict mode)")
|
|
formatted_json = json.dumps(parsed, indent=2)
|
|
else:
|
|
# Original response was valid
|
|
formatted_json = json.dumps(parsed, indent=2)
|
|
|
|
logger.info(f"Successfully generated valid JSON for '{disease_name}'")
|
|
return formatted_json, input_token_estimate, output_token_estimate
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Generated content is not valid JSON: {e}")
|
|
logger.debug(f"Generated content: {content[:500]}...") # Log first 500 chars for debugging
|
|
|
|
# Try different extraction methods for the JSON
|
|
extraction_attempts = [
|
|
# Method 1: Find content between JSON code block markers
|
|
re.search(r'```json([\s\S]*?)```', content),
|
|
# Method 2: Find content between code block markers
|
|
re.search(r'```([\s\S]*?)```', content),
|
|
# Method 3: Find content between curly braces
|
|
re.search(r'({[\s\S]*})', content),
|
|
# Method 4: Find anything that looks like JSON starting with {
|
|
re.search(r'({.*})', content, re.DOTALL)
|
|
]
|
|
|
|
for attempt in extraction_attempts:
|
|
if attempt:
|
|
try:
|
|
extracted_content = attempt.group(1).strip()
|
|
# Add missing braces and fix incomplete JSON structures
|
|
# First clean up the JSON to remove any trailing commas before closing brackets
|
|
extracted_content = re.sub(r',\s*}', '}', extracted_content)
|
|
extracted_content = re.sub(r',\s*]', ']', extracted_content)
|
|
|
|
# Count opening and closing braces to detect missing ones
|
|
open_braces = extracted_content.count('{')
|
|
close_braces = extracted_content.count('}')
|
|
open_brackets = extracted_content.count('[')
|
|
close_brackets = extracted_content.count(']')
|
|
|
|
# Add missing braces or brackets if needed
|
|
if open_braces > close_braces:
|
|
extracted_content += '}' * (open_braces - close_braces)
|
|
logger.info(f"Added {open_braces - close_braces} missing closing braces")
|
|
elif close_braces > open_braces:
|
|
# Remove excess closing braces
|
|
for _ in range(close_braces - open_braces):
|
|
extracted_content = extracted_content.rstrip().rstrip('}') + '}'
|
|
logger.info(f"Removed {close_braces - open_braces} excess closing braces")
|
|
|
|
if open_brackets > close_brackets:
|
|
extracted_content += ']' * (open_brackets - close_brackets)
|
|
logger.info(f"Added {open_brackets - close_brackets} missing closing brackets")
|
|
elif close_brackets > open_brackets:
|
|
# Remove excess closing brackets
|
|
for _ in range(close_brackets - open_brackets):
|
|
last_bracket = extracted_content.rfind(']')
|
|
if last_bracket >= 0:
|
|
extracted_content = extracted_content[:last_bracket] + extracted_content[last_bracket+1:]
|
|
logger.info(f"Removed {close_brackets - open_brackets} excess closing brackets")
|
|
|
|
# Parse and format with Python's json module
|
|
parsed = json.loads(extracted_content)
|
|
# Format with Python's json module
|
|
formatted_json = json.dumps(parsed, indent=2)
|
|
logger.info(f"Successfully extracted valid JSON for '{disease_name}' after extraction attempt")
|
|
return formatted_json, input_token_estimate, output_token_estimate
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# If all attempts fail, try manual repair of common issues
|
|
try:
|
|
# Remove any triple backticks
|
|
cleaned = re.sub(r'```.*?```', '', content, flags=re.DOTALL)
|
|
# Remove any backticks
|
|
cleaned = re.sub(r'`', '', cleaned)
|
|
# Ensure proper quotes (replace single quotes with double quotes where needed)
|
|
cleaned = re.sub(r"'([^']*?)':", r'"\1":', cleaned)
|
|
# Fix trailing commas before closing brackets
|
|
cleaned = re.sub(r',\s*}', '}', cleaned)
|
|
cleaned = re.sub(r',\s*]', ']', cleaned)
|
|
|
|
# Find the module content between { }
|
|
module_match = re.search(r'({[\s\S]*})', cleaned)
|
|
if module_match:
|
|
module_json = module_match.group(1)
|
|
# Parse and format with Python's json module
|
|
parsed = json.loads(module_json)
|
|
# Format with Python's json module
|
|
formatted_json = json.dumps(parsed, indent=2)
|
|
logger.info(f"Successfully repaired and extracted JSON for '{disease_name}'")
|
|
return formatted_json, input_token_estimate, output_token_estimate
|
|
except Exception as repair_error:
|
|
logger.error(f"Failed to repair JSON: {repair_error}")
|
|
|
|
# Log a sample of the problematic content for debugging
|
|
debug_sample = content[:min(500, len(content))] + "..." if len(content) > 500 else content
|
|
logger.error(f"Could not extract valid JSON from Claude's response for '{disease_name}'")
|
|
logger.error(f"Response sample (first 500 chars): {debug_sample}")
|
|
logger.error(f"Full content length: {len(content)} characters")
|
|
raise ValueError("Could not extract valid JSON from Claude's response")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating module with Claude: {e}")
|
|
raise
|
|
|
|
def update_progress_file(disease_name: str, icd_code: str) -> None:
|
|
"""Update the progress tracking file with the newly created module"""
|
|
try:
|
|
with open(PROGRESS_FILE, 'r') as f:
|
|
content = f.readlines()
|
|
|
|
# Find the "Completed Modules" section and add the new module
|
|
for i, line in enumerate(content):
|
|
if line.startswith('## Completed Modules'):
|
|
# Count existing modules to determine next number
|
|
count = 0
|
|
for j in range(i+1, len(content)):
|
|
if content[j].strip() and content[j][0].isdigit():
|
|
count += 1
|
|
elif content[j].startswith('##'):
|
|
break
|
|
|
|
# Insert new module entry
|
|
content.insert(i+1+count, f"{count+1}. {disease_name} ({icd_code}) - Generated by module_generator.py\n")
|
|
break
|
|
|
|
with open(PROGRESS_FILE, 'w') as f:
|
|
f.writelines(content)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating progress file: {e}")
|
|
|
|
def normalize_filename(disease_name: str) -> str:
|
|
"""Convert disease name to a valid filename"""
|
|
# Replace spaces and special characters with underscores
|
|
filename = re.sub(r'[^a-zA-Z0-9]', '_', disease_name.lower())
|
|
# Remove consecutive underscores
|
|
filename = re.sub(r'_+', '_', filename)
|
|
# Remove leading/trailing underscores
|
|
filename = filename.strip('_')
|
|
return filename
|
|
|
|
def get_existing_module_path(disease_name: str, disease_info: Dict[str, str]) -> Optional[str]:
|
|
"""Check if a module for this disease already exists and return its path"""
|
|
# Method 1: Check by normalized filename
|
|
normalized_name = normalize_filename(disease_name)
|
|
candidate_path = os.path.join(MODULES_DIR, f"{normalized_name}.json")
|
|
if os.path.exists(candidate_path):
|
|
return candidate_path
|
|
|
|
# Method 2: Check by ICD-10 code in existing modules
|
|
icd_code = disease_info.get('icd_10', '')
|
|
if icd_code:
|
|
for module_path in glob.glob(os.path.join(MODULES_DIR, "*.json")):
|
|
with open(module_path, 'r') as f:
|
|
try:
|
|
content = f.read()
|
|
# Check for exact ICD-10 code match
|
|
if f'"code": "{icd_code}"' in content:
|
|
return module_path
|
|
except:
|
|
pass
|
|
|
|
# No existing module found
|
|
return None
|
|
|
|
def estimate_disease_prevalence(disease_name: str, icd_code: str) -> float:
|
|
"""Estimate disease prevalence for prioritization (higher = more prevalent)"""
|
|
# This is a simple heuristic - you could replace with actual prevalence data if available
|
|
|
|
# Some common conditions tend to have higher prevalence
|
|
common_conditions = [
|
|
"hypertension", "diabetes", "arthritis", "asthma", "depression",
|
|
"anxiety", "obesity", "cancer", "heart", "copd", "stroke", "pneumonia",
|
|
"bronchitis", "influenza", "infection", "pain", "fracture"
|
|
]
|
|
|
|
score = 1.0
|
|
|
|
# Check if it contains common condition keywords
|
|
name_lower = disease_name.lower()
|
|
for condition in common_conditions:
|
|
if condition in name_lower:
|
|
score += 2.0
|
|
break
|
|
|
|
# ICD-10 chapter weighting (approximate prevalence by chapter)
|
|
if icd_code.startswith('I'): # Circulatory system
|
|
score += 5.0
|
|
elif icd_code.startswith('J'): # Respiratory system
|
|
score += 4.0
|
|
elif icd_code.startswith('K'): # Digestive system
|
|
score += 3.5
|
|
elif icd_code.startswith('M'): # Musculoskeletal system
|
|
score += 3.0
|
|
elif icd_code.startswith('E'): # Endocrine, nutritional and metabolic diseases
|
|
score += 4.0
|
|
elif icd_code.startswith('F'): # Mental and behavioral disorders
|
|
score += 3.5
|
|
|
|
# Prefer shorter, more specific disease names (likely more common conditions)
|
|
if len(disease_name.split()) <= 3:
|
|
score += 1.0
|
|
|
|
return score
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Generate Synthea disease modules')
|
|
parser.add_argument('--diseases', type=str, help='Comma-separated list of ICD-10 codes to process')
|
|
parser.add_argument('--limit', type=int, default=10, help='Maximum number of modules to generate')
|
|
parser.add_argument('--prioritize', action='store_true', help='Prioritize high-prevalence diseases')
|
|
parser.add_argument('--log-file', type=str, help='Path to log file for token usage tracking')
|
|
parser.add_argument('--strict', action='store_true',
|
|
help='Fail immediately on validation errors instead of trying to fix them')
|
|
# Add support for direct disease generation
|
|
parser.add_argument('--disease', type=str, help='Single disease name to generate')
|
|
parser.add_argument('--output', type=str, help='Output path for the generated module')
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging with custom log file if specified
|
|
global logger
|
|
if args.log_file:
|
|
logger = setup_logging(args.log_file)
|
|
else:
|
|
logger = setup_logging()
|
|
|
|
# Check if we're operating in direct mode (single disease)
|
|
if args.disease and args.output:
|
|
try:
|
|
logger.info(f"Generating module for single disease: {args.disease}")
|
|
|
|
# Create a simple disease info dictionary
|
|
disease_info = {
|
|
'icd_10': '', # Empty since we don't have this info
|
|
'snomed': '', # Empty since we don't have this info
|
|
'ICD-10_name': args.disease
|
|
}
|
|
|
|
# Try to find a relevant template
|
|
templates = glob.glob(os.path.join(TEMPLATES_DIR, "*.json"))
|
|
if templates:
|
|
template_path = templates[0] # Use the first template found
|
|
else:
|
|
# Use a simple default template if none found
|
|
logger.warning("No template found, using a minimal default")
|
|
template_path = os.path.join(MODULES_DIR, "appendicitis.json")
|
|
if not os.path.exists(template_path):
|
|
raise ValueError(f"Cannot find any suitable template for generation")
|
|
|
|
# Generate the module
|
|
module_content, _, _ = generate_module_with_claude(args.disease, disease_info, template_path)
|
|
|
|
# Save the module to the specified output path
|
|
with open(args.output, 'w') as f:
|
|
f.write(module_content)
|
|
|
|
logger.info(f"Successfully generated module for {args.disease}")
|
|
print(f"Successfully generated module for {args.disease}")
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating single disease module: {e}")
|
|
print(f"Error: {e}")
|
|
sys.exit(1)
|
|
|
|
# Load disease list
|
|
all_diseases = load_disease_list()
|
|
logger.info(f"Loaded {len(all_diseases)} diseases from disease_list.json")
|
|
|
|
# Get existing modules
|
|
existing_modules = get_existing_modules()
|
|
logger.info(f"Found {len(existing_modules)} existing modules")
|
|
|
|
# Filter diseases to process
|
|
if args.diseases:
|
|
disease_codes = [code.strip() for code in args.diseases.split(',')]
|
|
to_process = {name: info for name, info in all_diseases.items()
|
|
if info.get('icd_10', '').split('.')[0] in disease_codes}
|
|
logger.info(f"Filtered to {len(to_process)} diseases matching specified codes")
|
|
else:
|
|
# Process all diseases up to the limit
|
|
to_process = all_diseases
|
|
|
|
# Only include diseases that don't already have modules by filename
|
|
# (We'll do a more thorough check later with get_existing_module_path)
|
|
diseases_to_create = {}
|
|
candidate_diseases = []
|
|
|
|
for name, info in to_process.items():
|
|
normalized_name = normalize_filename(name)
|
|
if normalized_name not in existing_modules:
|
|
if args.prioritize:
|
|
# Add to candidates for prioritization
|
|
icd_code = info.get('icd_10', '').split('.')[0]
|
|
prevalence_score = estimate_disease_prevalence(name, icd_code)
|
|
candidate_diseases.append((name, info, prevalence_score))
|
|
else:
|
|
diseases_to_create[name] = info
|
|
|
|
# Respect the limit for non-prioritized mode
|
|
if len(diseases_to_create) >= args.limit:
|
|
break
|
|
|
|
# If prioritizing, sort by estimated prevalence and take top N
|
|
if args.prioritize and candidate_diseases:
|
|
logger.info("Prioritizing diseases by estimated prevalence")
|
|
candidate_diseases.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
# Log top candidates for transparency
|
|
logger.info("Top candidates by estimated prevalence:")
|
|
for i, (name, info, score) in enumerate(candidate_diseases[:min(10, len(candidate_diseases))]):
|
|
logger.info(f" {i+1}. {name} (ICD-10: {info.get('icd_10', 'Unknown')}) - Score: {score:.2f}")
|
|
|
|
# Select top N diseases
|
|
for name, info, _ in candidate_diseases[:args.limit]:
|
|
diseases_to_create[name] = info
|
|
|
|
logger.info(f"Will generate modules for {len(diseases_to_create)} diseases")
|
|
|
|
# Generate modules
|
|
for disease_name, disease_info in tqdm(diseases_to_create.items(), desc="Generating modules"):
|
|
try:
|
|
# First check if module already exists - no need to use LLM if it does
|
|
existing_module_path = get_existing_module_path(disease_name, disease_info)
|
|
|
|
if existing_module_path:
|
|
# Module already exists, just copy it
|
|
logger.info(f"Module for {disease_name} already exists at {existing_module_path}")
|
|
|
|
# Read existing module
|
|
with open(existing_module_path, 'r') as f:
|
|
module_content = f.read()
|
|
|
|
# Save to normalized filename if different from existing path
|
|
filename = normalize_filename(disease_name) + ".json"
|
|
output_path = os.path.join(MODULES_DIR, filename)
|
|
|
|
if output_path != existing_module_path:
|
|
with open(output_path, 'w') as f:
|
|
f.write(module_content)
|
|
logger.info(f"Copied existing module to {output_path}")
|
|
else:
|
|
logger.info(f"Existing module already has correct filename")
|
|
|
|
# Update progress file
|
|
icd_code = disease_info.get('icd_10', 'Unknown')
|
|
update_progress_file(disease_name, icd_code)
|
|
|
|
else:
|
|
# No existing module, generate with Claude
|
|
# Find best template
|
|
template_path = find_most_relevant_module(disease_info, existing_modules)
|
|
logger.info(f"Using {os.path.basename(template_path)} as template for {disease_name}")
|
|
|
|
# Generate module
|
|
module_content, input_tokens, output_tokens = generate_module_with_claude(disease_name, disease_info, template_path)
|
|
|
|
# Save module
|
|
filename = normalize_filename(disease_name) + ".json"
|
|
output_path = os.path.join(MODULES_DIR, filename)
|
|
|
|
with open(output_path, 'w') as f:
|
|
f.write(module_content)
|
|
|
|
logger.info(f"Successfully created module for {disease_name}")
|
|
|
|
# Update progress file
|
|
icd_code = disease_info.get('icd_10', 'Unknown')
|
|
update_progress_file(disease_name, icd_code)
|
|
|
|
# Sleep to avoid hitting API rate limits
|
|
time.sleep(1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate module for {disease_name}: {e}")
|
|
|
|
logger.info("Module generation complete")
|
|
|
|
if __name__ == "__main__":
|
|
main() |