Configure ImmuneBuilder pipeline for WES execution
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled

- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest
- Update input/output paths to S3 (s3://omic/eureka/immunebuilder/)
- Remove local mount containerOptions (not needed in k8s)
- Update homepage to Gitea repo URL
- Clean history to remove large model weight blobs
This commit is contained in:
2026-03-16 15:31:38 +01:00
commit 8887cbe592
49 changed files with 8741 additions and 0 deletions

View File

@@ -0,0 +1,189 @@
import torch
import numpy as np
import os
import sys
import argparse
from ImmuneBuilder.models import StructureModule
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
from ImmuneBuilder.refine import refine
from ImmuneBuilder.sequence_checks import number_sequences
embed_dim = {
"antibody_model_1":128,
"antibody_model_2":256,
"antibody_model_3":256,
"antibody_model_4":256
}
model_urls = {
"antibody_model_1": "https://zenodo.org/record/7258553/files/antibody_model_1?download=1",
"antibody_model_2": "https://zenodo.org/record/7258553/files/antibody_model_2?download=1",
"antibody_model_3": "https://zenodo.org/record/7258553/files/antibody_model_3?download=1",
"antibody_model_4": "https://zenodo.org/record/7258553/files/antibody_model_4?download=1",
}
header = "REMARK ANTIBODY STRUCTURE MODELLED USING ABODYBUILDER2 \n"
class Antibody:
def __init__(self, numbered_sequences, predictions):
self.numbered_sequences = numbered_sequences
self.atoms = [x[0] for x in predictions]
self.encodings = [x[1] for x in predictions]
with torch.no_grad():
traces = torch.stack([x[:,0] for x in self.atoms])
self.R,self.t = find_alignment_transform(traces)
self.aligned_traces = (traces-self.t) @ self.R
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
def save_single_unrefined(self, filename, index=0):
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
unrefined = to_pdb(self.numbered_sequences, atoms)
with open(filename, "w+") as file:
file.write(unrefined)
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
if dirname is None:
dirname="ABodyBuilder2_output"
if filename is None:
filename="final_model.pdb"
os.makedirs(dirname, exist_ok = True)
for i in range(len(self.atoms)):
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
final_filename = os.path.join(dirname, filename)
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
def save(self, filename=None,check_for_strained_bonds=True, n_threads=-1):
if filename is None:
filename = "ABodyBuilder2_output.pdb"
for i in range(len(self.atoms)):
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
else:
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
if not success:
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
class ABodyBuilder2:
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme = 'imgt'):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.scheme = numbering_scheme
if weights_dir is None:
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
self.models = {}
for id in model_ids:
model_file = f"antibody_model_{id}"
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_file]).to(self.device)
weights_path = os.path.join(weights_dir, model_file)
try:
if not are_weights_ready(weights_path):
print(f"Downloading weights for {model_file}...", flush=True)
download_file(model_urls[model_file], weights_path)
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
except Exception as e:
print(f"ERROR {model_file} not downloaded or corrupted.", flush=True)
raise e
model.to(torch.get_default_dtype())
model.eval()
self.models[model_file] = model
def predict(self, sequence_dict):
numbered_sequences = number_sequences(sequence_dict, scheme = self.scheme)
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
with torch.no_grad():
encoding = torch.tensor(get_encoding(sequence_dict), device=self.device, dtype=torch.get_default_dtype())
full_seq = sequence_dict["H"] + sequence_dict["L"]
outputs = []
for model_file in self.models:
pred = self.models[model_file](encoding, full_seq)
outputs.append(pred)
return Antibody(numbered_sequences, outputs)
def command_line_interface():
description="""
ABodyBuilder2 \\\ //
A Method for Antibody Structure Prediction \\\ //
Author: Brennan Abanades Kenyon ||
Supervisor: Charlotte Deane ||
"""
parser = argparse.ArgumentParser(prog="ABodyBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("-H", "--heavy_sequence", help="Heavy chain amino acid sequence", default=None)
parser.add_argument("-L", "--light_sequence", help="Light chain amino acid sequence", default=None)
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a heavy amd light chain named H and L", default=None)
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
parser.add_argument("-n", "--numbering_scheme", help="The scheme used to number output antibody structures. Available numbering schemes are: imgt, chothia, kabat, aho, wolfguy, martin and raw. Default is imgt.", default='imgt')
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
args = parser.parse_args()
if (args.heavy_sequence is not None) and (args.light_sequence is not None):
seqs = {"H":args.heavy_sequence, "L":args.light_sequence}
elif args.fasta_file is not None:
seqs = sequence_dict_from_fasta(args.fasta_file)
else:
raise ValueError("Missing input sequences")
check_for_strained_bonds = not args.no_sidechain_bond_check
if args.n_threads > 0:
torch.set_num_threads(args.n_threads)
if args.verbose:
print(description, flush=True)
print(f"Sequences loaded succesfully.\nHeavy and light chains are:", flush=True)
[print(f"{chain}: {seqs[chain]}", flush=True) for chain in "HL"]
print("Running sequences through deep learning model...", flush=True)
try:
antibody = ABodyBuilder2(numbering_scheme=args.numbering_scheme).predict(seqs)
except AssertionError as e:
print(e, flush=True)
sys.exit(1)
if args.verbose:
print("Antibody modelled succesfully, starting refinement.", flush=True)
if args.to_directory:
antibody.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
print("Refinement finished. Saving all outputs to directory", flush=True)
else:
antibody.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
outfile = "ABodyBuilder2_output.pdb" if args.output is None else args.output
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)

View File

@@ -0,0 +1,194 @@
import torch
import numpy as np
import os
import argparse
import sys
from ImmuneBuilder.models import StructureModule
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
from ImmuneBuilder.refine import refine
from ImmuneBuilder.sequence_checks import number_sequences
embed_dim = {
"nanobody_model_1":128,
"nanobody_model_2":256,
"nanobody_model_3":256,
"nanobody_model_4":256
}
model_urls = {
"nanobody_model_1": "https://zenodo.org/record/7258553/files/nanobody_model_1?download=1",
"nanobody_model_2": "https://zenodo.org/record/7258553/files/nanobody_model_2?download=1",
"nanobody_model_3": "https://zenodo.org/record/7258553/files/nanobody_model_3?download=1",
"nanobody_model_4": "https://zenodo.org/record/7258553/files/nanobody_model_4?download=1",
}
header = "REMARK NANOBODY STRUCTURE MODELLED USING NANOBODYBUILDER2 \n"
class Nanobody:
def __init__(self, numbered_sequences, predictions):
self.numbered_sequences = numbered_sequences
self.atoms = [x[0] for x in predictions]
self.encodings = [x[1] for x in predictions]
with torch.no_grad():
traces = torch.stack([x[:,0] for x in self.atoms])
self.R,self.t = find_alignment_transform(traces)
self.aligned_traces = (traces-self.t) @ self.R
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
def save_single_unrefined(self, filename, index=0):
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
unrefined = to_pdb(self.numbered_sequences, atoms)
with open(filename, "w+") as file:
file.write(unrefined)
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
if dirname is None:
dirname="NanoBodyBuilder2_output"
if filename is None:
filename="final_model.pdb"
os.makedirs(dirname, exist_ok = True)
for i in range(len(self.atoms)):
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
final_filename = os.path.join(dirname, filename)
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
def save(self, filename=None, check_for_strained_bonds=True, n_threads=-1):
if filename is None:
filename = "NanoBodyBuilder2_output.pdb"
for i in range(len(self.atoms)):
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
else:
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
if not success:
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[header])
class NanoBodyBuilder2:
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme='imgt'):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.scheme = numbering_scheme
if weights_dir is None:
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
self.models = {}
for id in model_ids:
model_file = f"nanobody_model_{id}"
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_file]).to(self.device)
weights_path = os.path.join(weights_dir, model_file)
try:
if not are_weights_ready(weights_path):
print(f"Downloading weights for {model_file}...", flush=True)
download_file(model_urls[model_file], weights_path)
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
except Exception as e:
print(f"ERROR: {model_file} not downloaded or corrupted.", flush=True)
raise e
model.eval()
model.to(torch.get_default_dtype())
self.models[model_file] = model
def predict(self, sequence_dict):
numbered_sequences = number_sequences(sequence_dict, allowed_species=None, scheme = self.scheme)
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
with torch.no_grad():
sequence_dict["L"] = ""
encoding = torch.tensor(get_encoding(sequence_dict, "H"), device = self.device, dtype=torch.get_default_dtype())
full_seq = sequence_dict["H"]
outputs = []
for model_file in self.models:
pred = self.models[model_file](encoding, full_seq)
outputs.append(pred)
numbered_sequences["L"] = []
return Nanobody(numbered_sequences, outputs)
def command_line_interface():
description="""
\/
NanoBodyBuilder2 ⊂'l
A Method for Nanobody Structure Prediction ll
Author: Brennan Abanades Kenyon llama~
Supervisor: Charlotte Deane || ||
'' ''
"""
parser = argparse.ArgumentParser(prog="NanoBodyBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("-H", "--heavy_sequence", help="VHH amino acid sequence", default=None)
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a heavy chain named H", default=None)
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
parser.add_argument("-n", "--numbering_scheme", help="The scheme used to number output nanobody structures. Available numbering schemes are: imgt, chothia, kabat, aho, wolfguy, martin and raw. Default is imgt.", default='imgt')
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
args = parser.parse_args()
if args.heavy_sequence is not None:
seqs = {"H":args.heavy_sequence}
elif args.fasta_file is not None:
seqs = sequence_dict_from_fasta(args.fasta_file)
else:
raise ValueError("Missing input sequences")
check_for_strained_bonds = not args.no_sidechain_bond_check
if args.n_threads > 0:
torch.set_num_threads(args.n_threads)
if args.verbose:
print(description, flush=True)
print(f"Sequence loaded succesfully.\nHeavy chain is:", flush=True)
print("H: " + seqs["H"], flush=True)
print("Running sequence through deep learning model...", flush=True)
try:
antibody = NanoBodyBuilder2(numbering_scheme=args.numbering_scheme).predict(seqs)
except AssertionError as e:
print(e, flush=True)
sys.exit(1)
if args.verbose:
print("Nanobody modelled succesfully, starting refinement.", flush=True)
if args.to_directory:
antibody.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
print("Refinement finished. Saving all outputs to directory", flush=True)
else:
antibody.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
outfile = "NanoBodyBuilder2_output.pdb" if args.output is None else args.output
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)

