Add BioEmu Nextflow pipeline implementation

This commit is contained in:
2025-03-04 09:38:55 -08:00
commit 2cfbf64e92
12 changed files with 565 additions and 0 deletions

121
calculate_gibbs.py Executable file
View File

@@ -0,0 +1,121 @@
#!/usr/bin/env python3
import argparse
import numpy as np
import mdtraj as md
from sklearn.cluster import KMeans
import pandas as pd
import matplotlib.pyplot as plt
import os
def parse_args():
parser = argparse.ArgumentParser(description='Calculate Gibbs free energy from protein structure ensembles')
parser.add_argument('--samples', required=True, help='Path to XTC trajectory file with structure samples')
parser.add_argument('--topology', required=True, help='Path to PDB topology file')
parser.add_argument('--temperature', type=float, default=300.0, help='Temperature in Kelvin')
parser.add_argument('--output', required=True, help='Output CSV file for free energy data')
parser.add_argument('--n_clusters', type=int, default=5, help='Number of conformational clusters')
parser.add_argument('--plot', help='Path to save energy plot (optional)')
return parser.parse_args()
def main():
args = parse_args()
# Load trajectory
print(f"Loading trajectory {args.samples} with topology {args.topology}")
traj = md.load(args.samples, top=args.topology)
# Calculate RMSD to first frame
print("Calculating RMSD to reference structure")
# Align to the first frame
traj.superpose(traj, 0)
# Calculate RMSD for CA atoms only
atom_indices = traj.topology.select('name CA')
distances = np.zeros(traj.n_frames)
for i in range(traj.n_frames):
# Fixed line - using slicing instead of frame parameter
distances[i] = md.rmsd(traj[i:i+1], traj[0:1], atom_indices=atom_indices)[0]
# Feature extraction for clustering
# Use the RMSD and radius of gyration as features
rg = md.compute_rg(traj)
features = np.column_stack((distances, rg))
# Cluster structures
print(f"Clustering structures into {args.n_clusters} states")
kmeans = KMeans(n_clusters=args.n_clusters, random_state=42)
labels = kmeans.fit_predict(features)
# Calculate state populations
unique_labels, counts = np.unique(labels, return_counts=True)
populations = counts / len(labels)
# Calculate free energies
R = 0.0019872041 # kcal/(mol·K)
T = args.temperature
RT = R * T
# Reference state is the most populated one
reference_idx = np.argmax(populations)
reference_pop = populations[reference_idx]
# Calculate ΔG values
free_energies = -RT * np.log(populations / reference_pop)
# Get representative structures from each cluster
representatives = []
for i in range(args.n_clusters):
cluster_frames = np.where(labels == i)[0]
if len(cluster_frames) > 0:
# Find frame closest to cluster center
cluster_features = features[cluster_frames]
center_dists = np.linalg.norm(cluster_features - kmeans.cluster_centers_[i], axis=1)
center_idx = cluster_frames[np.argmin(center_dists)]
representatives.append(int(center_idx))
else:
representatives.append(-1) # No members in this cluster
# Create results dataframe
results = pd.DataFrame({
'Cluster': unique_labels,
'Population': populations,
'DeltaG_kcal_mol': free_energies,
'Representative_Frame': representatives
})
# Sort by free energy
results = results.sort_values('DeltaG_kcal_mol')
# Save results
results.to_csv(args.output, index=False)
print(f"Results saved to {args.output}")
# Print summary
print("\nFree Energy Summary:")
print(results)
# Calculate overall free energy range
print(f"\nFree energy range: {np.max(free_energies) - np.min(free_energies):.2f} kcal/mol")
# Create plot if requested
if args.plot:
plt.figure(figsize=(10, 6))
# Plot free energy for each cluster
plt.bar(range(len(unique_labels)), free_energies, alpha=0.7)
plt.xlabel('Cluster')
plt.ylabel('ΔG (kcal/mol)')
plt.title('Free Energy Landscape')
plt.xticks(range(len(unique_labels)))
plt.grid(axis='y', linestyle='--', alpha=0.7)
# Add population as text on bars
for i, (energy, pop) in enumerate(zip(free_energies, populations)):
plt.text(i, energy + 0.1, f"{pop*100:.1f}%", ha='center')
plt.tight_layout()
plt.savefig(args.plot, dpi=300)
print(f"Plot saved to {args.plot}")
if __name__ == "__main__":
main()