Configure ImmuneBuilder pipeline for WES execution
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest - Update input/output paths to S3 (s3://omic/eureka/immunebuilder/) - Remove local mount containerOptions (not needed in k8s) - Update homepage to Gitea repo URL - Clean history to remove large model weight blobs
This commit is contained in:
189
ImmuneBuilder/ABodyBuilder2.py
Normal file
189
ImmuneBuilder/ABodyBuilder2.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from ImmuneBuilder.models import StructureModule
|
||||
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
|
||||
from ImmuneBuilder.refine import refine
|
||||
from ImmuneBuilder.sequence_checks import number_sequences
|
||||
|
||||
embed_dim = {
|
||||
"antibody_model_1":128,
|
||||
"antibody_model_2":256,
|
||||
"antibody_model_3":256,
|
||||
"antibody_model_4":256
|
||||
}
|
||||
|
||||
model_urls = {
|
||||
"antibody_model_1": "https://zenodo.org/record/7258553/files/antibody_model_1?download=1",
|
||||
"antibody_model_2": "https://zenodo.org/record/7258553/files/antibody_model_2?download=1",
|
||||
"antibody_model_3": "https://zenodo.org/record/7258553/files/antibody_model_3?download=1",
|
||||
"antibody_model_4": "https://zenodo.org/record/7258553/files/antibody_model_4?download=1",
|
||||
}
|
||||
|
||||
header = "REMARK ANTIBODY STRUCTURE MODELLED USING ABODYBUILDER2 \n"
|
||||
|
||||
class Antibody:
|
||||
def __init__(self, numbered_sequences, predictions):
|
||||
self.numbered_sequences = numbered_sequences
|
||||
self.atoms = [x[0] for x in predictions]
|
||||
self.encodings = [x[1] for x in predictions]
|
||||
|
||||
with torch.no_grad():
|
||||
traces = torch.stack([x[:,0] for x in self.atoms])
|
||||
self.R,self.t = find_alignment_transform(traces)
|
||||
self.aligned_traces = (traces-self.t) @ self.R
|
||||
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
|
||||
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
|
||||
|
||||
|
||||
def save_single_unrefined(self, filename, index=0):
|
||||
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
|
||||
unrefined = to_pdb(self.numbered_sequences, atoms)
|
||||
|
||||
with open(filename, "w+") as file:
|
||||
file.write(unrefined)
|
||||
|
||||
|
||||
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
|
||||
if dirname is None:
|
||||
dirname="ABodyBuilder2_output"
|
||||
if filename is None:
|
||||
filename="final_model.pdb"
|
||||
os.makedirs(dirname, exist_ok = True)
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
|
||||
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
|
||||
|
||||
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
|
||||
final_filename = os.path.join(dirname, filename)
|
||||
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
|
||||
|
||||
|
||||
def save(self, filename=None,check_for_strained_bonds=True, n_threads=-1):
|
||||
if filename is None:
|
||||
filename = "ABodyBuilder2_output.pdb"
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
|
||||
if not success:
|
||||
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
|
||||
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
|
||||
|
||||
|
||||
class ABodyBuilder2:
|
||||
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme = 'imgt'):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.scheme = numbering_scheme
|
||||
|
||||
if weights_dir is None:
|
||||
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
|
||||
|
||||
self.models = {}
|
||||
for id in model_ids:
|
||||
model_file = f"antibody_model_{id}"
|
||||
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_file]).to(self.device)
|
||||
weights_path = os.path.join(weights_dir, model_file)
|
||||
|
||||
try:
|
||||
if not are_weights_ready(weights_path):
|
||||
print(f"Downloading weights for {model_file}...", flush=True)
|
||||
download_file(model_urls[model_file], weights_path)
|
||||
|
||||
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
|
||||
except Exception as e:
|
||||
print(f"ERROR {model_file} not downloaded or corrupted.", flush=True)
|
||||
raise e
|
||||
|
||||
model.to(torch.get_default_dtype())
|
||||
model.eval()
|
||||
|
||||
self.models[model_file] = model
|
||||
|
||||
|
||||
def predict(self, sequence_dict):
|
||||
numbered_sequences = number_sequences(sequence_dict, scheme = self.scheme)
|
||||
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
|
||||
|
||||
with torch.no_grad():
|
||||
encoding = torch.tensor(get_encoding(sequence_dict), device=self.device, dtype=torch.get_default_dtype())
|
||||
full_seq = sequence_dict["H"] + sequence_dict["L"]
|
||||
outputs = []
|
||||
|
||||
for model_file in self.models:
|
||||
pred = self.models[model_file](encoding, full_seq)
|
||||
outputs.append(pred)
|
||||
|
||||
return Antibody(numbered_sequences, outputs)
|
||||
|
||||
|
||||
def command_line_interface():
|
||||
description="""
|
||||
ABodyBuilder2 \\\ //
|
||||
A Method for Antibody Structure Prediction \\\ //
|
||||
Author: Brennan Abanades Kenyon ||
|
||||
Supervisor: Charlotte Deane ||
|
||||
"""
|
||||
parser = argparse.ArgumentParser(prog="ABodyBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
|
||||
parser.add_argument("-H", "--heavy_sequence", help="Heavy chain amino acid sequence", default=None)
|
||||
parser.add_argument("-L", "--light_sequence", help="Light chain amino acid sequence", default=None)
|
||||
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a heavy amd light chain named H and L", default=None)
|
||||
|
||||
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
|
||||
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
|
||||
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
|
||||
parser.add_argument("-n", "--numbering_scheme", help="The scheme used to number output antibody structures. Available numbering schemes are: imgt, chothia, kabat, aho, wolfguy, martin and raw. Default is imgt.", default='imgt')
|
||||
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
|
||||
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
|
||||
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if (args.heavy_sequence is not None) and (args.light_sequence is not None):
|
||||
seqs = {"H":args.heavy_sequence, "L":args.light_sequence}
|
||||
elif args.fasta_file is not None:
|
||||
seqs = sequence_dict_from_fasta(args.fasta_file)
|
||||
else:
|
||||
raise ValueError("Missing input sequences")
|
||||
|
||||
check_for_strained_bonds = not args.no_sidechain_bond_check
|
||||
|
||||
if args.n_threads > 0:
|
||||
torch.set_num_threads(args.n_threads)
|
||||
|
||||
if args.verbose:
|
||||
print(description, flush=True)
|
||||
print(f"Sequences loaded succesfully.\nHeavy and light chains are:", flush=True)
|
||||
[print(f"{chain}: {seqs[chain]}", flush=True) for chain in "HL"]
|
||||
print("Running sequences through deep learning model...", flush=True)
|
||||
|
||||
try:
|
||||
antibody = ABodyBuilder2(numbering_scheme=args.numbering_scheme).predict(seqs)
|
||||
except AssertionError as e:
|
||||
print(e, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print("Antibody modelled succesfully, starting refinement.", flush=True)
|
||||
|
||||
if args.to_directory:
|
||||
antibody.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
print("Refinement finished. Saving all outputs to directory", flush=True)
|
||||
else:
|
||||
antibody.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
outfile = "ABodyBuilder2_output.pdb" if args.output is None else args.output
|
||||
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)
|
||||
194
ImmuneBuilder/NanoBodyBuilder2.py
Normal file
194
ImmuneBuilder/NanoBodyBuilder2.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
from ImmuneBuilder.models import StructureModule
|
||||
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
|
||||
from ImmuneBuilder.refine import refine
|
||||
from ImmuneBuilder.sequence_checks import number_sequences
|
||||
|
||||
embed_dim = {
|
||||
"nanobody_model_1":128,
|
||||
"nanobody_model_2":256,
|
||||
"nanobody_model_3":256,
|
||||
"nanobody_model_4":256
|
||||
}
|
||||
|
||||
model_urls = {
|
||||
"nanobody_model_1": "https://zenodo.org/record/7258553/files/nanobody_model_1?download=1",
|
||||
"nanobody_model_2": "https://zenodo.org/record/7258553/files/nanobody_model_2?download=1",
|
||||
"nanobody_model_3": "https://zenodo.org/record/7258553/files/nanobody_model_3?download=1",
|
||||
"nanobody_model_4": "https://zenodo.org/record/7258553/files/nanobody_model_4?download=1",
|
||||
}
|
||||
|
||||
header = "REMARK NANOBODY STRUCTURE MODELLED USING NANOBODYBUILDER2 \n"
|
||||
|
||||
class Nanobody:
|
||||
def __init__(self, numbered_sequences, predictions):
|
||||
self.numbered_sequences = numbered_sequences
|
||||
self.atoms = [x[0] for x in predictions]
|
||||
self.encodings = [x[1] for x in predictions]
|
||||
|
||||
with torch.no_grad():
|
||||
traces = torch.stack([x[:,0] for x in self.atoms])
|
||||
self.R,self.t = find_alignment_transform(traces)
|
||||
self.aligned_traces = (traces-self.t) @ self.R
|
||||
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
|
||||
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
|
||||
|
||||
|
||||
def save_single_unrefined(self, filename, index=0):
|
||||
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
|
||||
unrefined = to_pdb(self.numbered_sequences, atoms)
|
||||
|
||||
with open(filename, "w+") as file:
|
||||
file.write(unrefined)
|
||||
|
||||
|
||||
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
|
||||
if dirname is None:
|
||||
dirname="NanoBodyBuilder2_output"
|
||||
if filename is None:
|
||||
filename="final_model.pdb"
|
||||
os.makedirs(dirname, exist_ok = True)
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
|
||||
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
|
||||
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
|
||||
|
||||
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
|
||||
final_filename = os.path.join(dirname, filename)
|
||||
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
|
||||
|
||||
|
||||
def save(self, filename=None, check_for_strained_bonds=True, n_threads=-1):
|
||||
if filename is None:
|
||||
filename = "NanoBodyBuilder2_output.pdb"
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
|
||||
if not success:
|
||||
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
|
||||
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
|
||||
|
||||
|
||||
class NanoBodyBuilder2:
|
||||
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme='imgt'):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.scheme = numbering_scheme
|
||||
if weights_dir is None:
|
||||
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
|
||||
|
||||
self.models = {}
|
||||
for id in model_ids:
|
||||
model_file = f"nanobody_model_{id}"
|
||||
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_file]).to(self.device)
|
||||
weights_path = os.path.join(weights_dir, model_file)
|
||||
|
||||
try:
|
||||
if not are_weights_ready(weights_path):
|
||||
print(f"Downloading weights for {model_file}...", flush=True)
|
||||
download_file(model_urls[model_file], weights_path)
|
||||
|
||||
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
|
||||
except Exception as e:
|
||||
print(f"ERROR: {model_file} not downloaded or corrupted.", flush=True)
|
||||
raise e
|
||||
|
||||
model.eval()
|
||||
model.to(torch.get_default_dtype())
|
||||
|
||||
self.models[model_file] = model
|
||||
|
||||
|
||||
def predict(self, sequence_dict):
|
||||
numbered_sequences = number_sequences(sequence_dict, allowed_species=None, scheme = self.scheme)
|
||||
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
|
||||
|
||||
with torch.no_grad():
|
||||
sequence_dict["L"] = ""
|
||||
encoding = torch.tensor(get_encoding(sequence_dict, "H"), device = self.device, dtype=torch.get_default_dtype())
|
||||
full_seq = sequence_dict["H"]
|
||||
outputs = []
|
||||
|
||||
for model_file in self.models:
|
||||
pred = self.models[model_file](encoding, full_seq)
|
||||
outputs.append(pred)
|
||||
|
||||
numbered_sequences["L"] = []
|
||||
return Nanobody(numbered_sequences, outputs)
|
||||
|
||||
|
||||
def command_line_interface():
|
||||
description="""
|
||||
\/
|
||||
NanoBodyBuilder2 ⊂'l
|
||||
A Method for Nanobody Structure Prediction ll
|
||||
Author: Brennan Abanades Kenyon llama~
|
||||
Supervisor: Charlotte Deane || ||
|
||||
'' ''
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(prog="NanoBodyBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
|
||||
parser.add_argument("-H", "--heavy_sequence", help="VHH amino acid sequence", default=None)
|
||||
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a heavy chain named H", default=None)
|
||||
|
||||
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
|
||||
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
|
||||
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
|
||||
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
|
||||
parser.add_argument("-n", "--numbering_scheme", help="The scheme used to number output nanobody structures. Available numbering schemes are: imgt, chothia, kabat, aho, wolfguy, martin and raw. Default is imgt.", default='imgt')
|
||||
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
|
||||
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.heavy_sequence is not None:
|
||||
seqs = {"H":args.heavy_sequence}
|
||||
elif args.fasta_file is not None:
|
||||
seqs = sequence_dict_from_fasta(args.fasta_file)
|
||||
else:
|
||||
raise ValueError("Missing input sequences")
|
||||
|
||||
check_for_strained_bonds = not args.no_sidechain_bond_check
|
||||
|
||||
if args.n_threads > 0:
|
||||
torch.set_num_threads(args.n_threads)
|
||||
|
||||
if args.verbose:
|
||||
print(description, flush=True)
|
||||
print(f"Sequence loaded succesfully.\nHeavy chain is:", flush=True)
|
||||
print("H: " + seqs["H"], flush=True)
|
||||
print("Running sequence through deep learning model...", flush=True)
|
||||
|
||||
try:
|
||||
antibody = NanoBodyBuilder2(numbering_scheme=args.numbering_scheme).predict(seqs)
|
||||
except AssertionError as e:
|
||||
print(e, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print("Nanobody modelled succesfully, starting refinement.", flush=True)
|
||||
|
||||
if args.to_directory:
|
||||
antibody.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
print("Refinement finished. Saving all outputs to directory", flush=True)
|
||||
else:
|
||||
antibody.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
outfile = "NanoBodyBuilder2_output.pdb" if args.output is None else args.output
|
||||
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)
|
||||
216
ImmuneBuilder/TCRBuilder2.py
Normal file
216
ImmuneBuilder/TCRBuilder2.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from ImmuneBuilder.models import StructureModule
|
||||
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
|
||||
from ImmuneBuilder.refine import refine
|
||||
from ImmuneBuilder.sequence_checks import number_sequences
|
||||
|
||||
embed_dim = {
|
||||
"tcr2_plus_model_1":256,
|
||||
"tcr2_plus_model_2":256,
|
||||
"tcr2_plus_model_3":128,
|
||||
"tcr2_plus_model_4":128,
|
||||
"tcr2_model_1":128,
|
||||
"tcr2_model_2":128,
|
||||
"tcr2_model_3":256,
|
||||
"tcr2_model_4":256,
|
||||
}
|
||||
|
||||
model_urls = {
|
||||
"tcr2_plus_model_1": "https://zenodo.org/records/10892159/files/tcr_model_1?download=1",
|
||||
"tcr2_plus_model_2": "https://zenodo.org/record/10892159/files/tcr_model_2?download=1",
|
||||
"tcr2_plus_model_3": "https://zenodo.org/record/10892159/files/tcr_model_3?download=1",
|
||||
"tcr2_plus_model_4": "https://zenodo.org/record/10892159/files/tcr_model_4?download=1",
|
||||
"tcr2_model_1": "https://zenodo.org/record/7258553/files/tcr_model_1?download=1",
|
||||
"tcr2_model_2": "https://zenodo.org/record/7258553/files/tcr_model_2?download=1",
|
||||
"tcr2_model_3": "https://zenodo.org/record/7258553/files/tcr_model_3?download=1",
|
||||
"tcr2_model_4": "https://zenodo.org/record/7258553/files/tcr_model_4?download=1",
|
||||
}
|
||||
|
||||
class TCR:
|
||||
def __init__(self, numbered_sequences, predictions, header=None):
|
||||
self.numbered_sequences = numbered_sequences
|
||||
self.atoms = [x[0] for x in predictions]
|
||||
self.encodings = [x[1] for x in predictions]
|
||||
if header is not None:
|
||||
self.header = header
|
||||
else:
|
||||
self.header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2+ \n"
|
||||
|
||||
with torch.no_grad():
|
||||
traces = torch.stack([x[:,0] for x in self.atoms])
|
||||
self.R,self.t = find_alignment_transform(traces)
|
||||
self.aligned_traces = (traces-self.t) @ self.R
|
||||
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
|
||||
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
|
||||
|
||||
|
||||
def save_single_unrefined(self, filename, index=0):
|
||||
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
|
||||
unrefined = to_pdb(self.numbered_sequences, atoms, chain_ids="BA")
|
||||
|
||||
with open(filename, "w+") as file:
|
||||
file.write(unrefined)
|
||||
|
||||
|
||||
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
|
||||
if dirname is None:
|
||||
dirname="TCRBuilder2_output"
|
||||
if filename is None:
|
||||
filename="final_model.pdb"
|
||||
os.makedirs(dirname, exist_ok = True)
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
|
||||
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
|
||||
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
|
||||
|
||||
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
|
||||
final_filename = os.path.join(dirname, filename)
|
||||
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[self.header])
|
||||
|
||||
|
||||
def save(self, filename=None, check_for_strained_bonds=True, n_threads=-1):
|
||||
if filename is None:
|
||||
filename = "TCRBuilder2_output.pdb"
|
||||
|
||||
for i in range(len(self.atoms)):
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self.save_single_unrefined(filename, index=self.ranking[i])
|
||||
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
|
||||
if success:
|
||||
break
|
||||
|
||||
if not success:
|
||||
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
|
||||
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[self.header])
|
||||
|
||||
|
||||
class TCRBuilder2:
|
||||
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme='imgt', use_TCRBuilder2_PLUS_weights=True):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.scheme = numbering_scheme
|
||||
self.use_TCRBuilder2_PLUS_weights = use_TCRBuilder2_PLUS_weights
|
||||
if weights_dir is None:
|
||||
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
|
||||
|
||||
self.models = {}
|
||||
for id in model_ids:
|
||||
if use_TCRBuilder2_PLUS_weights:
|
||||
model_file = f"tcr_model_{id}"
|
||||
model_key = f'tcr2_plus_model_{id}'
|
||||
else:
|
||||
model_file = f"tcr2_model_{id}"
|
||||
model_key = f'tcr2_model_{id}'
|
||||
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_key]).to(self.device)
|
||||
weights_path = os.path.join(weights_dir, model_file)
|
||||
|
||||
try:
|
||||
if not are_weights_ready(weights_path):
|
||||
print(f"Downloading weights for {model_file}...", flush=True)
|
||||
download_file(model_urls[model_key], weights_path)
|
||||
|
||||
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
|
||||
except Exception as e:
|
||||
print(f"ERROR: {model_file} not downloaded or corrupted.", flush=True)
|
||||
raise e
|
||||
|
||||
model.eval()
|
||||
model.to(torch.get_default_dtype())
|
||||
|
||||
self.models[model_file] = model
|
||||
|
||||
|
||||
def predict(self, sequence_dict):
|
||||
numbered_sequences = number_sequences(sequence_dict, scheme=self.scheme)
|
||||
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
|
||||
|
||||
with torch.no_grad():
|
||||
sequence_dict = {"H":sequence_dict["B"], "L":sequence_dict["A"]}
|
||||
encoding = torch.tensor(get_encoding(sequence_dict), device = self.device, dtype=torch.get_default_dtype())
|
||||
full_seq = sequence_dict["H"] + sequence_dict["L"]
|
||||
outputs = []
|
||||
|
||||
for model_file in self.models:
|
||||
pred = self.models[model_file](encoding, full_seq)
|
||||
outputs.append(pred)
|
||||
|
||||
if self.use_TCRBuilder2_PLUS_weights:
|
||||
header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2+ \n"
|
||||
else:
|
||||
header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2 \n"
|
||||
|
||||
return TCR(numbered_sequences, outputs, header)
|
||||
|
||||
|
||||
def command_line_interface():
|
||||
description="""
|
||||
TCRBuilder2 || ||
|
||||
A Method for T-Cell Receptor Structure Prediction |VB|VA|
|
||||
Author: Brennan Abanades Kenyon, Nele Quast |CB|CA|
|
||||
Supervisor: Charlotte Deane -------------
|
||||
|
||||
By default TCRBuilder2 will use TCRBuilder2+ weights.
|
||||
To use the original weights, see options.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(prog="TCRBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
|
||||
parser.add_argument("-B", "--beta_sequence", help="Beta chain amino acid sequence", default=None)
|
||||
parser.add_argument("-A", "--alpha_sequence", help="Alpha chain amino acid sequence", default=None)
|
||||
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a beta and alpha chain named B and A", default=None)
|
||||
|
||||
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
|
||||
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
|
||||
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
|
||||
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
|
||||
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
|
||||
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
|
||||
parser.add_argument("-og", "--original_weights", help="use original TCRBuilder2 weights instead of TCRBuilder2+", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if (args.beta_sequence is not None) and (args.alpha_sequence is not None):
|
||||
seqs = {"B":args.beta_sequence, "A":args.alpha_sequence}
|
||||
elif args.fasta_file is not None:
|
||||
seqs = sequence_dict_from_fasta(args.fasta_file)
|
||||
else:
|
||||
raise ValueError("Missing input sequences")
|
||||
|
||||
check_for_strained_bonds = not args.no_sidechain_bond_check
|
||||
|
||||
if args.n_threads > 0:
|
||||
torch.set_num_threads(args.n_threads)
|
||||
|
||||
if args.verbose:
|
||||
print(description, flush=True)
|
||||
print(f"Sequences loaded succesfully.\nAlpha and Beta chains are:", flush=True)
|
||||
[print(f"{chain}: {seqs[chain]}", flush=True) for chain in "AB"]
|
||||
print("Running sequences through deep learning model...", flush=True)
|
||||
|
||||
try:
|
||||
tcr = TCRBuilder2(use_TCRBuilder2_PLUS_weights=not args.original_weights).predict(seqs)
|
||||
except AssertionError as e:
|
||||
print(e, flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print("TCR modelled succesfully, starting refinement.", flush=True)
|
||||
|
||||
if args.to_directory:
|
||||
tcr.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
print("Refinement finished. Saving all outputs to directory", flush=True)
|
||||
else:
|
||||
tcr.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
|
||||
if args.verbose:
|
||||
outfile = "TCRBuilder2_output.pdb" if args.output is None else args.output
|
||||
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)
|
||||
3
ImmuneBuilder/__init__.py
Normal file
3
ImmuneBuilder/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ImmuneBuilder.ABodyBuilder2 import ABodyBuilder2
|
||||
from ImmuneBuilder.TCRBuilder2 import TCRBuilder2
|
||||
from ImmuneBuilder.NanoBodyBuilder2 import NanoBodyBuilder2
|
||||
294
ImmuneBuilder/constants.py
Normal file
294
ImmuneBuilder/constants.py
Normal file
@@ -0,0 +1,294 @@
|
||||
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
||||
|
||||
# Residue names definition:
|
||||
restype_1to3 = {
|
||||
'A': 'ALA',
|
||||
'R': 'ARG',
|
||||
'N': 'ASN',
|
||||
'D': 'ASP',
|
||||
'C': 'CYS',
|
||||
'Q': 'GLN',
|
||||
'E': 'GLU',
|
||||
'G': 'GLY',
|
||||
'H': 'HIS',
|
||||
'I': 'ILE',
|
||||
'L': 'LEU',
|
||||
'K': 'LYS',
|
||||
'M': 'MET',
|
||||
'F': 'PHE',
|
||||
'P': 'PRO',
|
||||
'S': 'SER',
|
||||
'T': 'THR',
|
||||
'W': 'TRP',
|
||||
'Y': 'TYR',
|
||||
'V': 'VAL',
|
||||
}
|
||||
|
||||
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
||||
|
||||
# How atoms are sorted in MLAb:
|
||||
|
||||
residue_atoms = {
|
||||
'A': ['CA', 'N', 'C', 'CB', 'O'],
|
||||
'C': ['CA', 'N', 'C', 'CB', 'O', 'SG'],
|
||||
'D': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'OD1', 'OD2'],
|
||||
'E': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'OE1', 'OE2'],
|
||||
'F': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
|
||||
'G': ['CA', 'N', 'C', 'CA', 'O'], # G has no CB so I am padding it with CA so the Os are aligned
|
||||
'H': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD2', 'CE1', 'ND1', 'NE2'],
|
||||
'I': ['CA', 'N', 'C', 'CB', 'O', 'CG1', 'CG2', 'CD1'],
|
||||
'K': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'CE', 'NZ'],
|
||||
'L': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2'],
|
||||
'M': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CE', 'SD'],
|
||||
'N': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'ND2', 'OD1'],
|
||||
'P': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD'],
|
||||
'Q': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'NE2', 'OE1'],
|
||||
'R': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'CZ', 'NE', 'NH1', 'NH2'],
|
||||
'S': ['CA', 'N', 'C', 'CB', 'O', 'OG'],
|
||||
'T': ['CA', 'N', 'C', 'CB', 'O', 'CG2', 'OG1'],
|
||||
'V': ['CA', 'N', 'C', 'CB', 'O', 'CG1', 'CG2'],
|
||||
'W': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'NE1'],
|
||||
'Y': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH']}
|
||||
|
||||
residue_atoms_mask = {res: len(residue_atoms[res]) * [True] + (14 - len(residue_atoms[res])) * [False] for res in
|
||||
residue_atoms}
|
||||
|
||||
atom_types = [
|
||||
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
|
||||
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
|
||||
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
|
||||
'CZ3', 'NZ', 'OXT'
|
||||
]
|
||||
|
||||
# Position of atoms in each ref frame
|
||||
|
||||
rigid_group_atom_positions2 = {'A': {'C': [0, (1.526, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.529, -0.774, -1.205)],
|
||||
'N': [0, (-0.525, 1.363, 0.0)],
|
||||
'O': [3, (-0.627, 1.062, 0.0)]},
|
||||
'C': {'C': [0, (1.524, 0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.519, -0.773, -1.212)],
|
||||
'N': [0, (-0.522, 1.362, -0.0)],
|
||||
'O': [3, (-0.625, 1.062, -0.0)],
|
||||
'SG': [4, (-0.728, 1.653, 0.0)]},
|
||||
'D': {'C': [0, (1.527, 0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.526, -0.778, -1.208)],
|
||||
'CG': [4, (-0.593, 1.398, -0.0)],
|
||||
'N': [0, (-0.525, 1.362, -0.0)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)],
|
||||
'OD1': [5, (-0.61, 1.091, 0.0)],
|
||||
'OD2': [5, (-0.592, -1.101, 0.003)]},
|
||||
'E': {'C': [0, (1.526, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.526, -0.781, -1.207)],
|
||||
'CD': [5, (-0.6, 1.397, 0.0)],
|
||||
'CG': [4, (-0.615, 1.392, 0.0)],
|
||||
'N': [0, (-0.528, 1.361, 0.0)],
|
||||
'O': [3, (-0.626, 1.062, 0.0)],
|
||||
'OE1': [6, (-0.607, 1.095, -0.0)],
|
||||
'OE2': [6, (-0.589, -1.104, 0.001)]},
|
||||
'F': {'C': [0, (1.524, 0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.525, -0.776, -1.212)],
|
||||
'CD1': [5, (-0.709, 1.195, -0.0)],
|
||||
'CD2': [5, (-0.706, -1.196, 0.0)],
|
||||
'CE1': [5, (-2.102, 1.198, -0.0)],
|
||||
'CE2': [5, (-2.098, -1.201, -0.0)],
|
||||
'CG': [4, (-0.607, 1.377, 0.0)],
|
||||
'CZ': [5, (-2.794, -0.003, 0.001)],
|
||||
'N': [0, (-0.518, 1.363, 0.0)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)]},
|
||||
'G': {'C': [0, (1.517, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'N': [0, (-0.572, 1.337, 0.0)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)]},
|
||||
'H': {'C': [0, (1.525, 0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.525, -0.778, -1.208)],
|
||||
'CD2': [5, (-0.889, -1.021, -0.003)],
|
||||
'CE1': [5, (-2.03, 0.851, -0.002)],
|
||||
'CG': [4, (-0.6, 1.37, -0.0)],
|
||||
'N': [0, (-0.527, 1.36, 0.0)],
|
||||
'ND1': [5, (-0.744, 1.16, -0.0)],
|
||||
'NE2': [5, (-2.145, -0.466, -0.004)],
|
||||
'O': [3, (-0.625, 1.063, 0.0)]},
|
||||
'I': {'C': [0, (1.527, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.536, -0.793, -1.213)],
|
||||
'CD1': [5, (-0.619, 1.391, 0.0)],
|
||||
'CG1': [4, (-0.534, 1.437, -0.0)],
|
||||
'CG2': [4, (-0.54, -0.785, 1.199)],
|
||||
'N': [0, (-0.493, 1.373, -0.0)],
|
||||
'O': [3, (-0.627, 1.062, -0.0)]},
|
||||
'K': {'C': [0, (1.526, 0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.524, -0.778, -1.208)],
|
||||
'CD': [5, (-0.559, 1.417, 0.0)],
|
||||
'CE': [6, (-0.56, 1.416, 0.0)],
|
||||
'CG': [4, (-0.619, 1.39, 0.0)],
|
||||
'N': [0, (-0.526, 1.362, -0.0)],
|
||||
'NZ': [7, (-0.554, 1.387, 0.0)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)]},
|
||||
'L': {'C': [0, (1.525, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.522, -0.773, -1.214)],
|
||||
'CD1': [5, (-0.53, 1.43, -0.0)],
|
||||
'CD2': [5, (-0.535, -0.774, -1.2)],
|
||||
'CG': [4, (-0.678, 1.371, 0.0)],
|
||||
'N': [0, (-0.52, 1.363, 0.0)],
|
||||
'O': [3, (-0.625, 1.063, -0.0)]},
|
||||
'M': {'C': [0, (1.525, 0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.523, -0.776, -1.21)],
|
||||
'CE': [6, (-0.32, 1.786, -0.0)],
|
||||
'CG': [4, (-0.613, 1.391, -0.0)],
|
||||
'N': [0, (-0.521, 1.364, -0.0)],
|
||||
'O': [3, (-0.625, 1.062, -0.0)],
|
||||
'SD': [5, (-0.703, 1.695, 0.0)]},
|
||||
'N': {'C': [0, (1.526, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.531, -0.787, -1.2)],
|
||||
'CG': [4, (-0.584, 1.399, 0.0)],
|
||||
'N': [0, (-0.536, 1.357, 0.0)],
|
||||
'ND2': [5, (-0.593, -1.188, -0.001)],
|
||||
'O': [3, (-0.625, 1.062, 0.0)],
|
||||
'OD1': [5, (-0.633, 1.059, 0.0)]},
|
||||
'P': {'C': [0, (1.527, -0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.546, -0.611, -1.293)],
|
||||
'CD': [5, (-0.477, 1.424, 0.0)],
|
||||
'CG': [4, (-0.382, 1.445, 0.0)],
|
||||
'N': [0, (-0.566, 1.351, -0.0)],
|
||||
'O': [3, (-0.621, 1.066, 0.0)]},
|
||||
'Q': {'C': [0, (1.526, 0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.525, -0.779, -1.207)],
|
||||
'CD': [5, (-0.587, 1.399, -0.0)],
|
||||
'CG': [4, (-0.615, 1.393, 0.0)],
|
||||
'N': [0, (-0.526, 1.361, -0.0)],
|
||||
'NE2': [6, (-0.593, -1.189, 0.001)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)],
|
||||
'OE1': [6, (-0.634, 1.06, 0.0)]},
|
||||
'R': {'C': [0, (1.525, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.524, -0.778, -1.209)],
|
||||
'CD': [5, (-0.564, 1.414, 0.0)],
|
||||
'CG': [4, (-0.616, 1.39, -0.0)],
|
||||
'CZ': [7, (-0.758, 1.093, -0.0)],
|
||||
'N': [0, (-0.524, 1.362, -0.0)],
|
||||
'NE': [6, (-0.539, 1.357, -0.0)],
|
||||
'NH1': [7, (-0.206, 2.301, 0.0)],
|
||||
'NH2': [7, (-2.078, 0.978, -0.0)],
|
||||
'O': [3, (-0.626, 1.062, 0.0)]},
|
||||
'S': {'C': [0, (1.525, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.518, -0.777, -1.211)],
|
||||
'N': [0, (-0.529, 1.36, -0.0)],
|
||||
'O': [3, (-0.626, 1.062, -0.0)],
|
||||
'OG': [4, (-0.503, 1.325, 0.0)]},
|
||||
'T': {'C': [0, (1.526, 0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.516, -0.793, -1.215)],
|
||||
'CG2': [4, (-0.55, -0.718, 1.228)],
|
||||
'N': [0, (-0.517, 1.364, 0.0)],
|
||||
'O': [3, (-0.626, 1.062, 0.0)],
|
||||
'OG1': [4, (-0.472, 1.353, 0.0)]},
|
||||
'V': {'C': [0, (1.527, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.533, -0.795, -1.213)],
|
||||
'CG1': [4, (-0.54, 1.429, -0.0)],
|
||||
'CG2': [4, (-0.533, -0.776, -1.203)],
|
||||
'N': [0, (-0.494, 1.373, -0.0)],
|
||||
'O': [3, (-0.627, 1.062, -0.0)]},
|
||||
'W': {'C': [0, (1.525, -0.0, 0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.523, -0.776, -1.212)],
|
||||
'CD1': [5, (-0.824, 1.091, 0.0)],
|
||||
'CD2': [5, (-0.854, -1.148, -0.005)],
|
||||
'CE2': [5, (-2.186, -0.678, -0.007)],
|
||||
'CE3': [5, (-0.622, -2.53, -0.007)],
|
||||
'CG': [4, (-0.609, 1.37, -0.0)],
|
||||
'CH2': [5, (-3.028, -2.89, -0.013)],
|
||||
'CZ2': [5, (-3.283, -1.543, -0.011)],
|
||||
'CZ3': [5, (-1.715, -3.389, -0.011)],
|
||||
'N': [0, (-0.521, 1.363, 0.0)],
|
||||
'NE1': [5, (-2.14, 0.69, -0.004)],
|
||||
'O': [3, (-0.627, 1.062, 0.0)]},
|
||||
'Y': {'C': [0, (1.524, -0.0, -0.0)],
|
||||
'CA': [0, (0.0, 0.0, 0.0)],
|
||||
'CB': [0, (-0.522, -0.776, -1.213)],
|
||||
'CD1': [5, (-0.716, 1.195, -0.0)],
|
||||
'CD2': [5, (-0.713, -1.194, -0.001)],
|
||||
'CE1': [5, (-2.107, 1.2, -0.002)],
|
||||
'CE2': [5, (-2.104, -1.201, -0.003)],
|
||||
'CG': [4, (-0.607, 1.382, -0.0)],
|
||||
'CZ': [5, (-2.791, -0.001, -0.003)],
|
||||
'N': [0, (-0.522, 1.362, 0.0)],
|
||||
'O': [3, (-0.627, 1.062, -0.0)],
|
||||
'OH': [5, (-4.168, -0.002, -0.005)]}}
|
||||
|
||||
chi_angles_atoms = {'A': [],
|
||||
'C': [['N', 'CA', 'CB', 'SG']],
|
||||
'D': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
|
||||
'E': [['N', 'CA', 'CB', 'CG'],
|
||||
['CA', 'CB', 'CG', 'CD'],
|
||||
['CB', 'CG', 'CD', 'OE1']],
|
||||
'F': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
||||
'G': [],
|
||||
'H': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
|
||||
'I': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
|
||||
'K': [['N', 'CA', 'CB', 'CG'],
|
||||
['CA', 'CB', 'CG', 'CD'],
|
||||
['CB', 'CG', 'CD', 'CE'],
|
||||
['CG', 'CD', 'CE', 'NZ']],
|
||||
'L': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
||||
'M': [['N', 'CA', 'CB', 'CG'],
|
||||
['CA', 'CB', 'CG', 'SD'],
|
||||
['CB', 'CG', 'SD', 'CE']],
|
||||
'N': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
|
||||
'P': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
|
||||
'Q': [['N', 'CA', 'CB', 'CG'],
|
||||
['CA', 'CB', 'CG', 'CD'],
|
||||
['CB', 'CG', 'CD', 'OE1']],
|
||||
'R': [['N', 'CA', 'CB', 'CG'],
|
||||
['CA', 'CB', 'CG', 'CD'],
|
||||
['CB', 'CG', 'CD', 'NE'],
|
||||
['CG', 'CD', 'NE', 'CZ']],
|
||||
'S': [['N', 'CA', 'CB', 'OG']],
|
||||
'T': [['N', 'CA', 'CB', 'OG1']],
|
||||
'V': [['N', 'CA', 'CB', 'CG1']],
|
||||
'W': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
||||
'Y': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']]}
|
||||
|
||||
chi_angles_positions = {}
|
||||
for r in residue_atoms:
|
||||
chi_angles_positions[r] = []
|
||||
for angs in chi_angles_atoms[r]:
|
||||
chi_angles_positions[r].append([residue_atoms[r].index(atom) for atom in angs])
|
||||
|
||||
chi2_centers = {x: chi_angles_atoms[x][1][-2] if len(chi_angles_atoms[x]) > 1 else "CA" for x in chi_angles_atoms}
|
||||
chi3_centers = {x: chi_angles_atoms[x][2][-2] if len(chi_angles_atoms[x]) > 2 else "CA" for x in chi_angles_atoms}
|
||||
chi4_centers = {x: chi_angles_atoms[x][3][-2] if len(chi_angles_atoms[x]) > 3 else "CA" for x in chi_angles_atoms}
|
||||
|
||||
rel_pos = {
|
||||
x: [rigid_group_atom_positions2[x][residue_atoms[x][atom_id]] if len(residue_atoms[x]) > atom_id else [0, (0, 0, 0)]
|
||||
for atom_id in range(14)] for x in rigid_group_atom_positions2}
|
||||
|
||||
van_der_waals_radius = {
|
||||
"C": 1.7,
|
||||
"N": 1.55,
|
||||
"O": 1.52,
|
||||
"S": 1.8,
|
||||
}
|
||||
|
||||
residue_van_der_waals_radius = {
|
||||
x: [van_der_waals_radius[atom[0]] for atom in residue_atoms[x]] + (14 - len(residue_atoms[x])) * ([0]) for x in
|
||||
residue_atoms}
|
||||
|
||||
valid_rigids = {x: len(chi_angles_atoms[x]) + 2 for x in chi_angles_atoms}
|
||||
|
||||
r2n = {x: i for i, x in enumerate(restypes)}
|
||||
res_to_num = lambda x: r2n[x] if x in r2n else len(r2n)
|
||||
230
ImmuneBuilder/models.py
Normal file
230
ImmuneBuilder/models.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from ImmuneBuilder.rigids import Rigid, Rot, rigid_body_identity, vec_from_tensor, global_frames_from_bb_frame_and_torsion_angles, all_atoms_from_global_reference_frames
|
||||
|
||||
class InvariantPointAttention(torch.nn.Module):
|
||||
def __init__(self, node_dim, edge_dim, heads=12, head_dim=16, n_query_points=4, n_value_points=8, **kwargs):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = head_dim
|
||||
self.n_query_points = n_query_points
|
||||
|
||||
node_scalar_attention_inner_dim = heads * head_dim
|
||||
node_vector_attention_inner_dim = 3 * n_query_points * heads
|
||||
node_vector_attention_value_dim = 3 * n_value_points * heads
|
||||
after_final_cat_dim = heads * edge_dim + heads * head_dim + heads * n_value_points * 4
|
||||
|
||||
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
|
||||
self.point_weight = torch.nn.Parameter(point_weight_init_value)
|
||||
|
||||
self.to_scalar_qkv = torch.nn.Linear(node_dim, 3 * node_scalar_attention_inner_dim, bias=False)
|
||||
self.to_vector_qk = torch.nn.Linear(node_dim, 2 * node_vector_attention_inner_dim, bias=False)
|
||||
self.to_vector_v = torch.nn.Linear(node_dim, node_vector_attention_value_dim, bias=False)
|
||||
self.to_scalar_edge_attention_bias = torch.nn.Linear(edge_dim, heads, bias=False)
|
||||
self.final_linear = torch.nn.Linear(after_final_cat_dim, node_dim)
|
||||
|
||||
with torch.no_grad():
|
||||
self.final_linear.weight.fill_(0.0)
|
||||
self.final_linear.bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, edge_features, rigid):
|
||||
# Classic attention on nodes
|
||||
scalar_qkv = self.to_scalar_qkv(node_features).chunk(3, dim=-1)
|
||||
scalar_q, scalar_k, scalar_v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=self.heads), scalar_qkv)
|
||||
node_scalar = torch.einsum('h i d, h j d -> h i j', scalar_q, scalar_k) * self.head_dim ** (-1 / 2)
|
||||
|
||||
# Linear bias on edges
|
||||
edge_bias = rearrange(self.to_scalar_edge_attention_bias(edge_features), 'i j h -> h i j')
|
||||
|
||||
# Reference frame attention
|
||||
wc = (2 / self.n_query_points) ** (1 / 2) / 6
|
||||
vector_qk = self.to_vector_qk(node_features).chunk(2, dim=-1)
|
||||
vector_q, vector_k = map(lambda x: vec_from_tensor(rearrange(x, 'n (h p d) -> h n p d', h=self.heads, d=3)),
|
||||
vector_qk)
|
||||
rigid_ = rigid.unsqueeze(0).unsqueeze(-1) # add head and point dimension to rigids
|
||||
|
||||
global_vector_k = rigid_ @ vector_k
|
||||
global_vector_q = rigid_ @ vector_q
|
||||
global_frame_distance = wc * global_vector_q.unsqueeze(-2).dist(global_vector_k.unsqueeze(-3)).sum(
|
||||
-1) * rearrange(self.point_weight, "h -> h () ()")
|
||||
|
||||
# Combining attentions
|
||||
attention_matrix = (3 ** (-1 / 2) * (node_scalar + edge_bias - global_frame_distance)).softmax(-1)
|
||||
|
||||
# Obtaining outputs
|
||||
edge_output = (rearrange(attention_matrix, 'h i j -> i h () j') * rearrange(edge_features,
|
||||
'i j d -> i () d j')).sum(-1)
|
||||
scalar_node_output = torch.einsum('h i j, h j d -> i h d', attention_matrix, scalar_v)
|
||||
|
||||
vector_v = vec_from_tensor(
|
||||
rearrange(self.to_vector_v(node_features), 'n (h p d) -> h n p d', h=self.heads, d=3))
|
||||
global_vector_v = rigid_ @ vector_v
|
||||
attended_global_vector_v = global_vector_v.map(
|
||||
lambda x: torch.einsum('h i j, h j p -> h i p', attention_matrix, x))
|
||||
vector_node_output = rigid_.inv() @ attended_global_vector_v
|
||||
vector_node_output = torch.stack(
|
||||
[vector_node_output.norm(), vector_node_output.x, vector_node_output.y, vector_node_output.z], dim=-1)
|
||||
|
||||
# Concatenate along heads and points
|
||||
edge_output = rearrange(edge_output, 'n h d -> n (h d)')
|
||||
scalar_node_output = rearrange(scalar_node_output, 'n h d -> n (h d)')
|
||||
vector_node_output = rearrange(vector_node_output, 'h n p d -> n (h p d)')
|
||||
|
||||
combined = torch.cat([edge_output, scalar_node_output, vector_node_output], dim=-1)
|
||||
|
||||
return node_features + self.final_linear(combined)
|
||||
|
||||
|
||||
class BackboneUpdate(torch.nn.Module):
|
||||
def __init__(self, node_dim):
|
||||
super().__init__()
|
||||
|
||||
self.to_correction = torch.nn.Linear(node_dim, 6)
|
||||
|
||||
def forward(self, node_features, update_mask=None):
|
||||
# Predict quaternions and translation vector
|
||||
rot, t = self.to_correction(node_features).chunk(2, dim=-1)
|
||||
|
||||
# I may not want to update all residues
|
||||
if update_mask is not None:
|
||||
rot = update_mask[:, None] * rot
|
||||
t = update_mask[:, None] * t
|
||||
|
||||
# Normalize quaternions
|
||||
norm = (1 + rot.pow(2).sum(-1, keepdim=True)).pow(1 / 2)
|
||||
b, c, d = (rot / norm).chunk(3, dim=-1)
|
||||
a = 1 / norm
|
||||
a, b, c, d = a.squeeze(-1), b.squeeze(-1), c.squeeze(-1), d.squeeze(-1)
|
||||
|
||||
# Make rotation matrix from quaternions
|
||||
R = Rot(
|
||||
(a ** 2 + b ** 2 - c ** 2 - d ** 2), (2 * b * c - 2 * a * d), (2 * b * d + 2 * a * c),
|
||||
(2 * b * c + 2 * a * d), (a ** 2 - b ** 2 + c ** 2 - d ** 2), (2 * c * d - 2 * a * b),
|
||||
(2 * b * d - 2 * a * c), (2 * c * d + 2 * a * b), (a ** 2 - b ** 2 - c ** 2 + d ** 2)
|
||||
)
|
||||
|
||||
return Rigid(vec_from_tensor(t), R)
|
||||
|
||||
|
||||
class TorsionAngles(torch.nn.Module):
|
||||
def __init__(self, node_dim):
|
||||
super().__init__()
|
||||
self.residual1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim)
|
||||
)
|
||||
|
||||
self.residual2 = torch.nn.Sequential(
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim)
|
||||
)
|
||||
|
||||
self.final_pred = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 10)
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.residual1[-1].weight.fill_(0.0)
|
||||
self.residual2[-1].weight.fill_(0.0)
|
||||
self.residual1[-1].bias.fill_(0.0)
|
||||
self.residual2[-1].bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, s_i):
|
||||
full_feat = torch.cat([node_features, s_i], axis=-1)
|
||||
|
||||
full_feat = full_feat + self.residual1(full_feat)
|
||||
full_feat = full_feat + self.residual2(full_feat)
|
||||
torsions = rearrange(self.final_pred(full_feat), "i (t d) -> i t d", d=2)
|
||||
norm = torch.norm(torsions, dim=-1, keepdim=True)
|
||||
|
||||
return torsions / norm, norm
|
||||
|
||||
|
||||
class StructureUpdate(torch.nn.Module):
|
||||
def __init__(self, node_dim, edge_dim, dropout=0.0, **kwargs):
|
||||
super().__init__()
|
||||
self.IPA = InvariantPointAttention(node_dim, edge_dim, **kwargs)
|
||||
self.norm1 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(dropout),
|
||||
torch.nn.LayerNorm(node_dim)
|
||||
)
|
||||
self.norm2 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(dropout),
|
||||
torch.nn.LayerNorm(node_dim)
|
||||
)
|
||||
self.residual = torch.nn.Sequential(
|
||||
torch.nn.Linear(node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, node_dim)
|
||||
)
|
||||
|
||||
self.torsion_angles = TorsionAngles(node_dim)
|
||||
self.backbone_update = BackboneUpdate(node_dim)
|
||||
|
||||
with torch.no_grad():
|
||||
self.residual[-1].weight.fill_(0.0)
|
||||
self.residual[-1].bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, edge_features, rigid_pred, update_mask=None):
|
||||
s_i = self.IPA(node_features, edge_features, rigid_pred)
|
||||
s_i = self.norm1(s_i)
|
||||
s_i = s_i + self.residual(s_i)
|
||||
s_i = self.norm2(s_i)
|
||||
rigid_new = rigid_pred @ self.backbone_update(s_i, update_mask)
|
||||
|
||||
return s_i, rigid_new
|
||||
|
||||
|
||||
class StructureModule(torch.nn.Module):
|
||||
def __init__(self, node_dim=23, n_layers=8, rel_pos_dim=64, embed_dim=128, **kwargs):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.rel_pos_dim = rel_pos_dim
|
||||
self.node_embed = torch.nn.Linear(node_dim, embed_dim)
|
||||
self.edge_embed = torch.nn.Linear(2 * rel_pos_dim + 1, embed_dim - 1)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[StructureUpdate(node_dim=embed_dim,
|
||||
edge_dim=embed_dim,
|
||||
propagate_rotation_gradient=(i == n_layers - 1),
|
||||
**kwargs)
|
||||
for i in range(n_layers)])
|
||||
|
||||
def forward(self, node_features, sequence):
|
||||
rigid_in = rigid_body_identity(len(sequence)).to(node_features.device)
|
||||
relative_positions = (torch.arange(node_features.shape[-2])[None] -
|
||||
torch.arange(node_features.shape[-2])[:, None])
|
||||
relative_positions = relative_positions.clamp(min=-self.rel_pos_dim, max=self.rel_pos_dim) + self.rel_pos_dim
|
||||
|
||||
rel_pos_embeddings = torch.nn.functional.one_hot(relative_positions, num_classes=2 * self.rel_pos_dim + 1)
|
||||
rel_pos_embeddings = rel_pos_embeddings.to(dtype=node_features.dtype, device=node_features.device)
|
||||
rel_pos_embeddings = self.edge_embed(rel_pos_embeddings)
|
||||
|
||||
new_node_features = self.node_embed(node_features)
|
||||
|
||||
for layer in self.layers:
|
||||
edge_features = torch.cat(
|
||||
[rigid_in.origin.unsqueeze(-1).dist(rigid_in.origin).unsqueeze(-1), rel_pos_embeddings], dim=-1)
|
||||
new_node_features, rigid_in = layer(new_node_features, edge_features, rigid_in)
|
||||
|
||||
torsions, _ = self.layers[-1].torsion_angles(self.node_embed(node_features), new_node_features)
|
||||
|
||||
all_reference_frames = global_frames_from_bb_frame_and_torsion_angles(rigid_in, torsions, sequence)
|
||||
all_atoms = all_atoms_from_global_reference_frames(all_reference_frames, sequence)
|
||||
|
||||
# Remove atoms of side chains with outrageous clashes
|
||||
ds = torch.linalg.norm(all_atoms[None,:,None] - all_atoms[:,None,:,None], axis = -1)
|
||||
ds[torch.isnan(ds) | (ds==0.0)] = 10
|
||||
min_ds = ds.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0]
|
||||
all_atoms[min_ds < 0.2, 5:, :] = float("Nan")
|
||||
|
||||
return all_atoms, new_node_features
|
||||
354
ImmuneBuilder/refine.py
Normal file
354
ImmuneBuilder/refine.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import pdbfixer
|
||||
import os
|
||||
import numpy as np
|
||||
from openmm import app, LangevinIntegrator, CustomExternalForce, CustomTorsionForce, OpenMMException, Platform, unit
|
||||
from scipy import spatial
|
||||
import logging
|
||||
logging.disable()
|
||||
|
||||
ENERGY = unit.kilocalories_per_mole
|
||||
LENGTH = unit.angstroms
|
||||
spring_unit = ENERGY / (LENGTH ** 2)
|
||||
|
||||
CLASH_CUTOFF = 0.63
|
||||
|
||||
# Atomic radii for various atom types.
|
||||
atom_radii = {"C": 1.70, "N": 1.55, 'O': 1.52, 'S': 1.80}
|
||||
|
||||
# Sum of van-der-waals radii
|
||||
radii_sums = dict(
|
||||
[(i + j, (atom_radii[i] + atom_radii[j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
|
||||
# Clash_cutoff-based radii values
|
||||
cutoffs = dict(
|
||||
[(i + j, CLASH_CUTOFF * (radii_sums[i + j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
|
||||
|
||||
# Using amber14 recommended protein force field
|
||||
forcefield = app.ForceField("amber14/protein.ff14SB.xml")
|
||||
|
||||
|
||||
def refine(input_file, output_file, check_for_strained_bonds=True, tries=3, n=6, n_threads=-1):
|
||||
for i in range(tries):
|
||||
if refine_once(input_file, output_file, check_for_strained_bonds=check_for_strained_bonds, n=n, n_threads=n_threads):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def refine_once(input_file, output_file, check_for_strained_bonds=True, n=6, n_threads=-1):
|
||||
k1s = [2.5,1,0.5,0.25,0.1,0.001]
|
||||
k2s = [2.5,5,7.5,15,25,50]
|
||||
success = False
|
||||
|
||||
fixer = pdbfixer.PDBFixer(input_file)
|
||||
|
||||
fixer.findMissingResidues()
|
||||
fixer.findMissingAtoms()
|
||||
fixer.addMissingAtoms()
|
||||
|
||||
k1 = k1s[0]
|
||||
k2 = -1 if cis_check(fixer.topology, fixer.positions) else k2s[0]
|
||||
|
||||
topology, positions = fixer.topology, fixer.positions
|
||||
|
||||
for i in range(n):
|
||||
try:
|
||||
simulation = minimize_energy(topology, positions, k1=k1, k2 = k2, n_threads=n_threads)
|
||||
topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
|
||||
acceptable_bonds, trans_peptide_bonds = bond_check(topology, positions), cis_check(topology, positions)
|
||||
except OpenMMException as e:
|
||||
if (i == n-1) and ("positions" not in locals()):
|
||||
print("OpenMM failed to refine {}".format(input_file), flush=True)
|
||||
raise e
|
||||
else:
|
||||
topology, positions = fixer.topology, fixer.positions
|
||||
continue
|
||||
|
||||
# If peptide bonds are the wrong length, decrease the strength of the positional restraint
|
||||
if not acceptable_bonds:
|
||||
k1 = k1s[min(i, len(k1s)-1)]
|
||||
|
||||
# If there are still cis isomers in the model, increase the force to fix these
|
||||
if not trans_peptide_bonds:
|
||||
k2 = k2s[min(i, len(k2s)-1)]
|
||||
else:
|
||||
k2 = -1
|
||||
|
||||
if acceptable_bonds and trans_peptide_bonds:
|
||||
# If peptide bond lengths and torsions are okay, check and fix the chirality.
|
||||
try:
|
||||
simulation = chirality_fixer(simulation)
|
||||
topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
|
||||
except OpenMMException as e:
|
||||
topology, positions = fixer.topology, fixer.positions
|
||||
continue
|
||||
|
||||
if check_for_strained_bonds:
|
||||
# If all other checks pass, check and fix strained sidechain bonds:
|
||||
try:
|
||||
strained_bonds = strained_sidechain_bonds_check(topology, positions)
|
||||
if len(strained_bonds) > 0:
|
||||
needs_recheck = True
|
||||
topology, positions = strained_sidechain_bonds_fixer(strained_bonds, topology, positions, n_threads=n_threads)
|
||||
else:
|
||||
needs_recheck = False
|
||||
except OpenMMException as e:
|
||||
topology, positions = fixer.topology, fixer.positions
|
||||
continue
|
||||
else:
|
||||
needs_recheck = False
|
||||
|
||||
# If it passes all the tests, we are done
|
||||
tests = bond_check(topology, positions) and cis_check(topology, positions)
|
||||
if needs_recheck:
|
||||
tests = tests and strained_sidechain_bonds_check(topology, positions)
|
||||
if tests and stereo_check(topology, positions) and clash_check(topology, positions):
|
||||
success = True
|
||||
break
|
||||
|
||||
with open(output_file, "w") as out_handle:
|
||||
app.PDBFile.writeFile(topology, positions, out_handle, keepIds=True)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def minimize_energy(topology, positions, k1=2.5, k2=2.5, n_threads=-1):
|
||||
# Fill in the gaps with OpenMM Modeller
|
||||
modeller = app.Modeller(topology, positions)
|
||||
modeller.addHydrogens(forcefield)
|
||||
|
||||
# Set up force field
|
||||
system = forcefield.createSystem(modeller.topology)
|
||||
|
||||
# Keep atoms close to initial prediction
|
||||
force = CustomExternalForce("k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
|
||||
force.addGlobalParameter("k", k1 * spring_unit)
|
||||
for p in ["x0", "y0", "z0"]:
|
||||
force.addPerParticleParameter(p)
|
||||
|
||||
for residue in modeller.topology.residues():
|
||||
for atom in residue.atoms():
|
||||
if atom.name in ["CA", "CB", "N", "C"]:
|
||||
force.addParticle(atom.index, modeller.positions[atom.index])
|
||||
|
||||
system.addForce(force)
|
||||
|
||||
if k2 > 0.0:
|
||||
cis_force = CustomTorsionForce("10*k2*(1+cos(theta))^2")
|
||||
cis_force.addGlobalParameter("k2", k2 * ENERGY)
|
||||
|
||||
for chain in modeller.topology.chains():
|
||||
residues = [res for res in chain.residues()]
|
||||
relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
|
||||
for i in range(1,len(residues)):
|
||||
if residues[i].name == "PRO":
|
||||
continue
|
||||
|
||||
resi = relevant_atoms[i-1]
|
||||
n_resi = relevant_atoms[i]
|
||||
cis_force.addTorsion(resi["CA"], resi["C"], n_resi["N"], n_resi["CA"])
|
||||
|
||||
system.addForce(cis_force)
|
||||
|
||||
# Set up integrator
|
||||
integrator = LangevinIntegrator(0, 0.01, 0.0)
|
||||
|
||||
# Set up the simulation
|
||||
if n_threads > 0:
|
||||
# Set number of threads used by OpenMM
|
||||
platform = Platform.getPlatformByName('CPU')
|
||||
simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads': str(n_threads)})
|
||||
else:
|
||||
simulation = app.Simulation(modeller.topology, system, integrator)
|
||||
simulation.context.setPositions(modeller.positions)
|
||||
|
||||
# Minimize the energy
|
||||
simulation.minimizeEnergy()
|
||||
|
||||
return simulation
|
||||
|
||||
|
||||
def chirality_fixer(simulation):
|
||||
topology = simulation.topology
|
||||
positions = simulation.context.getState(getPositions=True).getPositions()
|
||||
|
||||
d_stereoisomers = []
|
||||
for residue in topology.residues():
|
||||
if residue.name == "GLY":
|
||||
continue
|
||||
|
||||
atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
|
||||
vectors = [positions[atom_indices[i]] - positions[atom_indices["CA"]] for i in ["N", "C", "CB"]]
|
||||
|
||||
if np.dot(np.cross(vectors[0], vectors[1]), vectors[2]) < .0*LENGTH**3:
|
||||
# If it is a D-stereoisomer then flip its H atom
|
||||
indices = {x.name:x.index for x in residue.atoms() if x.name in ["HA", "CA"]}
|
||||
positions[indices["HA"]] = 2*positions[indices["CA"]] - positions[indices["HA"]]
|
||||
|
||||
# Fix the H atom in place
|
||||
particle_mass = simulation.system.getParticleMass(indices["HA"])
|
||||
simulation.system.setParticleMass(indices["HA"], 0.0)
|
||||
d_stereoisomers.append((indices["HA"], particle_mass))
|
||||
|
||||
if len(d_stereoisomers) > 0:
|
||||
simulation.context.setPositions(positions)
|
||||
|
||||
# Minimize the energy with the evil hydrogens fixed
|
||||
simulation.minimizeEnergy()
|
||||
|
||||
# Minimize the energy letting the hydrogens move
|
||||
for atom in d_stereoisomers:
|
||||
simulation.system.setParticleMass(*atom)
|
||||
simulation.minimizeEnergy()
|
||||
|
||||
return simulation
|
||||
|
||||
|
||||
def bond_check(topology, positions):
|
||||
for chain in topology.chains():
|
||||
residues = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "C"]} for res in chain.residues()]
|
||||
for i in range(len(residues)-1):
|
||||
# For simplicity we only check the peptide bond length as the rest should be correct as they are hard coded
|
||||
v = np.linalg.norm(positions[residues[i]["C"]] - positions[residues[i+1]["N"]])
|
||||
if abs(v - 1.329*LENGTH) > 0.1*LENGTH:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def cis_bond(p0,p1,p2,p3):
|
||||
ab = p1-p0
|
||||
cd = p2-p1
|
||||
db = p3-p2
|
||||
|
||||
u = np.cross(-ab, cd)
|
||||
v = np.cross(db, cd)
|
||||
return np.dot(u,v) > 0
|
||||
|
||||
|
||||
def cis_check(topology, positions):
|
||||
pos = np.array(positions.value_in_unit(LENGTH))
|
||||
for chain in topology.chains():
|
||||
residues = [res for res in chain.residues()]
|
||||
relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
|
||||
for i in range(1,len(residues)):
|
||||
if residues[i].name == "PRO":
|
||||
continue
|
||||
|
||||
resi = relevant_atoms[i-1]
|
||||
n_resi = relevant_atoms[i]
|
||||
p0,p1,p2,p3 = pos[resi["CA"]],pos[resi["C"]],pos[n_resi["N"]],pos[n_resi["CA"]]
|
||||
if cis_bond(p0,p1,p2,p3):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def stereo_check(topology, positions):
|
||||
pos = np.array(positions.value_in_unit(LENGTH))
|
||||
for residue in topology.residues():
|
||||
if residue.name == "GLY":
|
||||
continue
|
||||
|
||||
atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
|
||||
vectors = pos[[atom_indices[i] for i in ["N", "C", "CB"]]] - pos[atom_indices["CA"]]
|
||||
|
||||
if np.linalg.det(vectors) < 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def clash_check(topology, positions):
|
||||
heavies = [x for x in topology.atoms() if x.element.symbol != "H"]
|
||||
pos = np.array(positions.value_in_unit(LENGTH))[[x.index for x in heavies]]
|
||||
|
||||
tree = spatial.KDTree(pos)
|
||||
pairs = tree.query_pairs(r=max(cutoffs.values()))
|
||||
|
||||
for pair in pairs:
|
||||
atom_i, atom_j = heavies[pair[0]], heavies[pair[1]]
|
||||
|
||||
if atom_i.residue.index == atom_j.residue.index:
|
||||
continue
|
||||
elif (atom_i.name == "C" and atom_j.name == "N") or (atom_i.name == "N" and atom_j.name == "C"):
|
||||
continue
|
||||
|
||||
atom_distance = np.linalg.norm(pos[pair[0]] - pos[pair[1]])
|
||||
|
||||
if (atom_i.name == "SG" and atom_j.name == "SG") and atom_distance > 1.88:
|
||||
continue
|
||||
|
||||
elif atom_distance < (cutoffs[atom_i.element.symbol + atom_j.element.symbol]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def strained_sidechain_bonds_check(topology, positions):
|
||||
atoms = list(topology.atoms())
|
||||
pos = np.array(positions.value_in_unit(LENGTH))
|
||||
|
||||
system = forcefield.createSystem(topology)
|
||||
bonds = [x for x in system.getForces() if type(x).__name__ == "HarmonicBondForce"][0]
|
||||
|
||||
# Initialise arrays for bond details
|
||||
n_bonds = bonds.getNumBonds()
|
||||
i = np.empty(n_bonds, dtype=int)
|
||||
j = np.empty(n_bonds, dtype=int)
|
||||
k = np.empty(n_bonds)
|
||||
x0 = np.empty(n_bonds)
|
||||
|
||||
# Extract bond details to arrays
|
||||
for n in range(n_bonds):
|
||||
i[n],j[n],_x0,_k = bonds.getBondParameters(n)
|
||||
k[n] = _k.value_in_unit(spring_unit)
|
||||
x0[n] = _x0.value_in_unit(LENGTH)
|
||||
|
||||
# Check if there are any abnormally strained bond
|
||||
distance = np.linalg.norm(pos[i] - pos[j], axis=-1)
|
||||
check = k*(distance - x0)**2 > 100
|
||||
|
||||
# Return residues with strained bonds if any
|
||||
return [atoms[x].residue for x in i[check]]
|
||||
|
||||
|
||||
def strained_sidechain_bonds_fixer(strained_residues, topology, positions, n_threads=-1):
|
||||
# Delete all atoms except the main chain for badly refined residues.
|
||||
bb_atoms = ["N","CA","C"]
|
||||
bad_side_chains = sum([[atom for atom in residue.atoms() if atom.name not in bb_atoms] for residue in strained_residues],[])
|
||||
modeller = app.Modeller(topology, positions)
|
||||
modeller.delete(bad_side_chains)
|
||||
|
||||
# Save model with deleted side chains to temporary file.
|
||||
random_number = str(int(np.random.rand()*10**8))
|
||||
tmp_file = f"side_chain_fix_tmp_{random_number}.pdb"
|
||||
with open(tmp_file,"w") as handle:
|
||||
app.PDBFile.writeFile(modeller.topology, modeller.positions, handle, keepIds=True)
|
||||
|
||||
# Load model into pdbfixer
|
||||
fixer = pdbfixer.PDBFixer(tmp_file)
|
||||
os.remove(tmp_file)
|
||||
|
||||
# Repair deleted side chains
|
||||
fixer.findMissingResidues()
|
||||
fixer.findMissingAtoms()
|
||||
fixer.addMissingAtoms()
|
||||
|
||||
# Fill in the gaps with OpenMM Modeller
|
||||
modeller = app.Modeller(fixer.topology, fixer.positions)
|
||||
modeller.addHydrogens(forcefield)
|
||||
|
||||
# Set up force field
|
||||
system = forcefield.createSystem(modeller.topology)
|
||||
|
||||
# Set up integrator
|
||||
integrator = LangevinIntegrator(0, 0.01, 0.0)
|
||||
|
||||
# Set up the simulation
|
||||
if n_threads > 0:
|
||||
# Set number of threads used by OpenMM
|
||||
platform = Platform.getPlatformByName('CPU')
|
||||
simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads', str(n_threads)})
|
||||
else:
|
||||
simulation = app.Simulation(modeller.topology, system, integrator)
|
||||
simulation.context.setPositions(modeller.positions)
|
||||
|
||||
# Minimize the energy
|
||||
simulation.minimizeEnergy()
|
||||
|
||||
return simulation.topology, simulation.context.getState(getPositions=True).getPositions()
|
||||
313
ImmuneBuilder/rigids.py
Normal file
313
ImmuneBuilder/rigids.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import torch
|
||||
from ImmuneBuilder.constants import rigid_group_atom_positions2, chi2_centers, chi3_centers, chi4_centers, rel_pos, \
|
||||
residue_atoms_mask
|
||||
|
||||
class Vector:
|
||||
def __init__(self, x, y, z):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.z = z
|
||||
self.shape = x.shape
|
||||
assert (x.shape == y.shape) and (y.shape == z.shape), "x y and z should have the same shape"
|
||||
|
||||
def __add__(self, vec):
|
||||
return Vector(vec.x + self.x, vec.y + self.y, vec.z + self.z)
|
||||
|
||||
def __sub__(self, vec):
|
||||
return Vector(-vec.x + self.x, -vec.y + self.y, -vec.z + self.z)
|
||||
|
||||
def __mul__(self, param):
|
||||
return Vector(param * self.x, param * self.y, param * self.z)
|
||||
|
||||
def __matmul__(self, vec):
|
||||
return vec.x * self.x + vec.y * self.y + vec.z * self.z
|
||||
|
||||
def norm(self):
|
||||
return (self.x ** 2 + self.y ** 2 + self.z ** 2 + 1e-8) ** (1 / 2)
|
||||
|
||||
def cross(self, other):
|
||||
a = (self.y * other.z - self.z * other.y)
|
||||
b = (self.z * other.x - self.x * other.z)
|
||||
c = (self.x * other.y - self.y * other.x)
|
||||
return Vector(a, b, c)
|
||||
|
||||
def dist(self, other):
|
||||
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2 + 1e-8) ** (1 / 2)
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return Vector(self.x.unsqueeze(dim), self.y.unsqueeze(dim), self.z.unsqueeze(dim))
|
||||
|
||||
def squeeze(self, dim):
|
||||
return Vector(self.x.squeeze(dim), self.y.squeeze(dim), self.z.squeeze(dim))
|
||||
|
||||
def map(self, func):
|
||||
return Vector(func(self.x), func(self.y), func(self.z))
|
||||
|
||||
def to(self, device):
|
||||
return Vector(self.x.to(device), self.y.to(device), self.z.to(device))
|
||||
|
||||
def __str__(self):
|
||||
return "Vector(x={},\ny={},\nz={})\n".format(self.x, self.y, self.z)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Vector(self.x[key], self.y[key], self.z[key])
|
||||
|
||||
|
||||
class Rot:
|
||||
def __init__(self, xx, xy, xz, yx, yy, yz, zx, zy, zz):
|
||||
self.xx = xx
|
||||
self.xy = xy
|
||||
self.xz = xz
|
||||
self.yx = yx
|
||||
self.yy = yy
|
||||
self.yz = yz
|
||||
self.zx = zx
|
||||
self.zy = zy
|
||||
self.zz = zz
|
||||
self.shape = xx.shape
|
||||
|
||||
def __matmul__(self, other):
|
||||
if isinstance(other, Vector):
|
||||
return Vector(
|
||||
other.x * self.xx + other.y * self.xy + other.z * self.xz,
|
||||
other.x * self.yx + other.y * self.yy + other.z * self.yz,
|
||||
other.x * self.zx + other.y * self.zy + other.z * self.zz)
|
||||
|
||||
if isinstance(other, Rot):
|
||||
return Rot(
|
||||
xx=self.xx * other.xx + self.xy * other.yx + self.xz * other.zx,
|
||||
xy=self.xx * other.xy + self.xy * other.yy + self.xz * other.zy,
|
||||
xz=self.xx * other.xz + self.xy * other.yz + self.xz * other.zz,
|
||||
yx=self.yx * other.xx + self.yy * other.yx + self.yz * other.zx,
|
||||
yy=self.yx * other.xy + self.yy * other.yy + self.yz * other.zy,
|
||||
yz=self.yx * other.xz + self.yy * other.yz + self.yz * other.zz,
|
||||
zx=self.zx * other.xx + self.zy * other.yx + self.zz * other.zx,
|
||||
zy=self.zx * other.xy + self.zy * other.yy + self.zz * other.zy,
|
||||
zz=self.zx * other.xz + self.zy * other.yz + self.zz * other.zz,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Matmul against {}".format(type(other)))
|
||||
|
||||
def inv(self):
|
||||
return Rot(
|
||||
xx=self.xx, xy=self.yx, xz=self.zx,
|
||||
yx=self.xy, yy=self.yy, yz=self.zy,
|
||||
zx=self.xz, zy=self.yz, zz=self.zz
|
||||
)
|
||||
|
||||
def det(self):
|
||||
return self.xx * self.yy * self.zz + self.xy * self.yz * self.zx + self.yx * self.zy * self.xz - self.xz * self.yy * self.zx - self.xy * self.yx * self.zz - self.xx * self.zy * self.yz
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return Rot(
|
||||
self.xx.unsqueeze(dim=dim), self.xy.unsqueeze(dim=dim), self.xz.unsqueeze(dim=dim),
|
||||
self.yx.unsqueeze(dim=dim), self.yy.unsqueeze(dim=dim), self.yz.unsqueeze(dim=dim),
|
||||
self.zx.unsqueeze(dim=dim), self.zy.unsqueeze(dim=dim), self.zz.unsqueeze(dim=dim)
|
||||
)
|
||||
|
||||
def squeeze(self, dim):
|
||||
return Rot(
|
||||
self.xx.squeeze(dim=dim), self.xy.squeeze(dim=dim), self.xz.squeeze(dim=dim),
|
||||
self.yx.squeeze(dim=dim), self.yy.squeeze(dim=dim), self.yz.squeeze(dim=dim),
|
||||
self.zx.squeeze(dim=dim), self.zy.squeeze(dim=dim), self.zz.squeeze(dim=dim)
|
||||
)
|
||||
|
||||
def detach(self):
|
||||
return Rot(
|
||||
self.xx.detach(), self.xy.detach(), self.xz.detach(),
|
||||
self.yx.detach(), self.yy.detach(), self.yz.detach(),
|
||||
self.zx.detach(), self.zy.detach(), self.zz.detach()
|
||||
)
|
||||
|
||||
def to(self, device):
|
||||
return Rot(
|
||||
self.xx.to(device), self.xy.to(device), self.xz.to(device),
|
||||
self.yx.to(device), self.yy.to(device), self.yz.to(device),
|
||||
self.zx.to(device), self.zy.to(device), self.zz.to(device)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Rot(xx={},\nxy={},\nxz={},\nyx={},\nyy={},\nyz={},\nzx={},\nzy={},\nzz={})\n".format(self.xx, self.xy,
|
||||
self.xz, self.yx,
|
||||
self.yy, self.yz,
|
||||
self.zx, self.zy,
|
||||
self.zz)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Rot(
|
||||
self.xx[key], self.xy[key], self.xz[key],
|
||||
self.yx[key], self.yy[key], self.yz[key],
|
||||
self.zx[key], self.zy[key], self.zz[key]
|
||||
)
|
||||
|
||||
|
||||
class Rigid:
|
||||
def __init__(self, origin, rot):
|
||||
self.origin = origin
|
||||
self.rot = rot
|
||||
self.shape = self.origin.shape
|
||||
|
||||
def __matmul__(self, other):
|
||||
if isinstance(other, Vector):
|
||||
return self.rot @ other + self.origin
|
||||
elif isinstance(other, Rigid):
|
||||
return Rigid(self.rot @ other.origin + self.origin, self.rot @ other.rot)
|
||||
else:
|
||||
raise TypeError(f"can't multiply rigid by object of type {type(other)}")
|
||||
|
||||
def inv(self):
|
||||
inv_rot = self.rot.inv()
|
||||
t = inv_rot @ self.origin
|
||||
return Rigid(Vector(-t.x, -t.y, -t.z), inv_rot)
|
||||
|
||||
def unsqueeze(self, dim=None):
|
||||
return Rigid(self.origin.unsqueeze(dim=dim), self.rot.unsqueeze(dim=dim))
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
return Rigid(self.origin.squeeze(dim=dim), self.rot.squeeze(dim=dim))
|
||||
|
||||
def to(self, device):
|
||||
return Rigid(self.origin.to(device), self.rot.to(device))
|
||||
|
||||
def __str__(self):
|
||||
return "Rigid(origin={},\nrot={})".format(self.origin, self.rot)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Rigid(self.origin[key], self.rot[key])
|
||||
|
||||
|
||||
def rigid_body_identity(shape):
|
||||
return Rigid(Vector(*3 * [torch.zeros(shape)]),
|
||||
Rot(torch.ones(shape), *3 * [torch.zeros(shape)], torch.ones(shape), *3 * [torch.zeros(shape)],
|
||||
torch.ones(shape)))
|
||||
|
||||
def vec_from_tensor(tens):
|
||||
assert tens.shape[-1] == 3, "What dimension you in?"
|
||||
return Vector(tens[..., 0], tens[..., 1], tens[..., 2])
|
||||
|
||||
|
||||
def rigid_from_three_points(origin, y_x_plane, x_axis):
|
||||
v1 = x_axis - origin
|
||||
v2 = y_x_plane - origin
|
||||
|
||||
v1 *= 1 / v1.norm()
|
||||
v2 = v2 - v1 * (v1 @ v2)
|
||||
v2 *= 1 / v2.norm()
|
||||
v3 = v1.cross(v2)
|
||||
rot = Rot(v1.x, v2.x, v3.x, v1.y, v2.y, v3.y, v1.z, v2.z, v3.z)
|
||||
return Rigid(origin, rot)
|
||||
|
||||
|
||||
def rigid_from_tensor(tens):
|
||||
assert (tens.shape[-1] == 3), "I want 3D points"
|
||||
return rigid_from_three_points(vec_from_tensor(tens[..., 0, :]), vec_from_tensor(tens[..., 1, :]),
|
||||
vec_from_tensor(tens[..., 2, :]))
|
||||
|
||||
def stack_rigids(rigids, **kwargs):
|
||||
# Probably best to avoid using very much
|
||||
stacked_origin = Vector(torch.stack([rig.origin.x for rig in rigids], **kwargs),
|
||||
torch.stack([rig.origin.y for rig in rigids], **kwargs),
|
||||
torch.stack([rig.origin.z for rig in rigids], **kwargs))
|
||||
stacked_rot = Rot(
|
||||
torch.stack([rig.rot.xx for rig in rigids], **kwargs), torch.stack([rig.rot.xy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.xz for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.yx for rig in rigids], **kwargs), torch.stack([rig.rot.yy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.yz for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.zx for rig in rigids], **kwargs), torch.stack([rig.rot.zy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.zz for rig in rigids], **kwargs),
|
||||
)
|
||||
return Rigid(stacked_origin, stacked_rot)
|
||||
|
||||
|
||||
def rotate_x_axis_to_new_vector(new_vector):
|
||||
# Extract coordinates
|
||||
c, b, a = new_vector[..., 0], new_vector[..., 1], new_vector[..., 2]
|
||||
|
||||
# Normalize
|
||||
n = (c ** 2 + a ** 2 + b ** 2 + 1e-16) ** (1 / 2)
|
||||
a, b, c = a / n, b / n, -c / n
|
||||
|
||||
# Set new origin
|
||||
new_origin = vec_from_tensor(torch.zeros_like(new_vector))
|
||||
|
||||
# Rotate x-axis to point old origin to new one
|
||||
k = (1 - c) / (a ** 2 + b ** 2 + 1e-8)
|
||||
new_rot = Rot(-c, b, -a, b, 1 - k * b ** 2, a * b * k, a, -a * b * k, k * a ** 2 - 1)
|
||||
|
||||
return Rigid(new_origin, new_rot)
|
||||
|
||||
|
||||
def rigid_transformation_from_torsion_angles(torsion_angles, distance_to_new_origin):
|
||||
dev = torsion_angles.device
|
||||
|
||||
zero = torch.zeros(torsion_angles.shape[:-1]).to(dev)
|
||||
one = torch.ones(torsion_angles.shape[:-1]).to(dev)
|
||||
new_rot = Rot(
|
||||
-one, zero, zero,
|
||||
zero, torsion_angles[..., 0], torsion_angles[..., 1],
|
||||
zero, torsion_angles[..., 1], -torsion_angles[..., 0],
|
||||
)
|
||||
new_origin = Vector(distance_to_new_origin, zero, zero)
|
||||
|
||||
return Rigid(new_origin, new_rot)
|
||||
|
||||
|
||||
def global_frames_from_bb_frame_and_torsion_angles(bb_frame, torsion_angles, seq):
|
||||
dev = bb_frame.origin.x.device
|
||||
|
||||
# We start with psi
|
||||
psi_local_frame_origin = torch.tensor([rel_pos[x][2][1] for x in seq]).to(dev).pow(2).sum(-1).pow(1 / 2)
|
||||
psi_local_frame = rigid_transformation_from_torsion_angles(torsion_angles[:, 0], psi_local_frame_origin)
|
||||
psi_global_frame = bb_frame @ psi_local_frame
|
||||
|
||||
# Now all the chis
|
||||
chi1_local_frame_origin = torch.tensor([rel_pos[x][3][1] for x in seq]).to(dev)
|
||||
chi1_local_frame = rotate_x_axis_to_new_vector(chi1_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 1], chi1_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi1_global_frame = bb_frame @ chi1_local_frame
|
||||
|
||||
chi2_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi2_centers[x]][1] for x in seq]).to(dev)
|
||||
chi2_local_frame = rotate_x_axis_to_new_vector(chi2_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 2], chi2_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi2_global_frame = chi1_global_frame @ chi2_local_frame
|
||||
|
||||
chi3_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi3_centers[x]][1] for x in seq]).to(dev)
|
||||
chi3_local_frame = rotate_x_axis_to_new_vector(chi3_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 3], chi3_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi3_global_frame = chi2_global_frame @ chi3_local_frame
|
||||
|
||||
chi4_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi4_centers[x]][1] for x in seq]).to(dev)
|
||||
chi4_local_frame = rotate_x_axis_to_new_vector(chi4_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 4], chi4_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi4_global_frame = chi3_global_frame @ chi4_local_frame
|
||||
|
||||
return stack_rigids(
|
||||
[bb_frame, psi_global_frame, chi1_global_frame, chi2_global_frame, chi3_global_frame, chi4_global_frame],
|
||||
dim=-1)
|
||||
|
||||
|
||||
def all_atoms_from_global_reference_frames(global_reference_frames, seq):
|
||||
dev = global_reference_frames.origin.x.device
|
||||
|
||||
all_atoms = torch.zeros((len(seq), 14, 3)).to(dev)
|
||||
for atom_pos in range(14):
|
||||
relative_positions = [rel_pos[x][atom_pos][1] for x in seq]
|
||||
local_reference_frame = [max(rel_pos[x][atom_pos][0] - 2, 0) for x in seq]
|
||||
local_reference_frame_mask = torch.tensor([[y == x for y in range(6)] for x in local_reference_frame]).to(dev)
|
||||
global_atom_vector = global_reference_frames[local_reference_frame_mask] @ vec_from_tensor(
|
||||
torch.tensor(relative_positions).to(dev))
|
||||
all_atoms[:, atom_pos] = torch.stack([global_atom_vector.x, global_atom_vector.y, global_atom_vector.z], dim=-1)
|
||||
|
||||
all_atom_mask = torch.tensor([residue_atoms_mask[x] for x in seq]).to(dev)
|
||||
all_atoms[~all_atom_mask] = float("Nan")
|
||||
return all_atoms
|
||||
40
ImmuneBuilder/sequence_checks.py
Normal file
40
ImmuneBuilder/sequence_checks.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from anarci import validate_sequence, anarci, scheme_short_to_long
|
||||
|
||||
def number_single_sequence(sequence, chain, scheme="imgt", allowed_species=['human','mouse']):
|
||||
validate_sequence(sequence)
|
||||
|
||||
try:
|
||||
if scheme != "raw":
|
||||
scheme = scheme_short_to_long[scheme.lower()]
|
||||
except KeyError:
|
||||
raise NotImplementedError(f"Unimplemented numbering scheme: {scheme}")
|
||||
|
||||
assert len(sequence) > 70, f"Sequence too short to be an Ig domain. Please give whole sequence:\n{sequence}"
|
||||
|
||||
allow = [chain]
|
||||
if chain == "L":
|
||||
allow.append("K")
|
||||
|
||||
# Use imgt scheme for numbering sanity checks
|
||||
numbered, _, _ = anarci([("sequence", sequence)], scheme='imgt', output=False, allow=set(allow), allowed_species=allowed_species)
|
||||
|
||||
assert numbered[0], f"Sequence provided as an {chain} chain is not recognised as an {chain} chain."
|
||||
|
||||
output = [x for x in numbered[0][0][0] if x[1] != "-"]
|
||||
numbers = [x[0][0] for x in output]
|
||||
|
||||
# Check for missing residues assuming imgt numbering
|
||||
assert (max(numbers) > 120) and (min(numbers) < 8), f"Sequence missing too many residues to model correctly. Please give whole sequence:\n{sequence}"
|
||||
|
||||
# Renumber once sanity checks done
|
||||
if scheme == "raw":
|
||||
output = [((i+1, " "),x[1]) for i,x in enumerate(output)]
|
||||
elif scheme != 'imgt':
|
||||
numbered, _, _ = anarci([("sequence", sequence)], scheme=scheme, output=False, allow=set(allow), allowed_species=allowed_species)
|
||||
output = [x for x in numbered[0][0][0] if x[1] != "-"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def number_sequences(seqs, scheme="imgt", allowed_species=['human','mouse']):
|
||||
return {chain: number_single_sequence(seqs[chain], chain, scheme=scheme, allowed_species=allowed_species) for chain in seqs}
|
||||
1
ImmuneBuilder/trained_model/.gitkeep
Normal file
1
ImmuneBuilder/trained_model/.gitkeep
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
0
ImmuneBuilder/trained_model/antibody_model_1
Normal file
0
ImmuneBuilder/trained_model/antibody_model_1
Normal file
0
ImmuneBuilder/trained_model/antibody_model_2
Normal file
0
ImmuneBuilder/trained_model/antibody_model_2
Normal file
0
ImmuneBuilder/trained_model/antibody_model_3
Normal file
0
ImmuneBuilder/trained_model/antibody_model_3
Normal file
0
ImmuneBuilder/trained_model/antibody_model_4
Normal file
0
ImmuneBuilder/trained_model/antibody_model_4
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_1
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_1
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_2
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_2
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_3
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_3
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_4
Normal file
0
ImmuneBuilder/trained_model/nanobody_model_4
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_1
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_1
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_2
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_2
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_3
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_3
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_4
Normal file
0
ImmuneBuilder/trained_model/tcr2_model_4
Normal file
0
ImmuneBuilder/trained_model/tcr_model_1
Normal file
0
ImmuneBuilder/trained_model/tcr_model_1
Normal file
0
ImmuneBuilder/trained_model/tcr_model_2
Normal file
0
ImmuneBuilder/trained_model/tcr_model_2
Normal file
0
ImmuneBuilder/trained_model/tcr_model_3
Normal file
0
ImmuneBuilder/trained_model/tcr_model_3
Normal file
0
ImmuneBuilder/trained_model/tcr_model_4
Normal file
0
ImmuneBuilder/trained_model/tcr_model_4
Normal file
136
ImmuneBuilder/util.py
Normal file
136
ImmuneBuilder/util.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from ImmuneBuilder.constants import res_to_num, atom_types, residue_atoms, restype_1to3, restypes
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, 'wb+') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
def get_one_hot(targets, nb_classes=21):
|
||||
res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
|
||||
return res.reshape(list(targets.shape) + [nb_classes])
|
||||
|
||||
|
||||
def get_encoding(sequence_dict, chain_ids="HL"):
|
||||
|
||||
encodings = []
|
||||
|
||||
for j,chain in enumerate(chain_ids):
|
||||
seq = sequence_dict[chain]
|
||||
one_hot_amino = get_one_hot(np.array([res_to_num(x) for x in seq]))
|
||||
one_hot_region = get_one_hot(j * np.ones(len(seq), dtype=int), 2)
|
||||
encoding = np.concatenate([one_hot_amino, one_hot_region], axis=-1)
|
||||
encodings.append(encoding)
|
||||
|
||||
return np.concatenate(encodings, axis = 0)
|
||||
|
||||
|
||||
def find_alignment_transform(traces):
|
||||
centers = traces.mean(-2, keepdim=True)
|
||||
traces = traces - centers
|
||||
|
||||
p1, p2 = traces[0], traces[1:]
|
||||
|
||||
C = torch.einsum("i j k, j l -> i k l", p2, p1)
|
||||
V, _, W = torch.linalg.svd(C)
|
||||
U = torch.matmul(V, W)
|
||||
U = torch.matmul(torch.stack([torch.ones(len(p2), device=U.device),torch.ones(len(p2), device=U.device),torch.linalg.det(U)], dim=1)[:,:,None] * V, W)
|
||||
|
||||
return torch.cat([torch.eye(3, device=U.device)[None], U]), centers
|
||||
|
||||
|
||||
def to_pdb(numbered_sequences, all_atoms, chain_ids = "HL"):
|
||||
atom_index = 0
|
||||
pdb_lines = []
|
||||
record_type = "ATOM"
|
||||
seq = numbered_sequences[chain_ids[0]] + numbered_sequences[chain_ids[1]]
|
||||
chain_index = [0]*len(numbered_sequences[chain_ids[0]]) + [1]*len(numbered_sequences[chain_ids[1]])
|
||||
chain_id = chain_ids[0]
|
||||
|
||||
for i, amino in enumerate(seq):
|
||||
for atom in atom_types:
|
||||
if atom in residue_atoms[amino[1]]:
|
||||
j = residue_atoms[amino[1]].index(atom)
|
||||
pos = all_atoms[i, j]
|
||||
if pos.mean() != pos.mean():
|
||||
continue
|
||||
name = f' {atom}'
|
||||
alt_loc = ''
|
||||
res_name_3 = restype_1to3[amino[1]]
|
||||
if chain_id != chain_ids[chain_index[i]]:
|
||||
chain_id = chain_ids[chain_index[i]]
|
||||
occupancy = 1.00
|
||||
b_factor = 0.00
|
||||
element = atom[0]
|
||||
charge = ''
|
||||
# PDB is a columnar format, every space matters here!
|
||||
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
|
||||
f'{res_name_3:>3} {chain_id:>1}'
|
||||
f'{(amino[0][0]):>4}{amino[0][1]:>1} '
|
||||
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
|
||||
f'{occupancy:>6.2f}{b_factor:>6.2f} '
|
||||
f'{element:>2}{charge:>2}')
|
||||
pdb_lines.append(atom_line)
|
||||
atom_index += 1
|
||||
|
||||
return "\n".join(pdb_lines)
|
||||
|
||||
|
||||
def sequence_dict_from_fasta(fasta_file):
|
||||
out = {}
|
||||
|
||||
with open(fasta_file) as file:
|
||||
txt = file.read().split()
|
||||
|
||||
for i in range(len(txt)-1):
|
||||
if ">" in txt[i]:
|
||||
chain_id = txt[i].split(">")[1]
|
||||
else:
|
||||
continue
|
||||
|
||||
if all(c in restypes for c in txt[i+1]):
|
||||
out[chain_id] = txt[i+1]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def add_errors_as_bfactors(filename, errors, header=[]):
|
||||
|
||||
with open(filename) as file:
|
||||
txt = file.readlines()
|
||||
|
||||
new_txt = [x for x in header]
|
||||
residue_index = -1
|
||||
position = " "
|
||||
|
||||
for line in txt:
|
||||
if line[:4] == "ATOM":
|
||||
current_res = line[22:27]
|
||||
if current_res != position:
|
||||
position = current_res
|
||||
residue_index += 1
|
||||
line = line.replace(" 0.00 ",f"{errors[residue_index]:>6.2f} ")
|
||||
elif "REMARK 1 CREATED WITH OPENMM" in line:
|
||||
line = line.replace(" 1 CREATED WITH OPENMM", "STRUCTURE REFINED USING OPENMM")
|
||||
line = line[:-1] + (81-len(line))*" " + "\n"
|
||||
new_txt.append(line)
|
||||
|
||||
with open(filename, "w+") as file:
|
||||
file.writelines(new_txt)
|
||||
|
||||
|
||||
def are_weights_ready(weights_path):
|
||||
if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
|
||||
return False
|
||||
with open(weights_path, "rb") as f:
|
||||
filestart = str(f.readline())
|
||||
return filestart != "b'EMPTY'"
|
||||
Reference in New Issue
Block a user