122 lines
4.4 KiB
Python
Executable File
122 lines
4.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import numpy as np
|
|
import mdtraj as md
|
|
from sklearn.cluster import KMeans
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Calculate Gibbs free energy from protein structure ensembles')
|
|
parser.add_argument('--samples', required=True, help='Path to XTC trajectory file with structure samples')
|
|
parser.add_argument('--topology', required=True, help='Path to PDB topology file')
|
|
parser.add_argument('--temperature', type=float, default=300.0, help='Temperature in Kelvin')
|
|
parser.add_argument('--output', required=True, help='Output CSV file for free energy data')
|
|
parser.add_argument('--n_clusters', type=int, default=5, help='Number of conformational clusters')
|
|
parser.add_argument('--plot', help='Path to save energy plot (optional)')
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Load trajectory
|
|
print(f"Loading trajectory {args.samples} with topology {args.topology}")
|
|
traj = md.load(args.samples, top=args.topology)
|
|
|
|
# Calculate RMSD to first frame
|
|
print("Calculating RMSD to reference structure")
|
|
# Align to the first frame
|
|
traj.superpose(traj, 0)
|
|
# Calculate RMSD for CA atoms only
|
|
atom_indices = traj.topology.select('name CA')
|
|
distances = np.zeros(traj.n_frames)
|
|
for i in range(traj.n_frames):
|
|
# Fixed line - using slicing instead of frame parameter
|
|
distances[i] = md.rmsd(traj[i:i+1], traj[0:1], atom_indices=atom_indices)[0]
|
|
|
|
# Feature extraction for clustering
|
|
# Use the RMSD and radius of gyration as features
|
|
rg = md.compute_rg(traj)
|
|
features = np.column_stack((distances, rg))
|
|
|
|
# Cluster structures
|
|
print(f"Clustering structures into {args.n_clusters} states")
|
|
kmeans = KMeans(n_clusters=args.n_clusters, random_state=42)
|
|
labels = kmeans.fit_predict(features)
|
|
|
|
# Calculate state populations
|
|
unique_labels, counts = np.unique(labels, return_counts=True)
|
|
populations = counts / len(labels)
|
|
|
|
# Calculate free energies
|
|
R = 0.0019872041 # kcal/(mol·K)
|
|
T = args.temperature
|
|
RT = R * T
|
|
|
|
# Reference state is the most populated one
|
|
reference_idx = np.argmax(populations)
|
|
reference_pop = populations[reference_idx]
|
|
|
|
# Calculate ΔG values
|
|
free_energies = -RT * np.log(populations / reference_pop)
|
|
|
|
# Get representative structures from each cluster
|
|
representatives = []
|
|
for i in range(args.n_clusters):
|
|
cluster_frames = np.where(labels == i)[0]
|
|
if len(cluster_frames) > 0:
|
|
# Find frame closest to cluster center
|
|
cluster_features = features[cluster_frames]
|
|
center_dists = np.linalg.norm(cluster_features - kmeans.cluster_centers_[i], axis=1)
|
|
center_idx = cluster_frames[np.argmin(center_dists)]
|
|
representatives.append(int(center_idx))
|
|
else:
|
|
representatives.append(-1) # No members in this cluster
|
|
|
|
# Create results dataframe
|
|
results = pd.DataFrame({
|
|
'Cluster': unique_labels,
|
|
'Population': populations,
|
|
'DeltaG_kcal_mol': free_energies,
|
|
'Representative_Frame': representatives
|
|
})
|
|
|
|
# Sort by free energy
|
|
results = results.sort_values('DeltaG_kcal_mol')
|
|
|
|
# Save results
|
|
results.to_csv(args.output, index=False)
|
|
print(f"Results saved to {args.output}")
|
|
|
|
# Print summary
|
|
print("\nFree Energy Summary:")
|
|
print(results)
|
|
|
|
# Calculate overall free energy range
|
|
print(f"\nFree energy range: {np.max(free_energies) - np.min(free_energies):.2f} kcal/mol")
|
|
|
|
# Create plot if requested
|
|
if args.plot:
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Plot free energy for each cluster
|
|
plt.bar(range(len(unique_labels)), free_energies, alpha=0.7)
|
|
plt.xlabel('Cluster')
|
|
plt.ylabel('ΔG (kcal/mol)')
|
|
plt.title('Free Energy Landscape')
|
|
plt.xticks(range(len(unique_labels)))
|
|
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
|
|
|
# Add population as text on bars
|
|
for i, (energy, pop) in enumerate(zip(free_energies, populations)):
|
|
plt.text(i, energy + 0.1, f"{pop*100:.1f}%", ha='center')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(args.plot, dpi=300)
|
|
print(f"Plot saved to {args.plot}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|