Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest - Update input/output paths to S3 (s3://omic/eureka/immunebuilder/) - Remove local mount containerOptions (not needed in k8s) - Update homepage to Gitea repo URL - Clean history to remove large model weight blobs
137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
from ImmuneBuilder.constants import res_to_num, atom_types, residue_atoms, restype_1to3, restypes
|
|
|
|
import numpy as np
|
|
import torch
|
|
import requests
|
|
import os
|
|
|
|
|
|
def download_file(url, filename):
|
|
with requests.get(url, stream=True) as r:
|
|
r.raise_for_status()
|
|
with open(filename, 'wb+') as f:
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
f.write(chunk)
|
|
return filename
|
|
|
|
|
|
def get_one_hot(targets, nb_classes=21):
|
|
res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
|
|
return res.reshape(list(targets.shape) + [nb_classes])
|
|
|
|
|
|
def get_encoding(sequence_dict, chain_ids="HL"):
|
|
|
|
encodings = []
|
|
|
|
for j,chain in enumerate(chain_ids):
|
|
seq = sequence_dict[chain]
|
|
one_hot_amino = get_one_hot(np.array([res_to_num(x) for x in seq]))
|
|
one_hot_region = get_one_hot(j * np.ones(len(seq), dtype=int), 2)
|
|
encoding = np.concatenate([one_hot_amino, one_hot_region], axis=-1)
|
|
encodings.append(encoding)
|
|
|
|
return np.concatenate(encodings, axis = 0)
|
|
|
|
|
|
def find_alignment_transform(traces):
|
|
centers = traces.mean(-2, keepdim=True)
|
|
traces = traces - centers
|
|
|
|
p1, p2 = traces[0], traces[1:]
|
|
|
|
C = torch.einsum("i j k, j l -> i k l", p2, p1)
|
|
V, _, W = torch.linalg.svd(C)
|
|
U = torch.matmul(V, W)
|
|
U = torch.matmul(torch.stack([torch.ones(len(p2), device=U.device),torch.ones(len(p2), device=U.device),torch.linalg.det(U)], dim=1)[:,:,None] * V, W)
|
|
|
|
return torch.cat([torch.eye(3, device=U.device)[None], U]), centers
|
|
|
|
|
|
def to_pdb(numbered_sequences, all_atoms, chain_ids = "HL"):
|
|
atom_index = 0
|
|
pdb_lines = []
|
|
record_type = "ATOM"
|
|
seq = numbered_sequences[chain_ids[0]] + numbered_sequences[chain_ids[1]]
|
|
chain_index = [0]*len(numbered_sequences[chain_ids[0]]) + [1]*len(numbered_sequences[chain_ids[1]])
|
|
chain_id = chain_ids[0]
|
|
|
|
for i, amino in enumerate(seq):
|
|
for atom in atom_types:
|
|
if atom in residue_atoms[amino[1]]:
|
|
j = residue_atoms[amino[1]].index(atom)
|
|
pos = all_atoms[i, j]
|
|
if pos.mean() != pos.mean():
|
|
continue
|
|
name = f' {atom}'
|
|
alt_loc = ''
|
|
res_name_3 = restype_1to3[amino[1]]
|
|
if chain_id != chain_ids[chain_index[i]]:
|
|
chain_id = chain_ids[chain_index[i]]
|
|
occupancy = 1.00
|
|
b_factor = 0.00
|
|
element = atom[0]
|
|
charge = ''
|
|
# PDB is a columnar format, every space matters here!
|
|
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
|
|
f'{res_name_3:>3} {chain_id:>1}'
|
|
f'{(amino[0][0]):>4}{amino[0][1]:>1} '
|
|
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
|
|
f'{occupancy:>6.2f}{b_factor:>6.2f} '
|
|
f'{element:>2}{charge:>2}')
|
|
pdb_lines.append(atom_line)
|
|
atom_index += 1
|
|
|
|
return "\n".join(pdb_lines)
|
|
|
|
|
|
def sequence_dict_from_fasta(fasta_file):
|
|
out = {}
|
|
|
|
with open(fasta_file) as file:
|
|
txt = file.read().split()
|
|
|
|
for i in range(len(txt)-1):
|
|
if ">" in txt[i]:
|
|
chain_id = txt[i].split(">")[1]
|
|
else:
|
|
continue
|
|
|
|
if all(c in restypes for c in txt[i+1]):
|
|
out[chain_id] = txt[i+1]
|
|
|
|
return out
|
|
|
|
|
|
def add_errors_as_bfactors(filename, errors, header=[]):
|
|
|
|
with open(filename) as file:
|
|
txt = file.readlines()
|
|
|
|
new_txt = [x for x in header]
|
|
residue_index = -1
|
|
position = " "
|
|
|
|
for line in txt:
|
|
if line[:4] == "ATOM":
|
|
current_res = line[22:27]
|
|
if current_res != position:
|
|
position = current_res
|
|
residue_index += 1
|
|
line = line.replace(" 0.00 ",f"{errors[residue_index]:>6.2f} ")
|
|
elif "REMARK 1 CREATED WITH OPENMM" in line:
|
|
line = line.replace(" 1 CREATED WITH OPENMM", "STRUCTURE REFINED USING OPENMM")
|
|
line = line[:-1] + (81-len(line))*" " + "\n"
|
|
new_txt.append(line)
|
|
|
|
with open(filename, "w+") as file:
|
|
file.writelines(new_txt)
|
|
|
|
|
|
def are_weights_ready(weights_path):
|
|
if not os.path.exists(weights_path) or os.path.getsize(weights_path) == 0:
|
|
return False
|
|
with open(weights_path, "rb") as f:
|
|
filestart = str(f.readline())
|
|
return filestart != "b'EMPTY'"
|