View File

@@ -0,0 +1,216 @@
import torch
import numpy as np
import os
import sys
import argparse
from ImmuneBuilder.models import StructureModule
from ImmuneBuilder.util import get_encoding, to_pdb, find_alignment_transform, download_file, sequence_dict_from_fasta, add_errors_as_bfactors, are_weights_ready
from ImmuneBuilder.refine import refine
from ImmuneBuilder.sequence_checks import number_sequences
embed_dim = {
"tcr2_plus_model_1":256,
"tcr2_plus_model_2":256,
"tcr2_plus_model_3":128,
"tcr2_plus_model_4":128,
"tcr2_model_1":128,
"tcr2_model_2":128,
"tcr2_model_3":256,
"tcr2_model_4":256,
}
model_urls = {
"tcr2_plus_model_1": "https://zenodo.org/records/10892159/files/tcr_model_1?download=1",
"tcr2_plus_model_2": "https://zenodo.org/record/10892159/files/tcr_model_2?download=1",
"tcr2_plus_model_3": "https://zenodo.org/record/10892159/files/tcr_model_3?download=1",
"tcr2_plus_model_4": "https://zenodo.org/record/10892159/files/tcr_model_4?download=1",
"tcr2_model_1": "https://zenodo.org/record/7258553/files/tcr_model_1?download=1",
"tcr2_model_2": "https://zenodo.org/record/7258553/files/tcr_model_2?download=1",
"tcr2_model_3": "https://zenodo.org/record/7258553/files/tcr_model_3?download=1",
"tcr2_model_4": "https://zenodo.org/record/7258553/files/tcr_model_4?download=1",
}
class TCR:
def __init__(self, numbered_sequences, predictions, header=None):
self.numbered_sequences = numbered_sequences
self.atoms = [x[0] for x in predictions]
self.encodings = [x[1] for x in predictions]
if header is not None:
self.header = header
else:
self.header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2+ \n"
with torch.no_grad():
traces = torch.stack([x[:,0] for x in self.atoms])
self.R,self.t = find_alignment_transform(traces)
self.aligned_traces = (traces-self.t) @ self.R
self.error_estimates = (self.aligned_traces - self.aligned_traces.mean(0)).square().sum(-1)
self.ranking = [x.item() for x in self.error_estimates.mean(-1).argsort()]
def save_single_unrefined(self, filename, index=0):
atoms = (self.atoms[index] - self.t[index]) @ self.R[index]
unrefined = to_pdb(self.numbered_sequences, atoms, chain_ids="BA")
with open(filename, "w+") as file:
file.write(unrefined)
def save_all(self, dirname=None, filename=None, check_for_strained_bonds=True, n_threads=-1):
if dirname is None:
dirname="TCRBuilder2_output"
if filename is None:
filename="final_model.pdb"
os.makedirs(dirname, exist_ok = True)
for i in range(len(self.atoms)):
unrefined_filename = os.path.join(dirname,f"rank{i}_unrefined.pdb")
self.save_single_unrefined(unrefined_filename, index=self.ranking[i])
np.save(os.path.join(dirname,"error_estimates"), self.error_estimates.mean(0).cpu().numpy())
final_filename = os.path.join(dirname, filename)
refine(os.path.join(dirname,"rank0_unrefined.pdb"), final_filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
add_errors_as_bfactors(final_filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[self.header])
def save(self, filename=None, check_for_strained_bonds=True, n_threads=-1):
if filename is None:
filename = "TCRBuilder2_output.pdb"
for i in range(len(self.atoms)):
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
else:
self.save_single_unrefined(filename, index=self.ranking[i])
success = refine(filename, filename, check_for_strained_bonds=check_for_strained_bonds, n_threads=n_threads)
if success:
break
if not success:
print(f"FAILED TO REFINE {filename}.\nSaving anyways.", flush=True)
add_errors_as_bfactors(filename, self.error_estimates.mean(0).sqrt().cpu().numpy(), header=[self.header])
class TCRBuilder2:
def __init__(self, model_ids = [1,2,3,4], weights_dir=None, numbering_scheme='imgt', use_TCRBuilder2_PLUS_weights=True):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.scheme = numbering_scheme
self.use_TCRBuilder2_PLUS_weights = use_TCRBuilder2_PLUS_weights
if weights_dir is None:
weights_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "trained_model")
self.models = {}
for id in model_ids:
if use_TCRBuilder2_PLUS_weights:
model_file = f"tcr_model_{id}"
model_key = f'tcr2_plus_model_{id}'
else:
model_file = f"tcr2_model_{id}"
model_key = f'tcr2_model_{id}'
model = StructureModule(rel_pos_dim=64, embed_dim=embed_dim[model_key]).to(self.device)
weights_path = os.path.join(weights_dir, model_file)
try:
if not are_weights_ready(weights_path):
print(f"Downloading weights for {model_file}...", flush=True)
download_file(model_urls[model_key], weights_path)
model.load_state_dict(torch.load(weights_path, map_location=torch.device(self.device)))
except Exception as e:
print(f"ERROR: {model_file} not downloaded or corrupted.", flush=True)
raise e
model.eval()
model.to(torch.get_default_dtype())
self.models[model_file] = model
def predict(self, sequence_dict):
numbered_sequences = number_sequences(sequence_dict, scheme=self.scheme)
sequence_dict = {chain: "".join([x[1] for x in numbered_sequences[chain]]) for chain in numbered_sequences}
with torch.no_grad():
sequence_dict = {"H":sequence_dict["B"], "L":sequence_dict["A"]}
encoding = torch.tensor(get_encoding(sequence_dict), device = self.device, dtype=torch.get_default_dtype())
full_seq = sequence_dict["H"] + sequence_dict["L"]
outputs = []
for model_file in self.models:
pred = self.models[model_file](encoding, full_seq)
outputs.append(pred)
if self.use_TCRBuilder2_PLUS_weights:
header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2+ \n"
else:
header = "REMARK TCR STRUCTURE MODELLED USING TCRBUILDER2 \n"
return TCR(numbered_sequences, outputs, header)
def command_line_interface():
description="""
TCRBuilder2 || ||
A Method for T-Cell Receptor Structure Prediction |VB|VA|
Author: Brennan Abanades Kenyon, Nele Quast |CB|CA|
Supervisor: Charlotte Deane -------------
By default TCRBuilder2 will use TCRBuilder2+ weights.
To use the original weights, see options.
"""
parser = argparse.ArgumentParser(prog="TCRBuilder2", description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("-B", "--beta_sequence", help="Beta chain amino acid sequence", default=None)
parser.add_argument("-A", "--alpha_sequence", help="Alpha chain amino acid sequence", default=None)
parser.add_argument("-f", "--fasta_file", help="Fasta file containing a beta and alpha chain named B and A", default=None)
parser.add_argument("-o", "--output", help="Path to where the output model should be saved. Defaults to the same directory as input file.", default=None)
parser.add_argument("--to_directory", help="Save all unrefined models and the top ranked refined model to a directory. "
"If this flag is set the output argument will be assumed to be a directory", default=False, action="store_true")
parser.add_argument("--n_threads", help="The number of CPU threads to be used. If this option is set, refinement will be performed on CPU instead of GPU. By default, all available cores will be used.", type=int, default=-1)
parser.add_argument("-u", "--no_sidechain_bond_check", help="Don't check for strained bonds. This is a bit faster but will rarely generate unphysical side chains", default=False, action="store_true")
parser.add_argument("-v", "--verbose", help="Verbose output", default=False, action="store_true")
parser.add_argument("-og", "--original_weights", help="use original TCRBuilder2 weights instead of TCRBuilder2+", default=False, action="store_true")
args = parser.parse_args()
if (args.beta_sequence is not None) and (args.alpha_sequence is not None):
seqs = {"B":args.beta_sequence, "A":args.alpha_sequence}
elif args.fasta_file is not None:
seqs = sequence_dict_from_fasta(args.fasta_file)
else:
raise ValueError("Missing input sequences")
check_for_strained_bonds = not args.no_sidechain_bond_check
if args.n_threads > 0:
torch.set_num_threads(args.n_threads)
if args.verbose:
print(description, flush=True)
print(f"Sequences loaded succesfully.\nAlpha and Beta chains are:", flush=True)
[print(f"{chain}: {seqs[chain]}", flush=True) for chain in "AB"]
print("Running sequences through deep learning model...", flush=True)
try:
tcr = TCRBuilder2(use_TCRBuilder2_PLUS_weights=not args.original_weights).predict(seqs)
except AssertionError as e:
print(e, flush=True)
sys.exit(1)
if args.verbose:
print("TCR modelled succesfully, starting refinement.", flush=True)
if args.to_directory:
tcr.save_all(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
print("Refinement finished. Saving all outputs to directory", flush=True)
else:
tcr.save(args.output,check_for_strained_bonds=check_for_strained_bonds, n_threads=args.n_threads)
if args.verbose:
outfile = "TCRBuilder2_output.pdb" if args.output is None else args.output
print(f"Refinement finished. Saving final structure to {outfile}", flush=True)

View File

@@ -0,0 +1,3 @@
from ImmuneBuilder.ABodyBuilder2 import ABodyBuilder2
from ImmuneBuilder.TCRBuilder2 import TCRBuilder2
from ImmuneBuilder.NanoBodyBuilder2 import NanoBodyBuilder2

294
ImmuneBuilder/constants.py Normal file
View File

@@ -0,0 +1,294 @@
restypes = 'ARNDCQEGHILKMFPSTWYV'
# Residue names definition:
restype_1to3 = {
'A': 'ALA',
'R': 'ARG',
'N': 'ASN',
'D': 'ASP',
'C': 'CYS',
'Q': 'GLN',
'E': 'GLU',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'L': 'LEU',
'K': 'LYS',
'M': 'MET',
'F': 'PHE',
'P': 'PRO',
'S': 'SER',
'T': 'THR',
'W': 'TRP',
'Y': 'TYR',
'V': 'VAL',
}
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# How atoms are sorted in MLAb:
residue_atoms = {
'A': ['CA', 'N', 'C', 'CB', 'O'],
'C': ['CA', 'N', 'C', 'CB', 'O', 'SG'],
'D': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'OD1', 'OD2'],
'E': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'OE1', 'OE2'],
'F': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
'G': ['CA', 'N', 'C', 'CA', 'O'], # G has no CB so I am padding it with CA so the Os are aligned
'H': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD2', 'CE1', 'ND1', 'NE2'],
'I': ['CA', 'N', 'C', 'CB', 'O', 'CG1', 'CG2', 'CD1'],
'K': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'CE', 'NZ'],
'L': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2'],
'M': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CE', 'SD'],
'N': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'ND2', 'OD1'],
'P': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD'],
'Q': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'NE2', 'OE1'],
'R': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD', 'CZ', 'NE', 'NH1', 'NH2'],
'S': ['CA', 'N', 'C', 'CB', 'O', 'OG'],
'T': ['CA', 'N', 'C', 'CB', 'O', 'CG2', 'OG1'],
'V': ['CA', 'N', 'C', 'CB', 'O', 'CG1', 'CG2'],
'W': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'NE1'],
'Y': ['CA', 'N', 'C', 'CB', 'O', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH']}
residue_atoms_mask = {res: len(residue_atoms[res]) * [True] + (14 - len(residue_atoms[res])) * [False] for res in
residue_atoms}
atom_types = [
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
'CZ3', 'NZ', 'OXT'
]
# Position of atoms in each ref frame
rigid_group_atom_positions2 = {'A': {'C': [0, (1.526, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.529, -0.774, -1.205)],
'N': [0, (-0.525, 1.363, 0.0)],
'O': [3, (-0.627, 1.062, 0.0)]},
'C': {'C': [0, (1.524, 0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.519, -0.773, -1.212)],
'N': [0, (-0.522, 1.362, -0.0)],
'O': [3, (-0.625, 1.062, -0.0)],
'SG': [4, (-0.728, 1.653, 0.0)]},
'D': {'C': [0, (1.527, 0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.526, -0.778, -1.208)],
'CG': [4, (-0.593, 1.398, -0.0)],
'N': [0, (-0.525, 1.362, -0.0)],
'O': [3, (-0.626, 1.062, -0.0)],
'OD1': [5, (-0.61, 1.091, 0.0)],
'OD2': [5, (-0.592, -1.101, 0.003)]},
'E': {'C': [0, (1.526, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.526, -0.781, -1.207)],
'CD': [5, (-0.6, 1.397, 0.0)],
'CG': [4, (-0.615, 1.392, 0.0)],
'N': [0, (-0.528, 1.361, 0.0)],
'O': [3, (-0.626, 1.062, 0.0)],
'OE1': [6, (-0.607, 1.095, -0.0)],
'OE2': [6, (-0.589, -1.104, 0.001)]},
'F': {'C': [0, (1.524, 0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.525, -0.776, -1.212)],
'CD1': [5, (-0.709, 1.195, -0.0)],
'CD2': [5, (-0.706, -1.196, 0.0)],
'CE1': [5, (-2.102, 1.198, -0.0)],
'CE2': [5, (-2.098, -1.201, -0.0)],
'CG': [4, (-0.607, 1.377, 0.0)],
'CZ': [5, (-2.794, -0.003, 0.001)],
'N': [0, (-0.518, 1.363, 0.0)],
'O': [3, (-0.626, 1.062, -0.0)]},
'G': {'C': [0, (1.517, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'N': [0, (-0.572, 1.337, 0.0)],
'O': [3, (-0.626, 1.062, -0.0)]},
'H': {'C': [0, (1.525, 0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.525, -0.778, -1.208)],
'CD2': [5, (-0.889, -1.021, -0.003)],
'CE1': [5, (-2.03, 0.851, -0.002)],
'CG': [4, (-0.6, 1.37, -0.0)],
'N': [0, (-0.527, 1.36, 0.0)],
'ND1': [5, (-0.744, 1.16, -0.0)],
'NE2': [5, (-2.145, -0.466, -0.004)],
'O': [3, (-0.625, 1.063, 0.0)]},
'I': {'C': [0, (1.527, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.536, -0.793, -1.213)],
'CD1': [5, (-0.619, 1.391, 0.0)],
'CG1': [4, (-0.534, 1.437, -0.0)],
'CG2': [4, (-0.54, -0.785, 1.199)],
'N': [0, (-0.493, 1.373, -0.0)],
'O': [3, (-0.627, 1.062, -0.0)]},
'K': {'C': [0, (1.526, 0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.524, -0.778, -1.208)],
'CD': [5, (-0.559, 1.417, 0.0)],
'CE': [6, (-0.56, 1.416, 0.0)],
'CG': [4, (-0.619, 1.39, 0.0)],
'N': [0, (-0.526, 1.362, -0.0)],
'NZ': [7, (-0.554, 1.387, 0.0)],
'O': [3, (-0.626, 1.062, -0.0)]},
'L': {'C': [0, (1.525, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.522, -0.773, -1.214)],
'CD1': [5, (-0.53, 1.43, -0.0)],
'CD2': [5, (-0.535, -0.774, -1.2)],
'CG': [4, (-0.678, 1.371, 0.0)],
'N': [0, (-0.52, 1.363, 0.0)],
'O': [3, (-0.625, 1.063, -0.0)]},
'M': {'C': [0, (1.525, 0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.523, -0.776, -1.21)],
'CE': [6, (-0.32, 1.786, -0.0)],
'CG': [4, (-0.613, 1.391, -0.0)],
'N': [0, (-0.521, 1.364, -0.0)],
'O': [3, (-0.625, 1.062, -0.0)],
'SD': [5, (-0.703, 1.695, 0.0)]},
'N': {'C': [0, (1.526, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.531, -0.787, -1.2)],
'CG': [4, (-0.584, 1.399, 0.0)],
'N': [0, (-0.536, 1.357, 0.0)],
'ND2': [5, (-0.593, -1.188, -0.001)],
'O': [3, (-0.625, 1.062, 0.0)],
'OD1': [5, (-0.633, 1.059, 0.0)]},
'P': {'C': [0, (1.527, -0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.546, -0.611, -1.293)],
'CD': [5, (-0.477, 1.424, 0.0)],
'CG': [4, (-0.382, 1.445, 0.0)],
'N': [0, (-0.566, 1.351, -0.0)],
'O': [3, (-0.621, 1.066, 0.0)]},
'Q': {'C': [0, (1.526, 0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.525, -0.779, -1.207)],
'CD': [5, (-0.587, 1.399, -0.0)],
'CG': [4, (-0.615, 1.393, 0.0)],
'N': [0, (-0.526, 1.361, -0.0)],
'NE2': [6, (-0.593, -1.189, 0.001)],
'O': [3, (-0.626, 1.062, -0.0)],
'OE1': [6, (-0.634, 1.06, 0.0)]},
'R': {'C': [0, (1.525, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.524, -0.778, -1.209)],
'CD': [5, (-0.564, 1.414, 0.0)],
'CG': [4, (-0.616, 1.39, -0.0)],
'CZ': [7, (-0.758, 1.093, -0.0)],
'N': [0, (-0.524, 1.362, -0.0)],
'NE': [6, (-0.539, 1.357, -0.0)],
'NH1': [7, (-0.206, 2.301, 0.0)],
'NH2': [7, (-2.078, 0.978, -0.0)],
'O': [3, (-0.626, 1.062, 0.0)]},
'S': {'C': [0, (1.525, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.518, -0.777, -1.211)],
'N': [0, (-0.529, 1.36, -0.0)],
'O': [3, (-0.626, 1.062, -0.0)],
'OG': [4, (-0.503, 1.325, 0.0)]},
'T': {'C': [0, (1.526, 0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.516, -0.793, -1.215)],
'CG2': [4, (-0.55, -0.718, 1.228)],
'N': [0, (-0.517, 1.364, 0.0)],
'O': [3, (-0.626, 1.062, 0.0)],
'OG1': [4, (-0.472, 1.353, 0.0)]},
'V': {'C': [0, (1.527, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.533, -0.795, -1.213)],
'CG1': [4, (-0.54, 1.429, -0.0)],
'CG2': [4, (-0.533, -0.776, -1.203)],
'N': [0, (-0.494, 1.373, -0.0)],
'O': [3, (-0.627, 1.062, -0.0)]},
'W': {'C': [0, (1.525, -0.0, 0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.523, -0.776, -1.212)],
'CD1': [5, (-0.824, 1.091, 0.0)],
'CD2': [5, (-0.854, -1.148, -0.005)],
'CE2': [5, (-2.186, -0.678, -0.007)],
'CE3': [5, (-0.622, -2.53, -0.007)],
'CG': [4, (-0.609, 1.37, -0.0)],
'CH2': [5, (-3.028, -2.89, -0.013)],
'CZ2': [5, (-3.283, -1.543, -0.011)],
'CZ3': [5, (-1.715, -3.389, -0.011)],
'N': [0, (-0.521, 1.363, 0.0)],
'NE1': [5, (-2.14, 0.69, -0.004)],
'O': [3, (-0.627, 1.062, 0.0)]},
'Y': {'C': [0, (1.524, -0.0, -0.0)],
'CA': [0, (0.0, 0.0, 0.0)],
'CB': [0, (-0.522, -0.776, -1.213)],
'CD1': [5, (-0.716, 1.195, -0.0)],
'CD2': [5, (-0.713, -1.194, -0.001)],
'CE1': [5, (-2.107, 1.2, -0.002)],
'CE2': [5, (-2.104, -1.201, -0.003)],
'CG': [4, (-0.607, 1.382, -0.0)],
'CZ': [5, (-2.791, -0.001, -0.003)],
'N': [0, (-0.522, 1.362, 0.0)],
'O': [3, (-0.627, 1.062, -0.0)],
'OH': [5, (-4.168, -0.002, -0.005)]}}
chi_angles_atoms = {'A': [],
'C': [['N', 'CA', 'CB', 'SG']],
'D': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'E': [['N', 'CA', 'CB', 'CG'],
['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'F': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'G': [],
'H': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
'I': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
'K': [['N', 'CA', 'CB', 'CG'],
['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'CE'],
['CG', 'CD', 'CE', 'NZ']],
'L': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'M': [['N', 'CA', 'CB', 'CG'],
['CA', 'CB', 'CG', 'SD'],
['CB', 'CG', 'SD', 'CE']],
'N': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
'P': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
'Q': [['N', 'CA', 'CB', 'CG'],
['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'OE1']],
'R': [['N', 'CA', 'CB', 'CG'],
['CA', 'CB', 'CG', 'CD'],
['CB', 'CG', 'CD', 'NE'],
['CG', 'CD', 'NE', 'CZ']],
'S': [['N', 'CA', 'CB', 'OG']],
'T': [['N', 'CA', 'CB', 'OG1']],
'V': [['N', 'CA', 'CB', 'CG1']],
'W': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
'Y': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']]}
chi_angles_positions = {}
for r in residue_atoms:
chi_angles_positions[r] = []
for angs in chi_angles_atoms[r]:
chi_angles_positions[r].append([residue_atoms[r].index(atom) for atom in angs])
chi2_centers = {x: chi_angles_atoms[x][1][-2] if len(chi_angles_atoms[x]) > 1 else "CA" for x in chi_angles_atoms}
chi3_centers = {x: chi_angles_atoms[x][2][-2] if len(chi_angles_atoms[x]) > 2 else "CA" for x in chi_angles_atoms}
chi4_centers = {x: chi_angles_atoms[x][3][-2] if len(chi_angles_atoms[x]) > 3 else "CA" for x in chi_angles_atoms}
rel_pos = {
x: [rigid_group_atom_positions2[x][residue_atoms[x][atom_id]] if len(residue_atoms[x]) > atom_id else [0, (0, 0, 0)]
for atom_id in range(14)] for x in rigid_group_atom_positions2}
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
residue_van_der_waals_radius = {
x: [van_der_waals_radius[atom[0]] for atom in residue_atoms[x]] + (14 - len(residue_atoms[x])) * ([0]) for x in
residue_atoms}
valid_rigids = {x: len(chi_angles_atoms[x]) + 2 for x in chi_angles_atoms}
r2n = {x: i for i, x in enumerate(restypes)}
res_to_num = lambda x: r2n[x] if x in r2n else len(r2n)

230
ImmuneBuilder/models.py Normal file
View File

@@ -0,0 +1,230 @@
import torch
from einops import rearrange
from ImmuneBuilder.rigids import Rigid, Rot, rigid_body_identity, vec_from_tensor, global_frames_from_bb_frame_and_torsion_angles, all_atoms_from_global_reference_frames
class InvariantPointAttention(torch.nn.Module):
def __init__(self, node_dim, edge_dim, heads=12, head_dim=16, n_query_points=4, n_value_points=8, **kwargs):
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.n_query_points = n_query_points
node_scalar_attention_inner_dim = heads * head_dim
node_vector_attention_inner_dim = 3 * n_query_points * heads
node_vector_attention_value_dim = 3 * n_value_points * heads
after_final_cat_dim = heads * edge_dim + heads * head_dim + heads * n_value_points * 4
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
self.point_weight = torch.nn.Parameter(point_weight_init_value)
self.to_scalar_qkv = torch.nn.Linear(node_dim, 3 * node_scalar_attention_inner_dim, bias=False)
self.to_vector_qk = torch.nn.Linear(node_dim, 2 * node_vector_attention_inner_dim, bias=False)
self.to_vector_v = torch.nn.Linear(node_dim, node_vector_attention_value_dim, bias=False)
self.to_scalar_edge_attention_bias = torch.nn.Linear(edge_dim, heads, bias=False)
self.final_linear = torch.nn.Linear(after_final_cat_dim, node_dim)
with torch.no_grad():
self.final_linear.weight.fill_(0.0)
self.final_linear.bias.fill_(0.0)
def forward(self, node_features, edge_features, rigid):
# Classic attention on nodes
scalar_qkv = self.to_scalar_qkv(node_features).chunk(3, dim=-1)
scalar_q, scalar_k, scalar_v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=self.heads), scalar_qkv)
node_scalar = torch.einsum('h i d, h j d -> h i j', scalar_q, scalar_k) * self.head_dim ** (-1 / 2)
# Linear bias on edges
edge_bias = rearrange(self.to_scalar_edge_attention_bias(edge_features), 'i j h -> h i j')
# Reference frame attention
wc = (2 / self.n_query_points) ** (1 / 2) / 6
vector_qk = self.to_vector_qk(node_features).chunk(2, dim=-1)
vector_q, vector_k = map(lambda x: vec_from_tensor(rearrange(x, 'n (h p d) -> h n p d', h=self.heads, d=3)),
vector_qk)
rigid_ = rigid.unsqueeze(0).unsqueeze(-1) # add head and point dimension to rigids
global_vector_k = rigid_ @ vector_k
global_vector_q = rigid_ @ vector_q
global_frame_distance = wc * global_vector_q.unsqueeze(-2).dist(global_vector_k.unsqueeze(-3)).sum(
-1) * rearrange(self.point_weight, "h -> h () ()")
# Combining attentions
attention_matrix = (3 ** (-1 / 2) * (node_scalar + edge_bias - global_frame_distance)).softmax(-1)
# Obtaining outputs
edge_output = (rearrange(attention_matrix, 'h i j -> i h () j') * rearrange(edge_features,
'i j d -> i () d j')).sum(-1)
scalar_node_output = torch.einsum('h i j, h j d -> i h d', attention_matrix, scalar_v)
vector_v = vec_from_tensor(
rearrange(self.to_vector_v(node_features), 'n (h p d) -> h n p d', h=self.heads, d=3))
global_vector_v = rigid_ @ vector_v
attended_global_vector_v = global_vector_v.map(
lambda x: torch.einsum('h i j, h j p -> h i p', attention_matrix, x))
vector_node_output = rigid_.inv() @ attended_global_vector_v
vector_node_output = torch.stack(
[vector_node_output.norm(), vector_node_output.x, vector_node_output.y, vector_node_output.z], dim=-1)
# Concatenate along heads and points
edge_output = rearrange(edge_output, 'n h d -> n (h d)')
scalar_node_output = rearrange(scalar_node_output, 'n h d -> n (h d)')
vector_node_output = rearrange(vector_node_output, 'h n p d -> n (h p d)')
combined = torch.cat([edge_output, scalar_node_output, vector_node_output], dim=-1)
return node_features + self.final_linear(combined)
class BackboneUpdate(torch.nn.Module):
def __init__(self, node_dim):
super().__init__()
self.to_correction = torch.nn.Linear(node_dim, 6)
def forward(self, node_features, update_mask=None):
# Predict quaternions and translation vector
rot, t = self.to_correction(node_features).chunk(2, dim=-1)
# I may not want to update all residues
if update_mask is not None:
rot = update_mask[:, None] * rot
t = update_mask[:, None] * t
# Normalize quaternions
norm = (1 + rot.pow(2).sum(-1, keepdim=True)).pow(1 / 2)
b, c, d = (rot / norm).chunk(3, dim=-1)
a = 1 / norm
a, b, c, d = a.squeeze(-1), b.squeeze(-1), c.squeeze(-1), d.squeeze(-1)
# Make rotation matrix from quaternions
R = Rot(
(a ** 2 + b ** 2 - c ** 2 - d ** 2), (2 * b * c - 2 * a * d), (2 * b * d + 2 * a * c),
(2 * b * c + 2 * a * d), (a ** 2 - b ** 2 + c ** 2 - d ** 2), (2 * c * d - 2 * a * b),
(2 * b * d - 2 * a * c), (2 * c * d + 2 * a * b), (a ** 2 - b ** 2 - c ** 2 + d ** 2)
)
return Rigid(vec_from_tensor(t), R)
class TorsionAngles(torch.nn.Module):
def __init__(self, node_dim):
super().__init__()
self.residual1 = torch.nn.Sequential(
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim)
)
self.residual2 = torch.nn.Sequential(
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim)
)
self.final_pred = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 10)
)
with torch.no_grad():
self.residual1[-1].weight.fill_(0.0)
self.residual2[-1].weight.fill_(0.0)
self.residual1[-1].bias.fill_(0.0)
self.residual2[-1].bias.fill_(0.0)
def forward(self, node_features, s_i):
full_feat = torch.cat([node_features, s_i], axis=-1)
full_feat = full_feat + self.residual1(full_feat)
full_feat = full_feat + self.residual2(full_feat)
torsions = rearrange(self.final_pred(full_feat), "i (t d) -> i t d", d=2)
norm = torch.norm(torsions, dim=-1, keepdim=True)
return torsions / norm, norm
class StructureUpdate(torch.nn.Module):
def __init__(self, node_dim, edge_dim, dropout=0.0, **kwargs):
super().__init__()
self.IPA = InvariantPointAttention(node_dim, edge_dim, **kwargs)
self.norm1 = torch.nn.Sequential(
torch.nn.Dropout(dropout),
torch.nn.LayerNorm(node_dim)
)
self.norm2 = torch.nn.Sequential(
torch.nn.Dropout(dropout),
torch.nn.LayerNorm(node_dim)
)
self.residual = torch.nn.Sequential(
torch.nn.Linear(node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, node_dim)
)
self.torsion_angles = TorsionAngles(node_dim)
self.backbone_update = BackboneUpdate(node_dim)
with torch.no_grad():
self.residual[-1].weight.fill_(0.0)
self.residual[-1].bias.fill_(0.0)
def forward(self, node_features, edge_features, rigid_pred, update_mask=None):
s_i = self.IPA(node_features, edge_features, rigid_pred)
s_i = self.norm1(s_i)
s_i = s_i + self.residual(s_i)
s_i = self.norm2(s_i)
rigid_new = rigid_pred @ self.backbone_update(s_i, update_mask)
return s_i, rigid_new
class StructureModule(torch.nn.Module):
def __init__(self, node_dim=23, n_layers=8, rel_pos_dim=64, embed_dim=128, **kwargs):
super().__init__()
self.n_layers = n_layers
self.rel_pos_dim = rel_pos_dim
self.node_embed = torch.nn.Linear(node_dim, embed_dim)
self.edge_embed = torch.nn.Linear(2 * rel_pos_dim + 1, embed_dim - 1)
self.layers = torch.nn.ModuleList(
[StructureUpdate(node_dim=embed_dim,
edge_dim=embed_dim,
propagate_rotation_gradient=(i == n_layers - 1),
**kwargs)
for i in range(n_layers)])
def forward(self, node_features, sequence):
rigid_in = rigid_body_identity(len(sequence)).to(node_features.device)
relative_positions = (torch.arange(node_features.shape[-2])[None] -
torch.arange(node_features.shape[-2])[:, None])
relative_positions = relative_positions.clamp(min=-self.rel_pos_dim, max=self.rel_pos_dim) + self.rel_pos_dim
rel_pos_embeddings = torch.nn.functional.one_hot(relative_positions, num_classes=2 * self.rel_pos_dim + 1)
rel_pos_embeddings = rel_pos_embeddings.to(dtype=node_features.dtype, device=node_features.device)
rel_pos_embeddings = self.edge_embed(rel_pos_embeddings)
new_node_features = self.node_embed(node_features)
for layer in self.layers:
edge_features = torch.cat(
[rigid_in.origin.unsqueeze(-1).dist(rigid_in.origin).unsqueeze(-1), rel_pos_embeddings], dim=-1)
new_node_features, rigid_in = layer(new_node_features, edge_features, rigid_in)
torsions, _ = self.layers[-1].torsion_angles(self.node_embed(node_features), new_node_features)
all_reference_frames = global_frames_from_bb_frame_and_torsion_angles(rigid_in, torsions, sequence)
all_atoms = all_atoms_from_global_reference_frames(all_reference_frames, sequence)
# Remove atoms of side chains with outrageous clashes
ds = torch.linalg.norm(all_atoms[None,:,None] - all_atoms[:,None,:,None], axis = -1)
ds[torch.isnan(ds) | (ds==0.0)] = 10
min_ds = ds.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0]
all_atoms[min_ds < 0.2, 5:, :] = float("Nan")
return all_atoms, new_node_features

354
ImmuneBuilder/refine.py Normal file
View File

@@ -0,0 +1,354 @@
import pdbfixer
import os
import numpy as np
from openmm import app, LangevinIntegrator, CustomExternalForce, CustomTorsionForce, OpenMMException, Platform, unit
from scipy import spatial
import logging
logging.disable()
ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms
spring_unit = ENERGY / (LENGTH ** 2)
CLASH_CUTOFF = 0.63
# Atomic radii for various atom types.
atom_radii = {"C": 1.70, "N": 1.55, 'O': 1.52, 'S': 1.80}
# Sum of van-der-waals radii
radii_sums = dict(
[(i + j, (atom_radii[i] + atom_radii[j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
# Clash_cutoff-based radii values
cutoffs = dict(
[(i + j, CLASH_CUTOFF * (radii_sums[i + j])) for i in list(atom_radii.keys()) for j in list(atom_radii.keys())])
# Using amber14 recommended protein force field
forcefield = app.ForceField("amber14/protein.ff14SB.xml")
def refine(input_file, output_file, check_for_strained_bonds=True, tries=3, n=6, n_threads=-1):
for i in range(tries):
if refine_once(input_file, output_file, check_for_strained_bonds=check_for_strained_bonds, n=n, n_threads=n_threads):
return True
return False
def refine_once(input_file, output_file, check_for_strained_bonds=True, n=6, n_threads=-1):
k1s = [2.5,1,0.5,0.25,0.1,0.001]
k2s = [2.5,5,7.5,15,25,50]
success = False
fixer = pdbfixer.PDBFixer(input_file)
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
k1 = k1s[0]
k2 = -1 if cis_check(fixer.topology, fixer.positions) else k2s[0]
topology, positions = fixer.topology, fixer.positions
for i in range(n):
try:
simulation = minimize_energy(topology, positions, k1=k1, k2 = k2, n_threads=n_threads)
topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
acceptable_bonds, trans_peptide_bonds = bond_check(topology, positions), cis_check(topology, positions)
except OpenMMException as e:
if (i == n-1) and ("positions" not in locals()):
print("OpenMM failed to refine {}".format(input_file), flush=True)
raise e
else:
topology, positions = fixer.topology, fixer.positions
continue
# If peptide bonds are the wrong length, decrease the strength of the positional restraint
if not acceptable_bonds:
k1 = k1s[min(i, len(k1s)-1)]
# If there are still cis isomers in the model, increase the force to fix these
if not trans_peptide_bonds:
k2 = k2s[min(i, len(k2s)-1)]
else:
k2 = -1
if acceptable_bonds and trans_peptide_bonds:
# If peptide bond lengths and torsions are okay, check and fix the chirality.
try:
simulation = chirality_fixer(simulation)
topology, positions = simulation.topology, simulation.context.getState(getPositions=True).getPositions()
except OpenMMException as e:
topology, positions = fixer.topology, fixer.positions
continue
if check_for_strained_bonds:
# If all other checks pass, check and fix strained sidechain bonds:
try:
strained_bonds = strained_sidechain_bonds_check(topology, positions)
if len(strained_bonds) > 0:
needs_recheck = True
topology, positions = strained_sidechain_bonds_fixer(strained_bonds, topology, positions, n_threads=n_threads)
else:
needs_recheck = False
except OpenMMException as e:
topology, positions = fixer.topology, fixer.positions
continue
else:
needs_recheck = False
# If it passes all the tests, we are done
tests = bond_check(topology, positions) and cis_check(topology, positions)
if needs_recheck:
tests = tests and strained_sidechain_bonds_check(topology, positions)
if tests and stereo_check(topology, positions) and clash_check(topology, positions):
success = True
break
with open(output_file, "w") as out_handle:
app.PDBFile.writeFile(topology, positions, out_handle, keepIds=True)
return success
def minimize_energy(topology, positions, k1=2.5, k2=2.5, n_threads=-1):
# Fill in the gaps with OpenMM Modeller
modeller = app.Modeller(topology, positions)
modeller.addHydrogens(forcefield)
# Set up force field
system = forcefield.createSystem(modeller.topology)
# Keep atoms close to initial prediction
force = CustomExternalForce("k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
force.addGlobalParameter("k", k1 * spring_unit)
for p in ["x0", "y0", "z0"]:
force.addPerParticleParameter(p)
for residue in modeller.topology.residues():
for atom in residue.atoms():
if atom.name in ["CA", "CB", "N", "C"]:
force.addParticle(atom.index, modeller.positions[atom.index])
system.addForce(force)
if k2 > 0.0:
cis_force = CustomTorsionForce("10*k2*(1+cos(theta))^2")
cis_force.addGlobalParameter("k2", k2 * ENERGY)
for chain in modeller.topology.chains():
residues = [res for res in chain.residues()]
relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
for i in range(1,len(residues)):
if residues[i].name == "PRO":
continue
resi = relevant_atoms[i-1]
n_resi = relevant_atoms[i]
cis_force.addTorsion(resi["CA"], resi["C"], n_resi["N"], n_resi["CA"])
system.addForce(cis_force)
# Set up integrator
integrator = LangevinIntegrator(0, 0.01, 0.0)
# Set up the simulation
if n_threads > 0:
# Set number of threads used by OpenMM
platform = Platform.getPlatformByName('CPU')
simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads': str(n_threads)})
else:
simulation = app.Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)
# Minimize the energy
simulation.minimizeEnergy()
return simulation
def chirality_fixer(simulation):
topology = simulation.topology
positions = simulation.context.getState(getPositions=True).getPositions()
d_stereoisomers = []
for residue in topology.residues():
if residue.name == "GLY":
continue
atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
vectors = [positions[atom_indices[i]] - positions[atom_indices["CA"]] for i in ["N", "C", "CB"]]
if np.dot(np.cross(vectors[0], vectors[1]), vectors[2]) < .0*LENGTH**3:
# If it is a D-stereoisomer then flip its H atom
indices = {x.name:x.index for x in residue.atoms() if x.name in ["HA", "CA"]}
positions[indices["HA"]] = 2*positions[indices["CA"]] - positions[indices["HA"]]
# Fix the H atom in place
particle_mass = simulation.system.getParticleMass(indices["HA"])
simulation.system.setParticleMass(indices["HA"], 0.0)
d_stereoisomers.append((indices["HA"], particle_mass))
if len(d_stereoisomers) > 0:
simulation.context.setPositions(positions)
# Minimize the energy with the evil hydrogens fixed
simulation.minimizeEnergy()
# Minimize the energy letting the hydrogens move
for atom in d_stereoisomers:
simulation.system.setParticleMass(*atom)
simulation.minimizeEnergy()
return simulation
def bond_check(topology, positions):
for chain in topology.chains():
residues = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "C"]} for res in chain.residues()]
for i in range(len(residues)-1):
# For simplicity we only check the peptide bond length as the rest should be correct as they are hard coded
v = np.linalg.norm(positions[residues[i]["C"]] - positions[residues[i+1]["N"]])
if abs(v - 1.329*LENGTH) > 0.1*LENGTH:
return False
return True
def cis_bond(p0,p1,p2,p3):
ab = p1-p0
cd = p2-p1
db = p3-p2
u = np.cross(-ab, cd)
v = np.cross(db, cd)
return np.dot(u,v) > 0
def cis_check(topology, positions):
pos = np.array(positions.value_in_unit(LENGTH))
for chain in topology.chains():
residues = [res for res in chain.residues()]
relevant_atoms = [{atom.name:atom.index for atom in res.atoms() if atom.name in ["N", "CA", "C"]} for res in residues]
for i in range(1,len(residues)):
if residues[i].name == "PRO":
continue
resi = relevant_atoms[i-1]
n_resi = relevant_atoms[i]
p0,p1,p2,p3 = pos[resi["CA"]],pos[resi["C"]],pos[n_resi["N"]],pos[n_resi["CA"]]
if cis_bond(p0,p1,p2,p3):
return False
return True
def stereo_check(topology, positions):
pos = np.array(positions.value_in_unit(LENGTH))
for residue in topology.residues():
if residue.name == "GLY":
continue
atom_indices = {atom.name:atom.index for atom in residue.atoms() if atom.name in ["N", "CA", "C", "CB"]}
vectors = pos[[atom_indices[i] for i in ["N", "C", "CB"]]] - pos[atom_indices["CA"]]
if np.linalg.det(vectors) < 0:
return False
return True
def clash_check(topology, positions):
heavies = [x for x in topology.atoms() if x.element.symbol != "H"]
pos = np.array(positions.value_in_unit(LENGTH))[[x.index for x in heavies]]
tree = spatial.KDTree(pos)
pairs = tree.query_pairs(r=max(cutoffs.values()))
for pair in pairs:
atom_i, atom_j = heavies[pair[0]], heavies[pair[1]]
if atom_i.residue.index == atom_j.residue.index:
continue
elif (atom_i.name == "C" and atom_j.name == "N") or (atom_i.name == "N" and atom_j.name == "C"):
continue
atom_distance = np.linalg.norm(pos[pair[0]] - pos[pair[1]])
if (atom_i.name == "SG" and atom_j.name == "SG") and atom_distance > 1.88:
continue
elif atom_distance < (cutoffs[atom_i.element.symbol + atom_j.element.symbol]):
return False
return True
def strained_sidechain_bonds_check(topology, positions):
atoms = list(topology.atoms())
pos = np.array(positions.value_in_unit(LENGTH))
system = forcefield.createSystem(topology)
bonds = [x for x in system.getForces() if type(x).__name__ == "HarmonicBondForce"][0]
# Initialise arrays for bond details
n_bonds = bonds.getNumBonds()
i = np.empty(n_bonds, dtype=int)
j = np.empty(n_bonds, dtype=int)
k = np.empty(n_bonds)
x0 = np.empty(n_bonds)
# Extract bond details to arrays
for n in range(n_bonds):
i[n],j[n],_x0,_k = bonds.getBondParameters(n)
k[n] = _k.value_in_unit(spring_unit)
x0[n] = _x0.value_in_unit(LENGTH)
# Check if there are any abnormally strained bond
distance = np.linalg.norm(pos[i] - pos[j], axis=-1)
check = k*(distance - x0)**2 > 100
# Return residues with strained bonds if any
return [atoms[x].residue for x in i[check]]
def strained_sidechain_bonds_fixer(strained_residues, topology, positions, n_threads=-1):
# Delete all atoms except the main chain for badly refined residues.
bb_atoms = ["N","CA","C"]
bad_side_chains = sum([[atom for atom in residue.atoms() if atom.name not in bb_atoms] for residue in strained_residues],[])
modeller = app.Modeller(topology, positions)
modeller.delete(bad_side_chains)
# Save model with deleted side chains to temporary file.
random_number = str(int(np.random.rand()*10**8))
tmp_file = f"side_chain_fix_tmp_{random_number}.pdb"
with open(tmp_file,"w") as handle:
app.PDBFile.writeFile(modeller.topology, modeller.positions, handle, keepIds=True)
# Load model into pdbfixer
fixer = pdbfixer.PDBFixer(tmp_file)
os.remove(tmp_file)
# Repair deleted side chains
fixer.findMissingResidues()
fixer.findMissingAtoms()
fixer.addMissingAtoms()
# Fill in the gaps with OpenMM Modeller
modeller = app.Modeller(fixer.topology, fixer.positions)
modeller.addHydrogens(forcefield)
# Set up force field
system = forcefield.createSystem(modeller.topology)
# Set up integrator
integrator = LangevinIntegrator(0, 0.01, 0.0)
# Set up the simulation
if n_threads > 0:
# Set number of threads used by OpenMM
platform = Platform.getPlatformByName('CPU')
simulation = app.Simulation(modeller.topology, system, integrator, platform, {'Threads', str(n_threads)})
else:
simulation = app.Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)
# Minimize the energy
simulation.minimizeEnergy()
return simulation.topology, simulation.context.getState(getPositions=True).getPositions()

313
ImmuneBuilder/rigids.py Normal file
View File

@@ -0,0 +1,313 @@
import torch
from ImmuneBuilder.constants import rigid_group_atom_positions2, chi2_centers, chi3_centers, chi4_centers, rel_pos, \
residue_atoms_mask
class Vector:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
self.shape = x.shape
assert (x.shape == y.shape) and (y.shape == z.shape), "x y and z should have the same shape"
def __add__(self, vec):
return Vector(vec.x + self.x, vec.y + self.y, vec.z + self.z)
def __sub__(self, vec):
return Vector(-vec.x + self.x, -vec.y + self.y, -vec.z + self.z)
def __mul__(self, param):
return Vector(param * self.x, param * self.y, param * self.z)
def __matmul__(self, vec):
return vec.x * self.x + vec.y * self.y + vec.z * self.z
def norm(self):
return (self.x ** 2 + self.y ** 2 + self.z ** 2 + 1e-8) ** (1 / 2)
def cross(self, other):
a = (self.y * other.z - self.z * other.y)
b = (self.z * other.x - self.x * other.z)
c = (self.x * other.y - self.y * other.x)
return Vector(a, b, c)
def dist(self, other):
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2 + 1e-8) ** (1 / 2)
def unsqueeze(self, dim):
return Vector(self.x.unsqueeze(dim), self.y.unsqueeze(dim), self.z.unsqueeze(dim))
def squeeze(self, dim):
return Vector(self.x.squeeze(dim), self.y.squeeze(dim), self.z.squeeze(dim))
def map(self, func):
return Vector(func(self.x), func(self.y), func(self.z))
def to(self, device):
return Vector(self.x.to(device), self.y.to(device), self.z.to(device))
def __str__(self):
return "Vector(x={},\ny={},\nz={})\n".format(self.x, self.y, self.z)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Vector(self.x[key], self.y[key], self.z[key])
class Rot:
def __init__(self, xx, xy, xz, yx, yy, yz, zx, zy, zz):
self.xx = xx
self.xy = xy
self.xz = xz
self.yx = yx
self.yy = yy
self.yz = yz
self.zx = zx
self.zy = zy
self.zz = zz
self.shape = xx.shape
def __matmul__(self, other):
if isinstance(other, Vector):
return Vector(
other.x * self.xx + other.y * self.xy + other.z * self.xz,
other.x * self.yx + other.y * self.yy + other.z * self.yz,
other.x * self.zx + other.y * self.zy + other.z * self.zz)
if isinstance(other, Rot):
return Rot(
xx=self.xx * other.xx + self.xy * other.yx + self.xz * other.zx,
xy=self.xx * other.xy + self.xy * other.yy + self.xz * other.zy,
xz=self.xx * other.xz + self.xy * other.yz + self.xz * other.zz,
yx=self.yx * other.xx + self.yy * other.yx + self.yz * other.zx,
yy=self.yx * other.xy + self.yy * other.yy + self.yz * other.zy,
yz=self.yx * other.xz + self.yy * other.yz + self.yz * other.zz,
zx=self.zx * other.xx + self.zy * other.yx + self.zz * other.zx,
zy=self.zx * other.xy + self.zy * other.yy + self.zz * other.zy,
zz=self.zx * other.xz + self.zy * other.yz + self.zz * other.zz,
)
else:
raise ValueError("Matmul against {}".format(type(other)))
def inv(self):
return Rot(
xx=self.xx, xy=self.yx, xz=self.zx,
yx=self.xy, yy=self.yy, yz=self.zy,
zx=self.xz, zy=self.yz, zz=self.zz
)
def det(self):
return self.xx * self.yy * self.zz + self.xy * self.yz * self.zx + self.yx * self.zy * self.xz - self.xz * self.yy * self.zx - self.xy * self.yx * self.zz - self.xx * self.zy * self.yz
def unsqueeze(self, dim):
return Rot(
self.xx.unsqueeze(dim=dim), self.xy.unsqueeze(dim=dim), self.xz.unsqueeze(dim=dim),
self.yx.unsqueeze(dim=dim), self.yy.unsqueeze(dim=dim), self.yz.unsqueeze(dim=dim),
self.zx.unsqueeze(dim=dim), self.zy.unsqueeze(dim=dim), self.zz.unsqueeze(dim=dim)
)
def squeeze(self, dim):
return Rot(
self.xx.squeeze(dim=dim), self.xy.squeeze(dim=dim), self.xz.squeeze(dim=dim),
self.yx.squeeze(dim=dim), self.yy.squeeze(dim=dim), self.yz.squeeze(dim=dim),
self.zx.squeeze(dim=dim), self.zy.squeeze(dim=dim), self.zz.squeeze(dim=dim)
)
def detach(self):
return Rot(
self.xx.detach(), self.xy.detach(), self.xz.detach(),
self.yx.detach(), self.yy.detach(), self.yz.detach(),
self.zx.detach(), self.zy.detach(), self.zz.detach()
)
def to(self, device):
return Rot(
self.xx.to(device), self.xy.to(device), self.xz.to(device),
self.yx.to(device), self.yy.to(device), self.yz.to(device),
self.zx.to(device), self.zy.to(device), self.zz.to(device)
)
def __str__(self):
return "Rot(xx={},\nxy={},\nxz={},\nyx={},\nyy={},\nyz={},\nzx={},\nzy={},\nzz={})\n".format(self.xx, self.xy,
self.xz, self.yx,
self.yy, self.yz,
self.zx, self.zy,
self.zz)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Rot(
self.xx[key], self.xy[key], self.xz[key],
self.yx[key], self.yy[key], self.yz[key],
self.zx[key], self.zy[key], self.zz[key]
)
class Rigid:
def __init__(self, origin, rot):
self.origin = origin
self.rot = rot
self.shape = self.origin.shape
def __matmul__(self, other):
if isinstance(other, Vector):
return self.rot @ other + self.origin
elif isinstance(other, Rigid):
return Rigid(self.rot @ other.origin + self.origin, self.rot @ other.rot)
else:
raise TypeError(f"can't multiply rigid by object of type {type(other)}")
def inv(self):
inv_rot = self.rot.inv()
t = inv_rot @ self.origin
return Rigid(Vector(-t.x, -t.y, -t.z), inv_rot)
def unsqueeze(self, dim=None):
return Rigid(self.origin.unsqueeze(dim=dim), self.rot.unsqueeze(dim=dim))
def squeeze(self, dim=None):
return Rigid(self.origin.squeeze(dim=dim), self.rot.squeeze(dim=dim))
def to(self, device):
return Rigid(self.origin.to(device), self.rot.to(device))
def __str__(self):
return "Rigid(origin={},\nrot={})".format(self.origin, self.rot)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Rigid(self.origin[key], self.rot[key])
def rigid_body_identity(shape):
return Rigid(Vector(*3 * [torch.zeros(shape)]),
Rot(torch.ones(shape), *3 * [torch.zeros(shape)], torch.ones(shape), *3 * [torch.zeros(shape)],
torch.ones(shape)))
def vec_from_tensor(tens):
assert tens.shape[-1] == 3, "What dimension you in?"
return Vector(tens[..., 0], tens[..., 1], tens[..., 2])
def rigid_from_three_points(origin, y_x_plane, x_axis):
v1 = x_axis - origin
v2 = y_x_plane - origin
v1 *= 1 / v1.norm()
v2 = v2 - v1 * (v1 @ v2)
v2 *= 1 / v2.norm()
v3 = v1.cross(v2)
rot = Rot(v1.x, v2.x, v3.x, v1.y, v2.y, v3.y, v1.z, v2.z, v3.z)
return Rigid(origin, rot)
def rigid_from_tensor(tens):
assert (tens.shape[-1] == 3), "I want 3D points"
return rigid_from_three_points(vec_from_tensor(tens[..., 0, :]), vec_from_tensor(tens[..., 1, :]),
vec_from_tensor(tens[..., 2, :]))
def stack_rigids(rigids, **kwargs):
# Probably best to avoid using very much
stacked_origin = Vector(torch.stack([rig.origin.x for rig in rigids], **kwargs),
torch.stack([rig.origin.y for rig in rigids], **kwargs),
torch.stack([rig.origin.z for rig in rigids], **kwargs))
stacked_rot = Rot(
torch.stack([rig.rot.xx for rig in rigids], **kwargs), torch.stack([rig.rot.xy for rig in rigids], **kwargs),
torch.stack([rig.rot.xz for rig in rigids], **kwargs),
torch.stack([rig.rot.yx for rig in rigids], **kwargs), torch.stack([rig.rot.yy for rig in rigids], **kwargs),
torch.stack([rig.rot.yz for rig in rigids], **kwargs),
torch.stack([rig.rot.zx for rig in rigids], **kwargs), torch.stack([rig.rot.zy for rig in rigids], **kwargs),
torch.stack([rig.rot.zz for rig in rigids], **kwargs),
)
return Rigid(stacked_origin, stacked_rot)
def rotate_x_axis_to_new_vector(new_vector):
# Extract coordinates
c, b, a = new_vector[..., 0], new_vector[..., 1], new_vector[..., 2]
# Normalize
n = (c ** 2 + a ** 2 + b ** 2 + 1e-16) ** (1 / 2)
a, b, c = a / n, b / n, -c / n
# Set new origin
new_origin = vec_from_tensor(torch.zeros_like(new_vector))
# Rotate x-axis to point old origin to new one
k = (1 - c) / (a ** 2 + b ** 2 + 1e-8)
new_rot = Rot(-c, b, -a, b, 1 - k * b ** 2, a * b * k, a, -a * b * k, k * a ** 2 - 1)
return Rigid(new_origin, new_rot)
def rigid_transformation_from_torsion_angles(torsion_angles, distance_to_new_origin):
dev = torsion_angles.device
zero = torch.zeros(torsion_angles.shape[:-1]).to(dev)
one = torch.ones(torsion_angles.shape[:-1]).to(dev)
new_rot = Rot(
-one, zero, zero,
zero, torsion_angles[..., 0], torsion_angles[..., 1],
zero, torsion_angles[..., 1], -torsion_angles[..., 0],
)
new_origin = Vector(distance_to_new_origin, zero, zero)
return Rigid(new_origin, new_rot)
def global_frames_from_bb_frame_and_torsion_angles(bb_frame, torsion_angles, seq):
dev = bb_frame.origin.x.device
# We start with psi
psi_local_frame_origin = torch.tensor([rel_pos[x][2][1] for x in seq]).to(dev).pow(2).sum(-1).pow(1 / 2)
psi_local_frame = rigid_transformation_from_torsion_angles(torsion_angles[:, 0], psi_local_frame_origin)
psi_global_frame = bb_frame @ psi_local_frame
# Now all the chis
chi1_local_frame_origin = torch.tensor([rel_pos[x][3][1] for x in seq]).to(dev)
chi1_local_frame = rotate_x_axis_to_new_vector(chi1_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 1], chi1_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi1_global_frame = bb_frame @ chi1_local_frame
chi2_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi2_centers[x]][1] for x in seq]).to(dev)
chi2_local_frame = rotate_x_axis_to_new_vector(chi2_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 2], chi2_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi2_global_frame = chi1_global_frame @ chi2_local_frame
chi3_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi3_centers[x]][1] for x in seq]).to(dev)
chi3_local_frame = rotate_x_axis_to_new_vector(chi3_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 3], chi3_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi3_global_frame = chi2_global_frame @ chi3_local_frame
chi4_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi4_centers[x]][1] for x in seq]).to(dev)
chi4_local_frame = rotate_x_axis_to_new_vector(chi4_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 4], chi4_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi4_global_frame = chi3_global_frame @ chi4_local_frame
return stack_rigids(
[bb_frame, psi_global_frame, chi1_global_frame, chi2_global_frame, chi3_global_frame, chi4_global_frame],
dim=-1)
def all_atoms_from_global_reference_frames(global_reference_frames, seq):
dev = global_reference_frames.origin.x.device
all_atoms = torch.zeros((len(seq), 14, 3)).to(dev)
for atom_pos in range(14):
relative_positions = [rel_pos[x][atom_pos][1] for x in seq]
local_reference_frame = [max(rel_pos[x][atom_pos][0] - 2, 0) for x in seq]
local_reference_frame_mask = torch.tensor([[y == x for y in range(6)] for x in local_reference_frame]).to(dev)
global_atom_vector = global_reference_frames[local_reference_frame_mask] @ vec_from_tensor(
torch.tensor(relative_positions).to(dev))
all_atoms[:, atom_pos] = torch.stack([global_atom_vector.x, global_atom_vector.y, global_atom_vector.z], dim=-1)
all_atom_mask = torch.tensor([residue_atoms_mask[x] for x in seq]).to(dev)
all_atoms[~all_atom_mask] = float("Nan")
return all_atoms

