Files
pocketminer/entrypoint.py
Olamide Isreal 42d4e6cb87 Add WES pipeline configuration for pocketminer
- Add Nextflow pipeline (main.nf) with Harbor container image
- Add nextflow.config with k8s/k8s_gpu/standard profiles
- Add params.json for TRS/WES parameter discovery
- Add Dockerfile, entrypoint.py, meta.yml from original implementation
- Update paths to use /omic/eureka/Pocketminer/ convention
- Update .gitignore to allow params.json
2026-03-23 13:27:40 +01:00

294 lines
9.1 KiB
Python

#!/usr/bin/env python3
"""
PocketMiner Entrypoint - Command-line wrapper for cryptic pocket prediction
This script wraps the PocketMiner xtal_predict.py functionality with a proper
command-line interface for Nextflow/Docker integration.
"""
import argparse
import json
import os
import sys
import numpy as np
from pathlib import Path
import warnings
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
# Import PocketMiner components
sys.path.insert(0, '/workspace/gvp/src')
try:
import tensorflow as tf
import mdtraj as md
from models import MQAModel
from util import load_checkpoint
from validate_performance_on_xtals import process_strucs, predict_on_xtals
except ImportError as e:
print(f"Error importing PocketMiner modules: {e}", file=sys.stderr)
print("Please ensure the GVP repository is properly cloned and models are available.", file=sys.stderr)
sys.exit(1)
def load_model(model_path, dropout=0.1, num_layers=4, hidden_dim=100):
"""Load pre-trained PocketMiner model"""
# Model architecture from original PocketMiner (must match checkpoint exactly)
model = MQAModel(
node_features=(8, 50),
edge_features=(1, 32),
hidden_dim=(16, hidden_dim), # (16, 100) for pocketminer checkpoint
num_layers=num_layers,
dropout=dropout
)
# Load checkpoint
opt = tf.keras.optimizers.Adam()
load_checkpoint(model, opt, model_path)
return model
def make_predictions(pdb_file, model, model_path, output_folder, output_name, debug=False):
"""Make cryptic pocket predictions for a PDB structure"""
# Load structure using mdtraj
try:
struc = md.load(pdb_file)
strucs = [struc]
except Exception as e:
raise ValueError(f"Failed to load PDB file {pdb_file}: {e}")
# Process structure to get features
X, S, mask = process_strucs(strucs)
# Get predictions using PocketMiner model
predictions = predict_on_xtals(model, model_path, X, S, mask)
# Extract predictions for the single structure
# predictions shape: (batch, max_length)
pred_array = predictions[0] # First (and only) structure
mask_array = mask[0] # Corresponding mask
# Convert TensorFlow tensors to NumPy arrays explicitly
if hasattr(pred_array, 'numpy'):
pred_array = pred_array.numpy()
if hasattr(mask_array, 'numpy'):
mask_array = mask_array.numpy()
# Ensure arrays are NumPy (in case they weren't TensorFlow tensors)
pred_array = np.asarray(pred_array)
mask_array = np.asarray(mask_array)
# Get only valid (masked) residues
valid_residues = mask_array > 0
pred_valid = pred_array[valid_residues]
# Save outputs
output_path = Path(output_folder)
output_path.mkdir(parents=True, exist_ok=True)
# Save binary predictions (full array with padding)
pred_file = output_path / f"{output_name}-preds.npy"
np.save(pred_file, pred_valid)
# Save human-readable predictions
txt_file = output_path / f"{output_name}-predictions.txt"
np.savetxt(txt_file, pred_valid, fmt='%.4f')
# Calculate summary statistics
cryptic_pocket_score = float(np.mean(pred_valid))
high_confidence_residues = int(np.sum(pred_valid > 0.7))
medium_confidence_residues = int(np.sum((pred_valid > 0.4) & (pred_valid <= 0.7)))
# Save debug features if requested
if debug:
np.save(output_path / f"{output_name}_X.npy", X)
np.save(output_path / f"{output_name}_S.npy", S)
np.save(output_path / f"{output_name}_mask.npy", mask)
# Cluster high-confidence residues
pocket_clusters = cluster_residues(pred_valid, threshold=0.5)
# Generate summary JSON
summary = {
"cryptic_pocket_score": cryptic_pocket_score,
"high_confidence_residues": high_confidence_residues,
"medium_confidence_residues": medium_confidence_residues,
"total_residues": len(pred_valid),
"pocket_clusters": pocket_clusters,
"output_files": {
"predictions_npy": str(pred_file),
"predictions_txt": str(txt_file)
}
}
summary_file = output_path / f"{output_name}-summary.json"
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
return summary
def cluster_residues(predictions, threshold=0.5, min_cluster_size=3):
"""
Cluster high-scoring residues into spatial pockets
Returns list of clusters with residue indices and average scores
"""
# Ensure predictions is a pure NumPy array
if hasattr(predictions, 'numpy'):
predictions = predictions.numpy()
predictions = np.asarray(predictions)
high_score_idx = np.where(predictions > threshold)[0]
if len(high_score_idx) == 0:
return []
# Simple sequential clustering (assumes residues are ordered by sequence)
# More sophisticated spatial clustering would require 3D coordinates
clusters = []
current_cluster = [int(high_score_idx[0])] # Convert to Python int
for idx in high_score_idx[1:]:
idx = int(idx) # Convert to Python int
if idx - current_cluster[-1] <= 2: # Allow 2-residue gaps
current_cluster.append(idx)
else:
if len(current_cluster) >= min_cluster_size:
# Use NumPy array indexing for safety
cluster_indices = np.array(current_cluster)
cluster_score = float(np.mean(predictions[cluster_indices]))
clusters.append({
"residue_indices": current_cluster,
"size": len(current_cluster),
"average_score": cluster_score
})
current_cluster = [idx]
# Add final cluster
if len(current_cluster) >= min_cluster_size:
cluster_indices = np.array(current_cluster)
cluster_score = float(np.mean(predictions[cluster_indices]))
clusters.append({
"residue_indices": current_cluster,
"size": len(current_cluster),
"average_score": cluster_score
})
# Sort by score
clusters.sort(key=lambda x: x['average_score'], reverse=True)
return clusters
def main():
parser = argparse.ArgumentParser(
description='PocketMiner: Predict cryptic binding pockets in protein structures'
)
parser.add_argument(
'--pdb',
required=True,
help='Input PDB file path'
)
parser.add_argument(
'--output-folder',
default='.',
help='Output directory for results (default: current directory)'
)
parser.add_argument(
'--output-name',
required=True,
help='Base name for output files'
)
parser.add_argument(
'--model-path',
default='/workspace/gvp/models/pocketminer',
help='Path to pre-trained model checkpoint'
)
parser.add_argument(
'--debug',
action='store_true',
help='Save debug features (X, S, mask arrays)'
)
parser.add_argument(
'--dropout',
type=float,
default=0.1,
help='Model dropout rate (default: 0.1)'
)
parser.add_argument(
'--num-layers',
type=int,
default=4,
help='Number of model layers (default: 4)'
)
parser.add_argument(
'--hidden-dim',
type=int,
default=100,
help='Hidden dimension size (default: 100)'
)
args = parser.parse_args()
# Validate inputs
if not os.path.exists(args.pdb):
print(f"Error: PDB file not found: {args.pdb}", file=sys.stderr)
sys.exit(1)
# Check if model checkpoint files exist (model_path is a prefix, not a directory)
model_index = f"{args.model_path}.index"
if not os.path.exists(model_index):
print(f"Error: Model checkpoint not found: {args.model_path}", file=sys.stderr)
print(f"Looking for: {model_index}", file=sys.stderr)
print("Please ensure the pre-trained model is available.", file=sys.stderr)
sys.exit(1)
print(f"Loading PocketMiner model from {args.model_path}...")
model = load_model(
args.model_path,
dropout=args.dropout,
num_layers=args.num_layers,
hidden_dim=args.hidden_dim
)
print(f"Processing structure: {args.pdb}")
summary = make_predictions(
pdb_file=args.pdb,
model=model,
model_path=args.model_path,
output_folder=args.output_folder,
output_name=args.output_name,
debug=args.debug
)
print("\n" + "="*60)
print("PocketMiner Prediction Summary")
print("="*60)
print(f"Overall cryptic pocket score: {summary['cryptic_pocket_score']:.4f}")
print(f"High confidence residues (>0.7): {summary['high_confidence_residues']}")
print(f"Medium confidence residues (0.4-0.7): {summary['medium_confidence_residues']}")
print(f"Total residues analyzed: {summary['total_residues']}")
print(f"\nPocket clusters identified: {len(summary['pocket_clusters'])}")
for i, cluster in enumerate(summary['pocket_clusters'][:5], 1):
print(f" Cluster {i}: {cluster['size']} residues, score={cluster['average_score']:.4f}")
print(f"\nResults saved to: {args.output_folder}")
print("="*60 + "\n")
if __name__ == '__main__':
main()