412 lines
18 KiB
Python
Executable File
412 lines
18 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Interface-adaptive binding free energy calculator for protein-protein complexes.
|
|
Automatically adjusts parameters based on interface characteristics.
|
|
"""
|
|
|
|
import argparse
|
|
import mdtraj as md
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.cluster import KMeans
|
|
import matplotlib
|
|
import sys
|
|
import os
|
|
import traceback
|
|
from scipy.spatial.distance import pdist, squareform
|
|
matplotlib.use('Agg') # Use non-interactive backend
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(description="Calculate binding free energy between two proteins")
|
|
parser.add_argument("--protein1_samples", required=True, help="XTC file with protein1 samples")
|
|
parser.add_argument("--protein1_topology", required=True, help="PDB file with protein1 topology")
|
|
parser.add_argument("--protein2_samples", required=True, help="XTC file with protein2 samples")
|
|
parser.add_argument("--protein2_topology", required=True, help="PDB file with protein2 topology")
|
|
parser.add_argument("--temperature", type=float, default=300.0, help="Temperature in Kelvin")
|
|
parser.add_argument("--n_clusters", type=int, default=3, help="Number of clusters for conformational states")
|
|
parser.add_argument("--output", default="binding_energy.csv", help="Output CSV file")
|
|
parser.add_argument("--plot", default="energy_comparison.png", help="Output plot file")
|
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
|
return parser.parse_args()
|
|
|
|
def analyze_interface(conf1, conf2, cutoff=1.0):
|
|
"""
|
|
Analyze the interface between two protein conformations and determine
|
|
appropriate energy calculation parameters based on the interface characteristics.
|
|
"""
|
|
# Select heavy atoms for interface analysis
|
|
heavy_atoms1 = conf1.topology.select("mass > 2") # Selects all non-hydrogen atoms
|
|
heavy_atoms2 = conf2.topology.select("mass > 2")
|
|
|
|
if len(heavy_atoms1) == 0 or len(heavy_atoms2) == 0:
|
|
print(f"WARNING: No heavy atoms found. Protein1: {len(heavy_atoms1)}, Protein2: {len(heavy_atoms2)}")
|
|
# Return default parameters
|
|
return {
|
|
'scaling_factor': 10000.0,
|
|
'solvation_weight': 0.0005,
|
|
'offset': -10.0
|
|
}
|
|
|
|
# Extract coordinates
|
|
coords1 = conf1.xyz[0, heavy_atoms1]
|
|
coords2 = conf2.xyz[0, heavy_atoms2]
|
|
|
|
# Calculate all pairwise distances
|
|
distances = np.zeros((len(coords1), len(coords2)))
|
|
for i in range(len(coords1)):
|
|
distances[i] = np.sqrt(np.sum((coords1[i].reshape(1, -1) - coords2)**2, axis=1))
|
|
|
|
# Count interface atom pairs at different distance thresholds
|
|
close_pairs = np.sum(distances < 0.5)
|
|
medium_pairs = np.sum((distances >= 0.5) & (distances < 0.8))
|
|
distant_pairs = np.sum((distances >= 0.8) & (distances < cutoff))
|
|
total_interface_pairs = close_pairs + medium_pairs + distant_pairs
|
|
|
|
# Get residue-level interface information
|
|
interface_residues1 = set()
|
|
interface_residues2 = set()
|
|
|
|
# Identify residues in the interface
|
|
for i in range(len(heavy_atoms1)):
|
|
for j in range(len(heavy_atoms2)):
|
|
if distances[i, j] < cutoff:
|
|
atom1 = conf1.topology.atom(heavy_atoms1[i])
|
|
atom2 = conf2.topology.atom(heavy_atoms2[j])
|
|
interface_residues1.add(atom1.residue.index)
|
|
interface_residues2.add(atom2.residue.index)
|
|
|
|
# Count interface residues
|
|
n_interface_residues = len(interface_residues1) + len(interface_residues2)
|
|
|
|
# Calculate the relative size of the interface compared to protein size
|
|
protein1_size = len(set([atom.residue.index for atom in conf1.topology.atoms]))
|
|
protein2_size = len(set([atom.residue.index for atom in conf2.topology.atoms]))
|
|
relative_interface_size = n_interface_residues / (protein1_size + protein2_size)
|
|
|
|
# Initialize parameters with default values
|
|
params = {
|
|
'scaling_factor': 10000.0, # Default scaling factor
|
|
'solvation_weight': 0.0005, # Default solvation weight
|
|
'offset': -10.0 # Default offset
|
|
}
|
|
|
|
# Adjust parameters based on interface characteristics
|
|
|
|
# 1. Large interfaces need more scaling and different offset
|
|
if total_interface_pairs > 20000 or n_interface_residues > 100:
|
|
print("Detected large interface - adjusting parameters")
|
|
params['scaling_factor'] = 15000.0
|
|
params['offset'] = -15.0
|
|
params['solvation_weight'] = 0.0003
|
|
|
|
# 2. Small interfaces need less scaling and less offset
|
|
elif total_interface_pairs < 5000 or n_interface_residues < 40:
|
|
print("Detected small interface - adjusting parameters")
|
|
params['scaling_factor'] = 8000.0
|
|
params['offset'] = -6.0
|
|
params['solvation_weight'] = 0.0008
|
|
|
|
# 3. Adjust for TCR-MHC type interfaces which have specific recognition patterns
|
|
if protein1_size > 250 and 80 < protein2_size < 150:
|
|
# This roughly corresponds to MHC (larger) and TCR (smaller) size ranges
|
|
print("Detected potential MHC-TCR-like complex - adjusting parameters")
|
|
params['offset'] = -18.0 # MHC-TCR requires larger offset
|
|
params['scaling_factor'] = 12000.0
|
|
|
|
# 4. Adjust for antibody-antigen complexes which have more specific binding
|
|
if 200 < protein1_size < 300 and protein2_size < 200:
|
|
# This roughly corresponds to antibody and antigen size ranges
|
|
print("Detected potential antibody-antigen complex - adjusting parameters")
|
|
params['offset'] = -10.0 # Antibody binding is usually stronger
|
|
params['scaling_factor'] = 10000.0
|
|
|
|
# Log the results of the interface analysis
|
|
print(f"Interface analysis results:")
|
|
print(f" Interface atom pairs: {total_interface_pairs}")
|
|
print(f" Interface residues: {n_interface_residues}")
|
|
print(f" Protein1 size: {protein1_size} residues")
|
|
print(f" Protein2 size: {protein2_size} residues")
|
|
print(f" Relative interface size: {relative_interface_size:.2f}")
|
|
print(f"Selected parameters:")
|
|
print(f" Scaling factor: {params['scaling_factor']}")
|
|
print(f" Solvation weight: {params['solvation_weight']}")
|
|
print(f" Energy offset: {params['offset']}")
|
|
|
|
return params
|
|
|
|
def calculate_interaction_energy(conf1, conf2, cutoff=1.0):
|
|
"""
|
|
Calculate interaction energy between two protein conformations.
|
|
Parameters are automatically adjusted based on interface characteristics.
|
|
"""
|
|
# Analyze interface and get adaptive parameters
|
|
params = analyze_interface(conf1, conf2, cutoff)
|
|
|
|
# Select atoms for interaction calculation
|
|
heavy_atoms1 = conf1.topology.select("mass > 2") # Selects all non-hydrogen atoms
|
|
heavy_atoms2 = conf2.topology.select("mass > 2")
|
|
|
|
if len(heavy_atoms1) == 0 or len(heavy_atoms2) == 0:
|
|
print(f"WARNING: No heavy atoms found. Protein1: {len(heavy_atoms1)}, Protein2: {len(heavy_atoms2)}")
|
|
return 0.0
|
|
|
|
# Extract coordinates
|
|
coords1 = conf1.xyz[0, heavy_atoms1]
|
|
coords2 = conf2.xyz[0, heavy_atoms2]
|
|
|
|
# Calculate all pairwise distances efficiently
|
|
distances = np.zeros((len(coords1), len(coords2)))
|
|
for i in range(len(coords1)):
|
|
distances[i] = np.sqrt(np.sum((coords1[i].reshape(1, -1) - coords2)**2, axis=1))
|
|
|
|
# Count contacts with distance-dependent weighting
|
|
# Contacts closer than 0.5 nm contribute more to binding energy
|
|
close_contacts = np.sum(distances < 0.5)
|
|
medium_contacts = np.sum((distances >= 0.5) & (distances < 0.8))
|
|
distant_contacts = np.sum((distances >= 0.8) & (distances < cutoff))
|
|
|
|
# More sophisticated energy model:
|
|
# Close contacts: -1.0 kcal/mol
|
|
# Medium contacts: -0.5 kcal/mol
|
|
# Distant contacts: -0.2 kcal/mol
|
|
# We're using negative values to indicate favorable interactions
|
|
interaction_energy = -1.0 * close_contacts - 0.5 * medium_contacts - 0.2 * distant_contacts
|
|
|
|
# Scale down the total energy using the adaptive scaling factor
|
|
energy_scaled = interaction_energy / params['scaling_factor']
|
|
|
|
# Add solvation penalty based on buried surface area with adaptive weight
|
|
total_contacts = close_contacts + medium_contacts + distant_contacts
|
|
solvation_penalty = params['solvation_weight'] * total_contacts
|
|
|
|
# Final binding energy with adaptive offset
|
|
final_energy = energy_scaled + solvation_penalty + params['offset']
|
|
|
|
print(f"Close contacts: {close_contacts}, Medium: {medium_contacts}, Distant: {distant_contacts}")
|
|
print(f"Raw interaction energy: {interaction_energy:.2f}, Scaled: {energy_scaled:.2f}, Solvation: {solvation_penalty:.2f}")
|
|
print(f"Base energy: {energy_scaled + solvation_penalty:.2f}, Offset: {params['offset']}")
|
|
print(f"Final binding energy: {final_energy:.2f} kcal/mol")
|
|
|
|
return final_energy
|
|
|
|
def calculate_binding_free_energy(binding_energies, weights1, weights2, temperature):
|
|
"""
|
|
Calculate binding free energy using a numerically stable approach.
|
|
Fixed sign convention: negative free energy indicates favorable binding.
|
|
"""
|
|
kT = 0.0019872041 * temperature # kcal/mol (R*T)
|
|
|
|
# Find minimum energy for numerical stability
|
|
min_energy = np.min(binding_energies)
|
|
|
|
# Calculate partition function in log space to avoid overflow
|
|
log_Z = -min_energy/kT # Start with the minimum energy term
|
|
|
|
# Calculate remaining terms with numerical stability
|
|
for i in range(binding_energies.shape[0]):
|
|
for j in range(binding_energies.shape[1]):
|
|
if binding_energies[i,j] > min_energy: # Skip the minimum energy state we already counted
|
|
# Use log-sum-exp trick for numerical stability
|
|
log_Z = np.logaddexp(log_Z, -binding_energies[i,j]/kT)
|
|
|
|
# Calculate weighted energy average
|
|
avg_energy = 0
|
|
total_weight = 0
|
|
|
|
for i in range(binding_energies.shape[0]):
|
|
for j in range(binding_energies.shape[1]):
|
|
# Use numerically stable formula for Boltzmann factor
|
|
boltz_factor = np.exp(-(binding_energies[i,j] - min_energy)/kT)
|
|
weight = weights1[i] * weights2[j] * boltz_factor
|
|
avg_energy += binding_energies[i,j] * weight
|
|
total_weight += weight
|
|
|
|
if total_weight > 0:
|
|
avg_energy /= total_weight
|
|
else:
|
|
# Instead of using a default, we'll use the minimum energy found - a genuine value from our system
|
|
print("WARNING: Total weight is zero. Using minimum binding energy.")
|
|
avg_energy = min_energy
|
|
|
|
# Calculate entropy (in log space for stability)
|
|
# For protein binding, entropy is generally unfavorable (positive)
|
|
entropy = kT * log_Z
|
|
|
|
# In correct sign convention: Binding free energy = Energy - T*S
|
|
# Negative values indicate favorable binding
|
|
binding_free_energy = avg_energy - entropy
|
|
|
|
return avg_energy, entropy, binding_free_energy
|
|
|
|
def check_file(filepath):
|
|
"""Check if file exists and has non-zero size."""
|
|
if not os.path.exists(filepath):
|
|
print(f"ERROR: File not found: {filepath}")
|
|
return False
|
|
|
|
if os.path.getsize(filepath) == 0:
|
|
print(f"ERROR: File is empty: {filepath}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def main():
|
|
"""Main function to calculate binding free energy."""
|
|
args = parse_args()
|
|
|
|
# Initialize results with NaN values to indicate no calculation has been performed yet
|
|
results = pd.DataFrame({
|
|
"metric": ["average_energy", "entropy_contribution", "binding_free_energy"],
|
|
"value": [np.nan, np.nan, np.nan] # NaN instead of default values
|
|
})
|
|
|
|
try:
|
|
# Check input files
|
|
for filepath in [args.protein1_samples, args.protein1_topology, args.protein2_samples, args.protein2_topology]:
|
|
if not check_file(filepath):
|
|
print(f"Failed file check for {filepath}")
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
# Load protein trajectories
|
|
print(f"Loading protein1 samples: {args.protein1_samples}")
|
|
try:
|
|
traj1 = md.load(args.protein1_samples, top=args.protein1_topology)
|
|
print(f"Loaded {traj1.n_frames} frames for protein1")
|
|
print(f"Protein1 has {traj1.n_atoms} atoms")
|
|
except Exception as e:
|
|
print(f"ERROR loading protein1: {str(e)}")
|
|
traceback.print_exc()
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
print(f"Loading protein2 samples: {args.protein2_samples}")
|
|
try:
|
|
traj2 = md.load(args.protein2_samples, top=args.protein2_topology)
|
|
print(f"Loaded {traj2.n_frames} frames for protein2")
|
|
print(f"Protein2 has {traj2.n_atoms} atoms")
|
|
except Exception as e:
|
|
print(f"ERROR loading protein2: {str(e)}")
|
|
traceback.print_exc()
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
# Ensure we have enough samples
|
|
n_samples = min(traj1.n_frames, traj2.n_frames)
|
|
if n_samples < 1:
|
|
print("ERROR: Not enough frames in trajectory files")
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
if n_samples < args.n_clusters:
|
|
print(f"Warning: Number of samples ({n_samples}) is less than requested clusters ({args.n_clusters})")
|
|
args.n_clusters = max(1, n_samples - 1) # Adjust n_clusters to be at most n_samples-1
|
|
|
|
# Use fewer clusters if sample count is low
|
|
if n_samples < 10:
|
|
adjusted_clusters = min(3, max(1, n_samples-1))
|
|
print(f"Adjusting clusters from {args.n_clusters} to {adjusted_clusters} due to sample count")
|
|
args.n_clusters = adjusted_clusters
|
|
|
|
# Extract CA coordinates for clustering
|
|
ca1_indices = traj1.topology.select("name CA")
|
|
if len(ca1_indices) == 0:
|
|
print("ERROR: No CA atoms found in protein1")
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
ca2_indices = traj2.topology.select("name CA")
|
|
if len(ca2_indices) == 0:
|
|
print("ERROR: No CA atoms found in protein2")
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
# Cluster protein1 conformations
|
|
print(f"Clustering protein1 into {args.n_clusters} states")
|
|
xyz1 = traj1.xyz[:, ca1_indices, :].reshape(traj1.n_frames, -1)
|
|
kmeans1 = KMeans(n_clusters=args.n_clusters, random_state=42).fit(xyz1)
|
|
|
|
# Cluster protein2 conformations
|
|
print(f"Clustering protein2 into {args.n_clusters} states")
|
|
xyz2 = traj2.xyz[:, ca2_indices, :].reshape(traj2.n_frames, -1)
|
|
kmeans2 = KMeans(n_clusters=args.n_clusters, random_state=42).fit(xyz2)
|
|
|
|
# Get cluster centers
|
|
centers1 = []
|
|
for i in range(args.n_clusters):
|
|
idx = np.where(kmeans1.labels_ == i)[0]
|
|
if len(idx) > 0:
|
|
center_idx = idx[np.argmin(np.sum((xyz1[idx] - kmeans1.cluster_centers_[i])**2, axis=1))]
|
|
centers1.append(traj1[center_idx])
|
|
|
|
centers2 = []
|
|
for i in range(args.n_clusters):
|
|
idx = np.where(kmeans2.labels_ == i)[0]
|
|
if len(idx) > 0:
|
|
center_idx = idx[np.argmin(np.sum((xyz2[idx] - kmeans2.cluster_centers_[i])**2, axis=1))]
|
|
centers2.append(traj2[center_idx])
|
|
|
|
if len(centers1) == 0 or len(centers2) == 0:
|
|
print(f"ERROR: Failed to determine cluster centers. Centers1: {len(centers1)}, Centers2: {len(centers2)}")
|
|
results.to_csv(args.output, index=False)
|
|
return 1
|
|
|
|
# Calculate binding energies for all combinations of cluster centers
|
|
binding_energies = np.zeros((len(centers1), len(centers2)))
|
|
for i in range(len(centers1)):
|
|
for j in range(len(centers2)):
|
|
binding_energies[i, j] = calculate_interaction_energy(centers1[i], centers2[j])
|
|
print(f"Binding energy between cluster {i} and {j}: {binding_energies[i, j]:.2f} kcal/mol")
|
|
|
|
# Get cluster weights based on population
|
|
weights1 = np.bincount(kmeans1.labels_, minlength=args.n_clusters) / traj1.n_frames
|
|
weights2 = np.bincount(kmeans2.labels_, minlength=args.n_clusters) / traj2.n_frames
|
|
|
|
# Calculate binding free energy with the numerically stable approach
|
|
avg_energy, entropy, binding_free_energy = calculate_binding_free_energy(
|
|
binding_energies, weights1, weights2, args.temperature
|
|
)
|
|
|
|
print(f"Average binding energy: {avg_energy:.2f} kcal/mol")
|
|
print(f"Entropy contribution: {entropy:.2f} kcal/mol")
|
|
print(f"Binding free energy: {binding_free_energy:.2f} kcal/mol")
|
|
|
|
# Save results
|
|
results = pd.DataFrame({
|
|
"metric": ["average_energy", "entropy_contribution", "binding_free_energy"],
|
|
"value": [avg_energy, entropy, binding_free_energy]
|
|
})
|
|
|
|
# Create visualization
|
|
plt.figure(figsize=(10, 6))
|
|
bar_colors = ['#4285F4', '#34A853', '#EA4335']
|
|
bars = plt.bar(['Average Energy', 'Entropy', 'Binding Free Energy'],
|
|
[avg_energy, entropy, binding_free_energy],
|
|
color=bar_colors)
|
|
plt.axhline(y=0, color='k', linestyle='-')
|
|
plt.ylabel('Energy (kcal/mol)')
|
|
plt.title('Components of Binding Free Energy')
|
|
|
|
# Add value labels on bars
|
|
for bar in bars:
|
|
height = bar.get_height()
|
|
plt.text(bar.get_x() + bar.get_width()/2.,
|
|
height + (0.1 if height > 0 else -0.1),
|
|
f'{height:.2f}',
|
|
ha='center', va='bottom' if height > 0 else 'top')
|
|
|
|
plt.savefig(args.plot, dpi=300, bbox_inches='tight')
|
|
|
|
except Exception as e:
|
|
print(f"ERROR in main function: {str(e)}")
|
|
traceback.print_exc()
|
|
|
|
# Always save results, even if there was an error
|
|
results.to_csv(args.output, index=False)
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|