import argparse import copy import json import os.path import random import sys import numpy as np import torch from data_utils import ( alphabet, element_dict_rev, featurize, get_score, get_seq_rec, parse_PDB, restype_1to3, restype_int_to_str, restype_str_to_int, write_full_PDB, ) from model_utils import ProteinMPNN from prody import writePDB from sc_utils import Packer, pack_side_chains def main(args) -> None: """ Inference function """ if args.seed: seed = args.seed else: seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0]) torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") folder_for_outputs = args.out_folder base_folder = folder_for_outputs if base_folder[-1] != "/": base_folder = base_folder + "/" if not os.path.exists(base_folder): os.makedirs(base_folder, exist_ok=True) if not os.path.exists(base_folder + "seqs"): os.makedirs(base_folder + "seqs", exist_ok=True) if not os.path.exists(base_folder + "backbones"): os.makedirs(base_folder + "backbones", exist_ok=True) if not os.path.exists(base_folder + "packed"): os.makedirs(base_folder + "packed", exist_ok=True) if args.save_stats: if not os.path.exists(base_folder + "stats"): os.makedirs(base_folder + "stats", exist_ok=True) if args.model_type == "protein_mpnn": checkpoint_path = args.checkpoint_protein_mpnn elif args.model_type == "ligand_mpnn": checkpoint_path = args.checkpoint_ligand_mpnn elif args.model_type == "per_residue_label_membrane_mpnn": checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn elif args.model_type == "global_label_membrane_mpnn": checkpoint_path = args.checkpoint_global_label_membrane_mpnn elif args.model_type == "soluble_mpnn": checkpoint_path = args.checkpoint_soluble_mpnn else: print("Choose one of the available models") sys.exit() checkpoint = torch.load(checkpoint_path, map_location=device) if args.model_type == "ligand_mpnn": atom_context_num = checkpoint["atom_context_num"] ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context k_neighbors = checkpoint["num_edges"] else: atom_context_num = 1 ligand_mpnn_use_side_chain_context = 0 k_neighbors = checkpoint["num_edges"] model = ProteinMPNN( node_features=128, edge_features=128, hidden_dim=128, num_encoder_layers=3, num_decoder_layers=3, k_neighbors=k_neighbors, device=device, atom_context_num=atom_context_num, model_type=args.model_type, ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context, ) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() if args.pack_side_chains: model_sc = Packer( node_features=128, edge_features=128, num_positional_embeddings=16, num_chain_embeddings=16, num_rbf=16, hidden_dim=128, num_encoder_layers=3, num_decoder_layers=3, atom_context_num=16, lower_bound=0.0, upper_bound=20.0, top_k=32, dropout=0.0, augment_eps=0.0, atom37_order=False, device=device, num_mix=3, ) checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device) model_sc.load_state_dict(checkpoint_sc["model_state_dict"]) model_sc.to(device) model_sc.eval() if args.pdb_path_multi: with open(args.pdb_path_multi, "r") as fh: pdb_paths = list(json.load(fh)) else: pdb_paths = [args.pdb_path] if args.fixed_residues_multi: with open(args.fixed_residues_multi, "r") as fh: fixed_residues_multi = json.load(fh) fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()} else: fixed_residues = [item for item in args.fixed_residues.split()] fixed_residues_multi = {} for pdb in pdb_paths: fixed_residues_multi[pdb] = fixed_residues if args.redesigned_residues_multi: with open(args.redesigned_residues_multi, "r") as fh: redesigned_residues_multi = json.load(fh) redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()} else: redesigned_residues = [item for item in args.redesigned_residues.split()] redesigned_residues_multi = {} for pdb in pdb_paths: redesigned_residues_multi[pdb] = redesigned_residues bias_AA = torch.zeros([21], device=device, dtype=torch.float32) if args.bias_AA: tmp = [item.split(":") for item in args.bias_AA.split(",")] a1 = [b[0] for b in tmp] a2 = [float(b[1]) for b in tmp] for i, AA in enumerate(a1): bias_AA[restype_str_to_int[AA]] = a2[i] if args.bias_AA_per_residue_multi: with open(args.bias_AA_per_residue_multi, "r") as fh: bias_AA_per_residue_multi = json.load( fh ) # {"pdb_path" : {"A12": {"G": 1.1}}} else: if args.bias_AA_per_residue: with open(args.bias_AA_per_residue, "r") as fh: bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}} bias_AA_per_residue_multi = {} for pdb in pdb_paths: bias_AA_per_residue_multi[pdb] = bias_AA_per_residue if args.omit_AA_per_residue_multi: with open(args.omit_AA_per_residue_multi, "r") as fh: omit_AA_per_residue_multi = json.load( fh ) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}} else: if args.omit_AA_per_residue: with open(args.omit_AA_per_residue, "r") as fh: omit_AA_per_residue = json.load(fh) # {"A12": "PG"} omit_AA_per_residue_multi = {} for pdb in pdb_paths: omit_AA_per_residue_multi[pdb] = omit_AA_per_residue omit_AA_list = args.omit_AA omit_AA = torch.tensor( np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32), device=device, ) if len(args.parse_these_chains_only) != 0: parse_these_chains_only_list = args.parse_these_chains_only.split(",") else: parse_these_chains_only_list = [] # loop over PDB paths for pdb in pdb_paths: if args.verbose: print("Designing protein from this path:", pdb) fixed_residues = fixed_residues_multi[pdb] redesigned_residues = redesigned_residues_multi[pdb] parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or ( args.pack_side_chains and not args.repack_everything ) protein_dict, backbone, other_atoms, icodes, _ = parse_PDB( pdb, device=device, chains=parse_these_chains_only_list, parse_all_atoms=parse_all_atoms_flag, parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy, ) # make chain_letter + residue_idx + insertion_code mapping to integers R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices chain_letters_list = list(protein_dict["chain_letters"]) # chain letters encoded_residues = [] for i, R_idx_item in enumerate(R_idx_list): tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i] encoded_residues.append(tmp) encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues)))) encoded_residue_dict_rev = dict( zip(list(range(len(encoded_residues))), encoded_residues) ) bias_AA_per_residue = torch.zeros( [len(encoded_residues), 21], device=device, dtype=torch.float32 ) if args.bias_AA_per_residue_multi or args.bias_AA_per_residue: bias_dict = bias_AA_per_residue_multi[pdb] for residue_name, v1 in bias_dict.items(): if residue_name in encoded_residues: i1 = encoded_residue_dict[residue_name] for amino_acid, v2 in v1.items(): if amino_acid in alphabet: j1 = restype_str_to_int[amino_acid] bias_AA_per_residue[i1, j1] = v2 omit_AA_per_residue = torch.zeros( [len(encoded_residues), 21], device=device, dtype=torch.float32 ) if args.omit_AA_per_residue_multi or args.omit_AA_per_residue: omit_dict = omit_AA_per_residue_multi[pdb] for residue_name, v1 in omit_dict.items(): if residue_name in encoded_residues: i1 = encoded_residue_dict[residue_name] for amino_acid in v1: if amino_acid in alphabet: j1 = restype_str_to_int[amino_acid] omit_AA_per_residue[i1, j1] = 1.0 fixed_positions = torch.tensor( [int(item not in fixed_residues) for item in encoded_residues], device=device, ) redesigned_positions = torch.tensor( [int(item not in redesigned_residues) for item in encoded_residues], device=device, ) # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model if args.transmembrane_buried: buried_residues = [item for item in args.transmembrane_buried.split()] buried_positions = torch.tensor( [int(item in buried_residues) for item in encoded_residues], device=device, ) else: buried_positions = torch.zeros_like(fixed_positions) if args.transmembrane_interface: interface_residues = [item for item in args.transmembrane_interface.split()] interface_positions = torch.tensor( [int(item in interface_residues) for item in encoded_residues], device=device, ) else: interface_positions = torch.zeros_like(fixed_positions) protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * ( 1 - interface_positions ) + 1 * interface_positions * (1 - buried_positions) if args.model_type == "global_label_membrane_mpnn": protein_dict["membrane_per_residue_labels"] = ( args.global_transmembrane_label + 0 * fixed_positions ) if len(args.chains_to_design) != 0: chains_to_design_list = args.chains_to_design.split(",") else: chains_to_design_list = protein_dict["chain_letters"] chain_mask = torch.tensor( np.array( [ item in chains_to_design_list for item in protein_dict["chain_letters"] ], dtype=np.int32, ), device=device, ) # create chain_mask to notify which residues are fixed (0) and which need to be designed (1) if redesigned_residues: protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions) elif fixed_residues: protein_dict["chain_mask"] = chain_mask * fixed_positions else: protein_dict["chain_mask"] = chain_mask if args.verbose: PDB_residues_to_be_redesigned = [ encoded_residue_dict_rev[item] for item in range(protein_dict["chain_mask"].shape[0]) if protein_dict["chain_mask"][item] == 1 ] PDB_residues_to_be_fixed = [ encoded_residue_dict_rev[item] for item in range(protein_dict["chain_mask"].shape[0]) if protein_dict["chain_mask"][item] == 0 ] print("These residues will be redesigned: ", PDB_residues_to_be_redesigned) print("These residues will be fixed: ", PDB_residues_to_be_fixed) # specify which residues are linked if args.symmetry_residues: symmetry_residues_list_of_lists = [ x.split(",") for x in args.symmetry_residues.split("|") ] remapped_symmetry_residues = [] for t_list in symmetry_residues_list_of_lists: tmp_list = [] for t in t_list: tmp_list.append(encoded_residue_dict[t]) remapped_symmetry_residues.append(tmp_list) else: remapped_symmetry_residues = [[]] # specify linking weights if args.symmetry_weights: symmetry_weights = [ [float(item) for item in x.split(",")] for x in args.symmetry_weights.split("|") ] else: symmetry_weights = [[]] if args.homo_oligomer: if args.verbose: print("Designing HOMO-OLIGOMER") chain_letters_set = list(set(chain_letters_list)) reference_chain = chain_letters_set[0] lc = len(reference_chain) residue_indices = [ item[lc:] for item in encoded_residues if item[:lc] == reference_chain ] remapped_symmetry_residues = [] symmetry_weights = [] for res in residue_indices: tmp_list = [] tmp_w_list = [] for chain in chain_letters_set: name = chain + res tmp_list.append(encoded_residue_dict[name]) tmp_w_list.append(1 / len(chain_letters_set)) remapped_symmetry_residues.append(tmp_list) symmetry_weights.append(tmp_w_list) # set other atom bfactors to 0.0 if other_atoms: other_bfactors = other_atoms.getBetas() other_atoms.setBetas(other_bfactors * 0.0) # adjust input PDB name by dropping .pdb if it does exist name = pdb[pdb.rfind("/") + 1 :] if name[-4:] == ".pdb": name = name[:-4] with torch.no_grad(): # run featurize to remap R_idx and add batch dimension if args.verbose: if "Y" in list(protein_dict): atom_coords = protein_dict["Y"].cpu().numpy() atom_types = list(protein_dict["Y_t"].cpu().numpy()) atom_mask = list(protein_dict["Y_m"].cpu().numpy()) number_of_atoms_parsed = np.sum(atom_mask) else: print("No ligand atoms parsed") number_of_atoms_parsed = 0 atom_types = "" atom_coords = [] if number_of_atoms_parsed == 0: print("No ligand atoms parsed") elif args.model_type == "ligand_mpnn": print( f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}" ) for i, atom_type in enumerate(atom_types): print( f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}" ) feature_dict = featurize( protein_dict, cutoff_for_score=args.ligand_mpnn_cutoff_for_score, use_atom_context=args.ligand_mpnn_use_atom_context, number_of_ligand_atoms=atom_context_num, model_type=args.model_type, ) feature_dict["batch_size"] = args.batch_size B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now. # add additional keys to the feature dictionary feature_dict["temperature"] = args.temperature feature_dict["bias"] = ( (-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1]) + bias_AA_per_residue[None] - 1e8 * omit_AA_per_residue[None] ) feature_dict["symmetry_residues"] = remapped_symmetry_residues feature_dict["symmetry_weights"] = symmetry_weights sampling_probs_list = [] log_probs_list = [] decoding_order_list = [] S_list = [] loss_list = [] loss_per_residue_list = [] loss_XY_list = [] for _ in range(args.number_of_batches): feature_dict["randn"] = torch.randn( [feature_dict["batch_size"], feature_dict["mask"].shape[1]], device=device, ) output_dict = model.sample(feature_dict) # compute confidence scores loss, loss_per_residue = get_score( output_dict["S"], output_dict["log_probs"], feature_dict["mask"] * feature_dict["chain_mask"], ) if args.model_type == "ligand_mpnn": combined_mask = ( feature_dict["mask"] * feature_dict["mask_XY"] * feature_dict["chain_mask"] ) else: combined_mask = feature_dict["mask"] * feature_dict["chain_mask"] loss_XY, _ = get_score( output_dict["S"], output_dict["log_probs"], combined_mask ) # ----- S_list.append(output_dict["S"]) log_probs_list.append(output_dict["log_probs"]) sampling_probs_list.append(output_dict["sampling_probs"]) decoding_order_list.append(output_dict["decoding_order"]) loss_list.append(loss) loss_per_residue_list.append(loss_per_residue) loss_XY_list.append(loss_XY) S_stack = torch.cat(S_list, 0) log_probs_stack = torch.cat(log_probs_list, 0) sampling_probs_stack = torch.cat(sampling_probs_list, 0) decoding_order_stack = torch.cat(decoding_order_list, 0) loss_stack = torch.cat(loss_list, 0) loss_per_residue_stack = torch.cat(loss_per_residue_list, 0) loss_XY_stack = torch.cat(loss_XY_list, 0) rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1] rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask) native_seq = "".join( [restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()] ) seq_np = np.array(list(native_seq)) seq_out_str = [] for mask in protein_dict["mask_c"]: seq_out_str += list(seq_np[mask.cpu().numpy()]) seq_out_str += [args.fasta_seq_separation] seq_out_str = "".join(seq_out_str)[:-1] output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa" output_backbones = base_folder + "/backbones/" output_packed = base_folder + "/packed/" output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt" out_dict = {} out_dict["generated_sequences"] = S_stack.cpu() out_dict["sampling_probs"] = sampling_probs_stack.cpu() out_dict["log_probs"] = log_probs_stack.cpu() out_dict["decoding_order"] = decoding_order_stack.cpu() out_dict["native_sequence"] = feature_dict["S"][0].cpu() out_dict["mask"] = feature_dict["mask"][0].cpu() out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu() out_dict["seed"] = seed out_dict["temperature"] = args.temperature if args.save_stats: torch.save(out_dict, output_stats_path) if args.pack_side_chains: if args.verbose: print("Packing side chains...") feature_dict_ = featurize( protein_dict, cutoff_for_score=8.0, use_atom_context=args.pack_with_ligand_context, number_of_ligand_atoms=16, model_type="ligand_mpnn", ) sc_feature_dict = copy.deepcopy(feature_dict_) B = args.batch_size for k, v in sc_feature_dict.items(): if k != "S": try: num_dim = len(v.shape) if num_dim == 2: sc_feature_dict[k] = v.repeat(B, 1) elif num_dim == 3: sc_feature_dict[k] = v.repeat(B, 1, 1) elif num_dim == 4: sc_feature_dict[k] = v.repeat(B, 1, 1, 1) elif num_dim == 5: sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1) except: pass X_stack_list = [] X_m_stack_list = [] b_factor_stack_list = [] for _ in range(args.number_of_packs_per_design): X_list = [] X_m_list = [] b_factor_list = [] for c in range(args.number_of_batches): sc_feature_dict["S"] = S_list[c] sc_dict = pack_side_chains( sc_feature_dict, model_sc, args.sc_num_denoising_steps, args.sc_num_samples, args.repack_everything, ) X_list.append(sc_dict["X"]) X_m_list.append(sc_dict["X_m"]) b_factor_list.append(sc_dict["b_factors"]) X_stack = torch.cat(X_list, 0) X_m_stack = torch.cat(X_m_list, 0) b_factor_stack = torch.cat(b_factor_list, 0) X_stack_list.append(X_stack) X_m_stack_list.append(X_m_stack) b_factor_stack_list.append(b_factor_stack) with open(output_fasta, "w") as f: f.write( ">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format( name, args.temperature, seed, torch.sum(rec_mask).cpu().numpy(), torch.sum(combined_mask[:1]).cpu().numpy(), bool(args.ligand_mpnn_use_atom_context), float(args.ligand_mpnn_cutoff_for_score), args.batch_size, args.number_of_batches, checkpoint_path, seq_out_str, ) ) for ix in range(S_stack.shape[0]): ix_suffix = ix if not args.zero_indexed: ix_suffix += 1 seq_rec_print = np.format_float_positional( rec_stack[ix].cpu().numpy(), unique=False, precision=4 ) loss_np = np.format_float_positional( np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4 ) loss_XY_np = np.format_float_positional( np.exp(-loss_XY_stack[ix].cpu().numpy()), unique=False, precision=4, ) seq = "".join( [restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()] ) # write new sequences into PDB with backbone coordinates seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[ None, ].repeat(4, 1) bfactor_prody = ( loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1) ) backbone.setResnames(seq_prody) backbone.setBetas( np.exp(-bfactor_prody) * (bfactor_prody > 0.01).astype(np.float32) ) if other_atoms: writePDB( output_backbones + name + "_" + str(ix_suffix) + args.file_ending + ".pdb", backbone + other_atoms, ) else: writePDB( output_backbones + name + "_" + str(ix_suffix) + args.file_ending + ".pdb", backbone, ) # write full PDB files if args.pack_side_chains: for c_pack in range(args.number_of_packs_per_design): X_stack = X_stack_list[c_pack] X_m_stack = X_m_stack_list[c_pack] b_factor_stack = b_factor_stack_list[c_pack] write_full_PDB( output_packed + name + args.packed_suffix + "_" + str(ix_suffix) + "_" + str(c_pack + 1) + args.file_ending + ".pdb", X_stack[ix].cpu().numpy(), X_m_stack[ix].cpu().numpy(), b_factor_stack[ix].cpu().numpy(), feature_dict["R_idx_original"][0].cpu().numpy(), protein_dict["chain_letters"], S_stack[ix].cpu().numpy(), other_atoms=other_atoms, icodes=icodes, force_hetatm=args.force_hetatm, ) # ----- # write fasta lines seq_np = np.array(list(seq)) seq_out_str = [] for mask in protein_dict["mask_c"]: seq_out_str += list(seq_np[mask.cpu().numpy()]) seq_out_str += [args.fasta_seq_separation] seq_out_str = "".join(seq_out_str)[:-1] if ix == S_stack.shape[0] - 1: # final 2 lines f.write( ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format( name, ix_suffix, args.temperature, seed, loss_np, loss_XY_np, seq_rec_print, seq_out_str, ) ) else: f.write( ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format( name, ix_suffix, args.temperature, seed, loss_np, loss_XY_np, seq_rec_print, seq_out_str, ) ) if __name__ == "__main__": argparser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) argparser.add_argument( "--model_type", type=str, default="protein_mpnn", help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn", ) # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids argparser.add_argument( "--checkpoint_protein_mpnn", type=str, default="./model_params/proteinmpnn_v_48_020.pt", help="Path to model weights.", ) argparser.add_argument( "--checkpoint_ligand_mpnn", type=str, default="./model_params/ligandmpnn_v_32_010_25.pt", help="Path to model weights.", ) argparser.add_argument( "--checkpoint_per_residue_label_membrane_mpnn", type=str, default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt", help="Path to model weights.", ) argparser.add_argument( "--checkpoint_global_label_membrane_mpnn", type=str, default="./model_params/global_label_membrane_mpnn_v_48_020.pt", help="Path to model weights.", ) argparser.add_argument( "--checkpoint_soluble_mpnn", type=str, default="./model_params/solublempnn_v_48_020.pt", help="Path to model weights.", ) argparser.add_argument( "--fasta_seq_separation", type=str, default=":", help="Symbol to use between sequences from different chains", ) argparser.add_argument("--verbose", type=int, default=1, help="Print stuff") argparser.add_argument( "--pdb_path", type=str, default="", help="Path to the input PDB." ) argparser.add_argument( "--pdb_path_multi", type=str, default="", help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.", ) argparser.add_argument( "--fixed_residues", type=str, default="", help="Provide fixed residues, A12 A13 A14 B2 B25", ) argparser.add_argument( "--fixed_residues_multi", type=str, default="", help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", ) argparser.add_argument( "--redesigned_residues", type=str, default="", help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25", ) argparser.add_argument( "--redesigned_residues_multi", type=str, default="", help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}", ) argparser.add_argument( "--bias_AA", type=str, default="", help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'", ) argparser.add_argument( "--bias_AA_per_residue", type=str, default="", help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}", ) argparser.add_argument( "--bias_AA_per_residue_multi", type=str, default="", help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}", ) argparser.add_argument( "--omit_AA", type=str, default="", help="Bias generation of amino acids, e.g. 'ACG'", ) argparser.add_argument( "--omit_AA_per_residue", type=str, default="", help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}", ) argparser.add_argument( "--omit_AA_per_residue_multi", type=str, default="", help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}", ) argparser.add_argument( "--symmetry_residues", type=str, default="", help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'", ) argparser.add_argument( "--symmetry_weights", type=str, default="", help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'", ) argparser.add_argument( "--homo_oligomer", type=int, default=0, help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.", ) argparser.add_argument( "--out_folder", type=str, help="Path to a folder to output sequences, e.g. /home/out/", ) argparser.add_argument( "--file_ending", type=str, default="", help="adding_string_to_the_end" ) argparser.add_argument( "--zero_indexed", type=str, default=0, help="1 - to start output PDB numbering with 0", ) argparser.add_argument( "--seed", type=int, default=0, help="Set seed for torch, numpy, and python random.", ) argparser.add_argument( "--batch_size", type=int, default=1, help="Number of sequence to generate per one pass.", ) argparser.add_argument( "--number_of_batches", type=int, default=1, help="Number of times to design sequence using a chosen batch size.", ) argparser.add_argument( "--temperature", type=float, default=0.1, help="Temperature to sample sequences.", ) argparser.add_argument( "--save_stats", type=int, default=0, help="Save output statistics" ) argparser.add_argument( "--ligand_mpnn_use_atom_context", type=int, default=1, help="1 - use atom context, 0 - do not use atom context.", ) argparser.add_argument( "--ligand_mpnn_cutoff_for_score", type=float, default=8.0, help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.", ) argparser.add_argument( "--ligand_mpnn_use_side_chain_context", type=int, default=0, help="Flag to use side chain atoms as ligand context for the fixed residues", ) argparser.add_argument( "--chains_to_design", type=str, default="", help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'", ) argparser.add_argument( "--parse_these_chains_only", type=str, default="", help="Provide chains letters for parsing backbones, 'A,B,C,F'", ) argparser.add_argument( "--transmembrane_buried", type=str, default="", help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", ) argparser.add_argument( "--transmembrane_interface", type=str, default="", help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25", ) argparser.add_argument( "--global_transmembrane_label", type=int, default=0, help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble", ) argparser.add_argument( "--parse_atoms_with_zero_occupancy", type=int, default=0, help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy", ) argparser.add_argument( "--pack_side_chains", type=int, default=0, help="1 - to run side chain packer, 0 - do not run it", ) argparser.add_argument( "--checkpoint_path_sc", type=str, default="./model_params/ligandmpnn_sc_v_32_002_16.pt", help="Path to model weights.", ) argparser.add_argument( "--number_of_packs_per_design", type=int, default=4, help="Number of independent side chain packing samples to return per design", ) argparser.add_argument( "--sc_num_denoising_steps", type=int, default=3, help="Number of denoising/recycling steps to make for side chain packing", ) argparser.add_argument( "--sc_num_samples", type=int, default=16, help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.", ) argparser.add_argument( "--repack_everything", type=int, default=0, help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues", ) argparser.add_argument( "--force_hetatm", type=int, default=0, help="To force ligand atoms to be written as HETATM to PDB file after packing.", ) argparser.add_argument( "--packed_suffix", type=str, default="_packed", help="Suffix for packed PDB paths", ) argparser.add_argument( "--pack_with_ligand_context", type=int, default=1, help="1-pack side chains using ligand context, 0 - do not use it.", ) args = argparser.parse_args() main(args)