550 lines
20 KiB
Python
550 lines
20 KiB
Python
import argparse
|
|
import json
|
|
import os.path
|
|
import random
|
|
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from data_utils import (
|
|
element_dict_rev,
|
|
alphabet,
|
|
restype_int_to_str,
|
|
featurize,
|
|
parse_PDB,
|
|
)
|
|
from model_utils import ProteinMPNN
|
|
|
|
|
|
def main(args) -> None:
|
|
"""
|
|
Inference function
|
|
"""
|
|
if args.seed:
|
|
seed = args.seed
|
|
else:
|
|
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
|
folder_for_outputs = args.out_folder
|
|
base_folder = folder_for_outputs
|
|
if base_folder[-1] != "/":
|
|
base_folder = base_folder + "/"
|
|
if not os.path.exists(base_folder):
|
|
os.makedirs(base_folder, exist_ok=True)
|
|
if args.model_type == "protein_mpnn":
|
|
checkpoint_path = args.checkpoint_protein_mpnn
|
|
elif args.model_type == "ligand_mpnn":
|
|
checkpoint_path = args.checkpoint_ligand_mpnn
|
|
elif args.model_type == "per_residue_label_membrane_mpnn":
|
|
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
|
|
elif args.model_type == "global_label_membrane_mpnn":
|
|
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
|
|
elif args.model_type == "soluble_mpnn":
|
|
checkpoint_path = args.checkpoint_soluble_mpnn
|
|
else:
|
|
print("Choose one of the available models")
|
|
sys.exit()
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
if args.model_type == "ligand_mpnn":
|
|
atom_context_num = checkpoint["atom_context_num"]
|
|
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
|
|
k_neighbors = checkpoint["num_edges"]
|
|
else:
|
|
atom_context_num = 1
|
|
ligand_mpnn_use_side_chain_context = 0
|
|
k_neighbors = checkpoint["num_edges"]
|
|
|
|
model = ProteinMPNN(
|
|
node_features=128,
|
|
edge_features=128,
|
|
hidden_dim=128,
|
|
num_encoder_layers=3,
|
|
num_decoder_layers=3,
|
|
k_neighbors=k_neighbors,
|
|
device=device,
|
|
atom_context_num=atom_context_num,
|
|
model_type=args.model_type,
|
|
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
|
|
)
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
if args.pdb_path_multi:
|
|
with open(args.pdb_path_multi, "r") as fh:
|
|
pdb_paths = list(json.load(fh))
|
|
else:
|
|
pdb_paths = [args.pdb_path]
|
|
|
|
if args.fixed_residues_multi:
|
|
with open(args.fixed_residues_multi, "r") as fh:
|
|
fixed_residues_multi = json.load(fh)
|
|
else:
|
|
fixed_residues = [item for item in args.fixed_residues.split()]
|
|
fixed_residues_multi = {}
|
|
for pdb in pdb_paths:
|
|
fixed_residues_multi[pdb] = fixed_residues
|
|
|
|
if args.redesigned_residues_multi:
|
|
with open(args.redesigned_residues_multi, "r") as fh:
|
|
redesigned_residues_multi = json.load(fh)
|
|
else:
|
|
redesigned_residues = [item for item in args.redesigned_residues.split()]
|
|
redesigned_residues_multi = {}
|
|
for pdb in pdb_paths:
|
|
redesigned_residues_multi[pdb] = redesigned_residues
|
|
|
|
# loop over PDB paths
|
|
for pdb in pdb_paths:
|
|
if args.verbose:
|
|
print("Designing protein from this path:", pdb)
|
|
fixed_residues = fixed_residues_multi[pdb]
|
|
redesigned_residues = redesigned_residues_multi[pdb]
|
|
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
|
|
pdb,
|
|
device=device,
|
|
chains=args.parse_these_chains_only,
|
|
parse_all_atoms=args.ligand_mpnn_use_side_chain_context,
|
|
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy
|
|
)
|
|
# make chain_letter + residue_idx + insertion_code mapping to integers
|
|
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
|
|
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
|
|
encoded_residues = []
|
|
for i, R_idx_item in enumerate(R_idx_list):
|
|
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
|
|
encoded_residues.append(tmp)
|
|
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
|
|
encoded_residue_dict_rev = dict(
|
|
zip(list(range(len(encoded_residues))), encoded_residues)
|
|
)
|
|
|
|
fixed_positions = torch.tensor(
|
|
[int(item not in fixed_residues) for item in encoded_residues],
|
|
device=device,
|
|
)
|
|
redesigned_positions = torch.tensor(
|
|
[int(item not in redesigned_residues) for item in encoded_residues],
|
|
device=device,
|
|
)
|
|
|
|
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
|
|
if args.transmembrane_buried:
|
|
buried_residues = [item for item in args.transmembrane_buried.split()]
|
|
buried_positions = torch.tensor(
|
|
[int(item in buried_residues) for item in encoded_residues],
|
|
device=device,
|
|
)
|
|
else:
|
|
buried_positions = torch.zeros_like(fixed_positions)
|
|
|
|
if args.transmembrane_interface:
|
|
interface_residues = [item for item in args.transmembrane_interface.split()]
|
|
interface_positions = torch.tensor(
|
|
[int(item in interface_residues) for item in encoded_residues],
|
|
device=device,
|
|
)
|
|
else:
|
|
interface_positions = torch.zeros_like(fixed_positions)
|
|
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
|
|
1 - interface_positions
|
|
) + 1 * interface_positions * (1 - buried_positions)
|
|
|
|
if args.model_type == "global_label_membrane_mpnn":
|
|
protein_dict["membrane_per_residue_labels"] = (
|
|
args.global_transmembrane_label + 0 * fixed_positions
|
|
)
|
|
if type(args.chains_to_design) == str:
|
|
chains_to_design_list = args.chains_to_design.split(",")
|
|
else:
|
|
chains_to_design_list = protein_dict["chain_letters"]
|
|
chain_mask = torch.tensor(
|
|
np.array(
|
|
[
|
|
item in chains_to_design_list
|
|
for item in protein_dict["chain_letters"]
|
|
],
|
|
dtype=np.int32,
|
|
),
|
|
device=device,
|
|
)
|
|
|
|
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
|
|
if redesigned_residues:
|
|
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
|
|
elif fixed_residues:
|
|
protein_dict["chain_mask"] = chain_mask * fixed_positions
|
|
else:
|
|
protein_dict["chain_mask"] = chain_mask
|
|
|
|
if args.verbose:
|
|
PDB_residues_to_be_redesigned = [
|
|
encoded_residue_dict_rev[item]
|
|
for item in range(protein_dict["chain_mask"].shape[0])
|
|
if protein_dict["chain_mask"][item] == 1
|
|
]
|
|
PDB_residues_to_be_fixed = [
|
|
encoded_residue_dict_rev[item]
|
|
for item in range(protein_dict["chain_mask"].shape[0])
|
|
if protein_dict["chain_mask"][item] == 0
|
|
]
|
|
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
|
|
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
|
|
|
|
# specify which residues are linked
|
|
if args.symmetry_residues:
|
|
symmetry_residues_list_of_lists = [
|
|
x.split(",") for x in args.symmetry_residues.split("|")
|
|
]
|
|
remapped_symmetry_residues = []
|
|
for t_list in symmetry_residues_list_of_lists:
|
|
tmp_list = []
|
|
for t in t_list:
|
|
tmp_list.append(encoded_residue_dict[t])
|
|
remapped_symmetry_residues.append(tmp_list)
|
|
else:
|
|
remapped_symmetry_residues = [[]]
|
|
|
|
if args.homo_oligomer:
|
|
if args.verbose:
|
|
print("Designing HOMO-OLIGOMER")
|
|
chain_letters_set = list(set(chain_letters_list))
|
|
reference_chain = chain_letters_set[0]
|
|
lc = len(reference_chain)
|
|
residue_indices = [
|
|
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
|
|
]
|
|
remapped_symmetry_residues = []
|
|
for res in residue_indices:
|
|
tmp_list = []
|
|
tmp_w_list = []
|
|
for chain in chain_letters_set:
|
|
name = chain + res
|
|
tmp_list.append(encoded_residue_dict[name])
|
|
tmp_w_list.append(1 / len(chain_letters_set))
|
|
remapped_symmetry_residues.append(tmp_list)
|
|
|
|
# set other atom bfactors to 0.0
|
|
if other_atoms:
|
|
other_bfactors = other_atoms.getBetas()
|
|
other_atoms.setBetas(other_bfactors * 0.0)
|
|
|
|
# adjust input PDB name by dropping .pdb if it does exist
|
|
name = pdb[pdb.rfind("/") + 1 :]
|
|
if name[-4:] == ".pdb":
|
|
name = name[:-4]
|
|
|
|
with torch.no_grad():
|
|
# run featurize to remap R_idx and add batch dimension
|
|
if args.verbose:
|
|
if "Y" in list(protein_dict):
|
|
atom_coords = protein_dict["Y"].cpu().numpy()
|
|
atom_types = list(protein_dict["Y_t"].cpu().numpy())
|
|
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
|
|
number_of_atoms_parsed = np.sum(atom_mask)
|
|
else:
|
|
print("No ligand atoms parsed")
|
|
number_of_atoms_parsed = 0
|
|
atom_types = ""
|
|
atom_coords = []
|
|
if number_of_atoms_parsed == 0:
|
|
print("No ligand atoms parsed")
|
|
elif args.model_type == "ligand_mpnn":
|
|
print(
|
|
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
|
|
)
|
|
for i, atom_type in enumerate(atom_types):
|
|
print(
|
|
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
|
|
)
|
|
feature_dict = featurize(
|
|
protein_dict,
|
|
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
|
|
use_atom_context=args.ligand_mpnn_use_atom_context,
|
|
number_of_ligand_atoms=atom_context_num,
|
|
model_type=args.model_type,
|
|
)
|
|
feature_dict["batch_size"] = args.batch_size
|
|
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
|
|
# add additional keys to the feature dictionary
|
|
feature_dict["symmetry_residues"] = remapped_symmetry_residues
|
|
|
|
logits_list = []
|
|
probs_list = []
|
|
log_probs_list = []
|
|
decoding_order_list = []
|
|
for _ in range(args.number_of_batches):
|
|
feature_dict["randn"] = torch.randn(
|
|
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
|
|
device=device,
|
|
)
|
|
if args.autoregressive_score:
|
|
score_dict = model.score(feature_dict, use_sequence=args.use_sequence)
|
|
elif args.single_aa_score:
|
|
score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence)
|
|
else:
|
|
print("Set either autoregressive_score or single_aa_score to True")
|
|
sys.exit()
|
|
logits_list.append(score_dict["logits"])
|
|
log_probs_list.append(score_dict["log_probs"])
|
|
probs_list.append(torch.exp(score_dict["log_probs"]))
|
|
decoding_order_list.append(score_dict["decoding_order"])
|
|
log_probs_stack = torch.cat(log_probs_list, 0)
|
|
logits_stack = torch.cat(logits_list, 0)
|
|
probs_stack = torch.cat(probs_list, 0)
|
|
decoding_order_stack = torch.cat(decoding_order_list, 0)
|
|
|
|
output_stats_path = base_folder + name + args.file_ending + ".pt"
|
|
out_dict = {}
|
|
out_dict["logits"] = logits_stack.cpu().numpy()
|
|
out_dict["probs"] = probs_stack.cpu().numpy()
|
|
out_dict["log_probs"] = log_probs_stack.cpu().numpy()
|
|
out_dict["decoding_order"] = decoding_order_stack.cpu().numpy()
|
|
out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy()
|
|
out_dict["mask"] = feature_dict["mask"][0].cpu().numpy()
|
|
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order
|
|
out_dict["seed"] = seed
|
|
out_dict["alphabet"] = alphabet
|
|
out_dict["residue_names"] = encoded_residue_dict_rev
|
|
|
|
mean_probs = np.mean(out_dict["probs"], 0)
|
|
std_probs = np.std(out_dict["probs"], 0)
|
|
sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]]
|
|
mean_dict = {}
|
|
std_dict = {}
|
|
for residue in range(L):
|
|
mean_dict_ = dict(zip(alphabet, mean_probs[residue]))
|
|
mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_
|
|
std_dict_ = dict(zip(alphabet, std_probs[residue]))
|
|
std_dict[encoded_residue_dict_rev[residue]] = std_dict_
|
|
|
|
out_dict["sequence"] = sequence
|
|
out_dict["mean_of_probs"] = mean_dict
|
|
out_dict["std_of_probs"] = std_dict
|
|
torch.save(out_dict, output_stats_path)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argparser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--model_type",
|
|
type=str,
|
|
default="protein_mpnn",
|
|
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
|
|
)
|
|
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
|
|
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
|
|
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
|
|
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
|
|
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
|
|
argparser.add_argument(
|
|
"--checkpoint_protein_mpnn",
|
|
type=str,
|
|
default="./model_params/proteinmpnn_v_48_020.pt",
|
|
help="Path to model weights.",
|
|
)
|
|
argparser.add_argument(
|
|
"--checkpoint_ligand_mpnn",
|
|
type=str,
|
|
default="./model_params/ligandmpnn_v_32_010_25.pt",
|
|
help="Path to model weights.",
|
|
)
|
|
argparser.add_argument(
|
|
"--checkpoint_per_residue_label_membrane_mpnn",
|
|
type=str,
|
|
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
|
|
help="Path to model weights.",
|
|
)
|
|
argparser.add_argument(
|
|
"--checkpoint_global_label_membrane_mpnn",
|
|
type=str,
|
|
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
|
|
help="Path to model weights.",
|
|
)
|
|
argparser.add_argument(
|
|
"--checkpoint_soluble_mpnn",
|
|
type=str,
|
|
default="./model_params/solublempnn_v_48_020.pt",
|
|
help="Path to model weights.",
|
|
)
|
|
|
|
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
|
|
|
|
argparser.add_argument(
|
|
"--pdb_path", type=str, default="", help="Path to the input PDB."
|
|
)
|
|
argparser.add_argument(
|
|
"--pdb_path_multi",
|
|
type=str,
|
|
default="",
|
|
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--fixed_residues",
|
|
type=str,
|
|
default="",
|
|
help="Provide fixed residues, A12 A13 A14 B2 B25",
|
|
)
|
|
argparser.add_argument(
|
|
"--fixed_residues_multi",
|
|
type=str,
|
|
default="",
|
|
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--redesigned_residues",
|
|
type=str,
|
|
default="",
|
|
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
|
|
)
|
|
argparser.add_argument(
|
|
"--redesigned_residues_multi",
|
|
type=str,
|
|
default="",
|
|
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--symmetry_residues",
|
|
type=str,
|
|
default="",
|
|
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--homo_oligomer",
|
|
type=int,
|
|
default=0,
|
|
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--out_folder",
|
|
type=str,
|
|
help="Path to a folder to output scores, e.g. /home/out/",
|
|
)
|
|
argparser.add_argument(
|
|
"--file_ending", type=str, default="", help="adding_string_to_the_end"
|
|
)
|
|
argparser.add_argument(
|
|
"--zero_indexed",
|
|
type=str,
|
|
default=0,
|
|
help="1 - to start output PDB numbering with 0",
|
|
)
|
|
argparser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=0,
|
|
help="Set seed for torch, numpy, and python random.",
|
|
)
|
|
argparser.add_argument(
|
|
"--batch_size",
|
|
type=int,
|
|
default=1,
|
|
help="Number of sequence to generate per one pass.",
|
|
)
|
|
argparser.add_argument(
|
|
"--number_of_batches",
|
|
type=int,
|
|
default=1,
|
|
help="Number of times to design sequence using a chosen batch size.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--ligand_mpnn_use_atom_context",
|
|
type=int,
|
|
default=1,
|
|
help="1 - use atom context, 0 - do not use atom context.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--ligand_mpnn_use_side_chain_context",
|
|
type=int,
|
|
default=0,
|
|
help="Flag to use side chain atoms as ligand context for the fixed residues",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--ligand_mpnn_cutoff_for_score",
|
|
type=float,
|
|
default=8.0,
|
|
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--chains_to_design",
|
|
type=str,
|
|
default=None,
|
|
help="Specify which chains to redesign, all others will be kept fixed.",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--parse_these_chains_only",
|
|
type=str,
|
|
default="",
|
|
help="Provide chains letters for parsing backbones, 'ABCF'",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--transmembrane_buried",
|
|
type=str,
|
|
default="",
|
|
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
|
)
|
|
argparser.add_argument(
|
|
"--transmembrane_interface",
|
|
type=str,
|
|
default="",
|
|
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--global_transmembrane_label",
|
|
type=int,
|
|
default=0,
|
|
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--parse_atoms_with_zero_occupancy",
|
|
type=int,
|
|
default=0,
|
|
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--use_sequence",
|
|
type=int,
|
|
default=1,
|
|
help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--autoregressive_score",
|
|
type=int,
|
|
default=0,
|
|
help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False",
|
|
)
|
|
|
|
argparser.add_argument(
|
|
"--single_aa_score",
|
|
type=int,
|
|
default=1,
|
|
help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False",
|
|
)
|
|
|
|
args = argparser.parse_args()
|
|
main(args)
|