View File

@@ -0,0 +1,40 @@
from anarci import validate_sequence, anarci, scheme_short_to_long
def number_single_sequence(sequence, chain, scheme="imgt", allowed_species=['human','mouse']):
validate_sequence(sequence)
try:
if scheme != "raw":
scheme = scheme_short_to_long[scheme.lower()]
except KeyError:
raise NotImplementedError(f"Unimplemented numbering scheme: {scheme}")
assert len(sequence) > 70, f"Sequence too short to be an Ig domain. Please give whole sequence:\n{sequence}"
allow = [chain]
if chain == "L":
allow.append("K")
# Use imgt scheme for numbering sanity checks
numbered, _, _ = anarci([("sequence", sequence)], scheme='imgt', output=False, allow=set(allow), allowed_species=allowed_species)
assert numbered[0], f"Sequence provided as an {chain} chain is not recognised as an {chain} chain."
output = [x for x in numbered[0][0][0] if x[1] != "-"]
numbers = [x[0][0] for x in output]
# Check for missing residues assuming imgt numbering
assert (max(numbers) > 120) and (min(numbers) < 8), f"Sequence missing too many residues to model correctly. Please give whole sequence:\n{sequence}"
# Renumber once sanity checks done
if scheme == "raw":
output = [((i+1, " "),x[1]) for i,x in enumerate(output)]
elif scheme != 'imgt':
numbered, _, _ = anarci([("sequence", sequence)], scheme=scheme, output=False, allow=set(allow), allowed_species=allowed_species)
output = [x for x in numbered[0][0][0] if x[1] != "-"]
return output
def number_sequences(seqs, scheme="imgt", allowed_species=['human','mouse']):
return {chain: number_single_sequence(seqs[chain], chain, scheme=scheme, allowed_species=allowed_species) for chain in seqs}

View File

@@ -0,0 +1 @@

View File

View File

View File

View File

View File

View File

View File

View File

136
ImmuneBuilder/util.py Normal file
View File

@@ -0,0 +1,136 @@
from ImmuneBuilder.constants import res_to_num, atom_types, residue_atoms, restype_1to3, restypes
import numpy as np
import torch
import requests
import os
def download_file(url, filename):
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, 'wb+') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return filename
def get_one_hot(targets, nb_classes=21):
res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
return res.reshape(list(targets.shape) + [nb_classes])
def get_encoding(sequence_dict, chain_ids="HL"):
encodings = []
for j,chain in enumerate(chain_ids):
seq = sequence_dict[chain]
one_hot_amino = get_one_hot(np.array([res_to_num(x) for x in seq]))
one_hot_region = get_one_hot(j * np.ones(len(seq), dtype=int), 2)
encoding = np.concatenate([one_hot_amino, one_hot_region], axis=-1)
encodings.append(encoding)
return np.concatenate(encodings, axis = 0)
def find_alignment_transform(traces):
centers = traces.mean(-2, keepdim=True)
traces = traces - centers
p1, p2 = traces[0], traces[1:]
C = torch.einsum("i j k, j l -> i k l", p2, p1)
V, _, W = torch.linalg.svd(C)
U = torch.matmul(V, W)
U = torch.matmul(torch.stack([torch.ones(len(p2), device=U.device),torch.ones(len(p2), device=U.device),torch.linalg.det(U)], dim=1)[:,:,None] * V, W)
return torch.cat([torch.eye(3, device=U.device)[None], U]), centers
def to_pdb(numbered_sequences, all_atoms, chain_ids = "HL"):
atom_index = 0
pdb_lines = []
record_type = "ATOM"
seq = numbered_sequences[chain_ids[0]] + numbered_sequences[chain_ids[1]]
chain_index = [0]*len(numbered_sequences[chain_ids[0]]) + [1]*len(numbered_sequences[chain_ids[1]])
chain_id = chain_ids[0]
for i, amino in enumerate(seq):
for atom in atom_types:
if atom in residue_atoms[amino[1]]:
j = residue_atoms[amino[1]].index(atom)
pos = all_atoms[i, j]
if pos.mean() != pos.mean():
continue
name = f' {atom}'
alt_loc = ''
res_name_3 = restype_1to3[amino[1]]
if chain_id != chain_ids[chain_index[i]]:
chain_id = chain_ids[chain_index[i]]
occupancy = 1.00
b_factor = 0.00
element = atom[0]
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{(amino[0][0]):>4}{amino[0][1]:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1
return "\n".join(pdb_lines)
def sequence_dict_from_fasta(fasta_file):
out = {}
with open(fasta_file) as file:
txt = file.read().split()
for i in range(len(txt)-1):
if ">" in txt[i]:
chain_id = txt[i].split(">")[1]
else:
continue
if all(c in restypes for c in txt[i+1]):
out[chain_id] = txt[i+1]
return out
def add_errors_as_bfactors(filename, errors, header=[]):
with open(filename) as file:
txt = file.readlines()
new_txt = [x for x in header]
residue_index = -1
position = " "
for line in txt:
if line[:4] == "ATOM":
current_res = line[22:27]
if current_res != position:
position = current_res
residue_index += 1
line = line.replace(" 0.00 ",f"{errors[residue_index]:>6.2f} ")
elif "REMARK 1 CREATED WITH OPENMM" in line:
line = line.replace(" 1 CREATED WITH OPENMM", "STRUCTURE REFINED USING OPENMM")
line = line[:-1] + (81-len(line))*" " + "\n"
new_txt.append(line)
with open(filename, "w+") as file:
file.writelines(new_txt)
def are_weights_ready(weights_path):
if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
return False
with open(weights_path, "rb") as f:
filestart = str(f.readline())
return filestart != "b'EMPTY'"