Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths

This commit is contained in:
2026-03-17 17:57:24 +01:00
commit 6eef3bb748
108 changed files with 28144 additions and 0 deletions

308
rf2aa/data/covale.py Normal file
View File

@@ -0,0 +1,308 @@
import torch
from openbabel import openbabel
from typing import Optional
from dataclasses import dataclass
from tempfile import NamedTemporaryFile
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.parsers import parse_mol
from rf2aa.data.small_molecule import compute_features_from_obmol
from rf2aa.util import get_bond_feats
@dataclass
class MoleculeToMoleculeBond:
chain_index_first: int
absolute_atom_index_first: int
chain_index_second: int
absolute_atom_index_second: int
new_chirality_atom_first: Optional[str]
new_chirality_atom_second: Optional[str]
@dataclass
class AtomizedResidue:
chain: str
chain_index_in_combined_chain: int
absolute_N_index_in_chain: int
absolute_C_index_in_chain: int
original_chain: str
index_in_original_chain: int
def load_covalent_molecules(protein_inputs, config, model_runner):
if config.covale_inputs is None:
return None
if config.sm_inputs is None:
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
covalent_bonds = eval(config.covale_inputs)
sm_inputs = delete_leaving_atoms(config.sm_inputs)
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
chainid_to_input = {}
for chain, combined_molecule in combined_molecules.items():
extra_bonds_for_chain = extra_bonds[chain]
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
xyz = recompute_xyz_after_chirality(mol)
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
chainid_to_input[chain] = input
return chainid_to_input, residues_to_atomize
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
residues_to_atomize = [] # hold on to delete wayward inputs
combined_molecules = {} # combined multiple molecules that are bonded
extra_bonds = {}
for bond in covalent_bonds:
prot_chid, prot_res_idx, atom_to_bond = bond[0]
sm_chid, sm_atom_num = bond[1]
chirality_first_atom, chirality_second_atom = bond[2]
if chirality_first_atom.strip() == "null":
chirality_first_atom = None
if chirality_second_atom.strip() == "null":
chirality_second_atom = None
sm_atom_num = int(sm_atom_num) - 1 # 0 index
try:
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
except:
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
continue
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
try:
protein_input = protein_inputs[prot_chid]
except Exception as e:
raise ValueError(f"first atom in covale_input must be present in\
a protein chain. Given chain: {prot_chid} was not in \
given protein chains: {list(protein_inputs.keys())}")
residue = (prot_chid, prot_res_idx, atom_to_bond)
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
if sm_chid not in combined_molecules:
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
if sm_chid not in extra_bonds:
extra_bonds[sm_chid] = []
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
absolute_chain_index_first,
sm_atom_num,
absolute_chain_index_second,
atom_index,
new_chirality_atom_first=chirality_first_atom,
new_chirality_atom_second=chirality_second_atom
))
residues_to_atomize.append(AtomizedResidue(
sm_chid,
absolute_chain_index_second,
0,
2,
prot_chid,
int(prot_res_idx) -1
))
return residues_to_atomize, combined_molecules, extra_bonds
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
"""convert residue into sdf and record index for covalent bond"""
prot_chid, prot_res_idx, atom_to_bond = residue
protein_input = protein_inputs[prot_chid]
prot_res_abs_idx = int(prot_res_idx) -1
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
residue_identity = ChemData().num2aa[residue_identity_num]
molecule_info = model_runner.molecule_db[residue_identity]
sdf = molecule_info["sdf"]
temp_file = create_and_populate_temp_file(sdf)
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
atom_names = molecule_info["atom_id"]
atom_index = atom_names.index(atom_to_bond.strip())
return temp_file, atom_index
def get_combined_atoms_bonds(combined_molecule):
atom_list = []
bond_feats_list = []
xyzs = []
Ls = []
for molecule in combined_molecule:
obmol, msa, ins, xyz, mask = parse_mol(
molecule,
filetype="sdf",
string=False,
generate_conformer=True,
find_automorphs=False
)
bond_feats = get_bond_feats(obmol)
atom_list.append(msa)
bond_feats_list.append(bond_feats)
xyzs.append(xyz)
Ls.append(msa.shape[0])
atoms = torch.cat(atom_list)
L_total = sum(Ls)
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in bond_feats_list:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
xyz = torch.cat(xyzs, dim=1)[0]
return atoms, bond_feats, xyz, Ls
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
mol = openbabel.OBMol()
for i,k in enumerate(msa):
element = ChemData().num2aa[k]
atomnum = ChemData().atomtype2atomnum[element]
a = mol.NewAtom()
a.SetAtomicNum(atomnum)
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
first_index, second_index = bond_feats.nonzero(as_tuple=True)
for i, j in zip(first_index, second_index):
order = bond_feats[i,j]
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
mol.AddBond(bond)
for bond in extra_bonds:
absolute_index_first = get_absolute_index_from_relative_indices(
bond.chain_index_first,
bond.absolute_atom_index_first,
Ls
)
absolute_index_second = get_absolute_index_from_relative_indices(
bond.chain_index_second,
bond.absolute_atom_index_second,
Ls
)
order = 1 #all covale bonds are single bonds
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
mol.AddBond(openbabel_bond)
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
return mol
def make_openbabel_bond(mol, i, j, order):
obb = openbabel.OBBond()
obb.SetBegin(mol.GetAtom(i+1))
obb.SetEnd(mol.GetAtom(j+1))
if order == 4:
obb.SetBondOrder(2)
obb.SetAromatic()
else:
obb.SetBondOrder(order)
return obb
def set_chirality(mol, absolute_atom_index, new_chirality):
stereo = openbabel.OBStereoFacade(mol)
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
if tetstereo is None:
return
assert new_chirality is not None, "you have introduced a new stereocenter, \
so you must specify its chirality either as CW, or CCW"
config = tetstereo.GetConfig()
config.winding = chirality_options[new_chirality]
tetstereo.SetConfig(config)
print("Updating chirality...")
else:
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
chirality_options = {
"CW": openbabel.OBStereo.Clockwise,
"CCW": openbabel.OBStereo.AntiClockwise,
}
def recompute_xyz_after_chirality(obmol):
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
return atom_coords
def delete_leaving_atoms(sm_inputs):
updated_sm_inputs = {}
for chain in sm_inputs:
if "is_leaving" not in sm_inputs[chain]:
continue
is_leaving = eval(sm_inputs[chain]["is_leaving"])
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
updated_sm_inputs[chain] = {
"input": create_and_populate_temp_file(sdf_string),
"input_type": "sdf"
}
sm_inputs.update(updated_sm_inputs)
return sm_inputs
def delete_leaving_atoms_single_chain(filename, is_leaving):
obmol, msa, ins, xyz, mask = parse_mol(
filename,
filetype="sdf",
string=False,
generate_conformer=True
)
assert len(is_leaving) == obmol.NumAtoms()
leaving_indices = torch.tensor(is_leaving).nonzero()
for idx in leaving_indices:
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("sdf", "sdf")
sdf_string = obConversion.WriteString(obmol)
return sdf_string
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
offset = sum(Ls[:chain_index])
return offset + absolute_index_in_chain
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
updated_residues_to_atomize = []
for residue in residues_to_atomize:
if residue.chain == chain:
absolute_index_N = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_N_index_in_chain,
Ls)
absolute_index_C = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_C_index_in_chain,
Ls)
updated_residue = AtomizedResidue(
residue.chain,
None,
absolute_index_N,
absolute_index_C,
residue.original_chain,
residue.index_in_original_chain
)
updated_residues_to_atomize.append(updated_residue)
else:
updated_residues_to_atomize.append(residue)
return updated_residues_to_atomize
def create_and_populate_temp_file(data):
# Create a temporary file
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
# Write the string to the temporary file
temp_file.write(data)
# Get the filename
temp_file_name = temp_file.name
return temp_file_name

202
rf2aa/data/data_loader.py Normal file
View File

@@ -0,0 +1,202 @@
import torch
from dataclasses import dataclass, fields
from typing import Optional, List
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
from rf2aa.kinematics import xyz_to_t2d
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
Ls_from_same_chain_2d, idx_from_Ls, is_atom
@dataclass
class RawInputData:
msa: torch.Tensor
ins: torch.Tensor
bond_feats: torch.Tensor
xyz_t: torch.Tensor
mask_t: torch.Tensor
t1d: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
taxids: Optional[List[str]] = None
term_info: Optional[torch.Tensor] = None
chain_lengths: Optional[List] = None
idx: Optional[List] = None
def query_sequence(self):
return self.msa[0]
def sequence_string(self):
three_letter_sequence = [ChemData().num2aa[num] for num in self.query_sequence()]
return "".join([ChemData().aa_321[three] for three in three_letter_sequence])
def is_atom(self):
return is_atom(self.query_sequence())
def length(self):
return self.msa.shape[1]
def get_chain_bins_from_chain_lengths(self):
if self.chain_lengths is None:
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
setting chain_lengths. Chain_lengths is set in merge_inputs")
chain_bins = {}
running_length = 0
for chain, length in self.chain_lengths:
chain_bins[chain] = (running_length, running_length+length)
running_length = running_length + length
return chain_bins
def update_protein_features_after_atomize(self, residues_to_atomize):
if self.chain_lengths is None:
raise("Cannot update protein features without chain_lengths. \
merge_inputs must be called before this function")
chain_bins = self.get_chain_bins_from_chain_lengths()
keep = torch.ones(self.length())
prev_absolute_index = None
prev_C = None
#need to atomize residues from N term to Cterm to handle atomizing neighbors
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
for residue in residues_to_atomize:
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
# if residue is first in the chain, no extra bond feats to following residue
if absolute_index_in_combined_input != original_chain_start_index:
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
# if residue is last in chain, no extra bonds feats to following residue
if absolute_index_in_combined_input != original_chain_end_index-1:
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
keep[absolute_index_in_combined_input] = 0
# find neighboring residues that were atomized
if prev_absolute_index is not None:
if prev_absolute_index + 1 == absolute_index_in_combined_input:
self.bond_feats[prev_C, N_index] = 1
self.bond_feats[N_index, prev_C] = 1
prev_absolute_index = absolute_index_in_combined_input
prev_C = C_index
# remove protein features
self.keep_features(keep.bool())
def keep_features(self, keep):
if not torch.all(keep[self.is_atom()]):
raise ValueError("cannot remove atoms")
self.msa = self.msa[:,keep]
self.ins = self.ins[:,keep]
self.bond_feats = self.bond_feats[keep][:,keep]
self.xyz_t = self.xyz_t[:,keep]
self.t1d = self.t1d[:,keep]
self.mask_t = self.mask_t[:,keep]
if self.term_info is not None:
self.term_info = self.term_info[keep]
if self.idx is not None:
self.idx = self.idx[keep]
# assumes all chirals are after all protein residues
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
def construct_features(self, model_runner):
loader_params = model_runner.config.loader_params
B, L = 1, self.length()
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
self.msa.long(),
self.ins.long(),
loader_params,
p_mask=loader_params.get("p_msa_mask", 0),
term_info=self.term_info,
deterministic=model_runner.deterministic,
)
dist_matrix = get_bond_distances(self.bond_feats)
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
# xyz_prev = torch.nan_to_num(xyz_prev)
# NOTE: The above is the way things "should" be done, this is for compatability with training.
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
self.xyz_t = torch.nan_to_num(self.xyz_t)
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
t2d = t2d[0]
# get torsion angles from templates
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
alpha_t = alpha_t[0]
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
same_chain = same_chain_from_bond_feats(self.bond_feats)
return RFInput(
msa_latent=msa_seed,
msa_full=msa_extra,
seq=seq,
seq_unmasked=self.query_sequence(),
bond_feats=self.bond_feats,
dist_matrix=dist_matrix,
chirals=self.chirals,
atom_frames=self.atom_frames.long(),
xyz_prev=xyz_prev,
alpha_prev=alpha_prev,
t1d=self.t1d,
t2d=t2d,
xyz_t=self.xyz_t[..., 1, :],
alpha_t=alpha_t.float(),
mask_t=mask_t_2d.float(),
same_chain=same_chain.long(),
idx=self.idx
)
@dataclass
class RFInput:
msa_latent: torch.Tensor
msa_full: torch.Tensor
seq: torch.Tensor
seq_unmasked: torch.Tensor
idx: torch.Tensor
bond_feats: torch.Tensor
dist_matrix: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
xyz_prev: torch.Tensor
alpha_prev: torch.Tensor
t1d: torch.Tensor
t2d: torch.Tensor
xyz_t: torch.Tensor
alpha_t: torch.Tensor
mask_t: torch.Tensor
same_chain: torch.Tensor
msa_prev: Optional[torch.Tensor] = None
pair_prev: Optional[torch.Tensor] = None
state_prev: Optional[torch.Tensor] = None
mask_recycle: Optional[torch.Tensor] = None
def to(self, gpu):
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value.to(gpu))
def add_batch_dim(self):
""" mimic pytorch dataloader at inference time"""
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value[None])

View File

@@ -0,0 +1,909 @@
import torch
import warnings
import time
from icecream import ic
from torch.utils import data
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_dir)
sys.path.append(script_dir+'/../')
import numpy as np
import scipy
import networkx as nx
from rf2aa.data.parsers import parse_a3m, parse_pdb
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.util import random_rot_trans, \
is_atom, is_protein, is_nucleic, is_atom
def MSABlockDeletion(msa, ins, nb=5):
'''
Input: MSA having shape (N, L)
output: new MSA with block deletion
'''
N, L = msa.shape
block_size = max(int(N*0.3), 1)
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
to_delete = np.unique(np.clip(to_delete, 1, N-1))
#
mask = np.ones(N, bool)
mask[to_delete] = 0
return msa[mask], ins[mask]
def cluster_sum(data, assignment, N_seq, N_res):
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
return csum
def get_term_feats(Ls):
"""Creates N/C-terminus binary features"""
term_info = torch.zeros((sum(Ls),2)).float()
start = 0
for L_chain in Ls:
term_info[start, 0] = 1.0 # flag for N-term
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
start += L_chain
return term_info
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
'''
Input: full MSA information (after Block deletion if necessary) & full insertion information
Output: seed MSA features & extra sequences
Seed MSA features:
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
- profile of clustered sequences (22)
- insertion statistics (2)
- N-term or C-term? (2)
extra sequence features:
- aatype of extra sequence (22)
- insertion info (1)
- N-term or C-term? (2)
'''
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# TODO: delete me, just for testing purposes
msa = msa[:2]
if fixbb:
p_mask = 0
msa = msa[:1]
ins = ins[:1]
N, L = msa.shape
if term_info is None:
if len(L_s)==0:
L_s = [L]
term_info = get_term_feats(L_s)
term_info = term_info.to(msa.device)
#binding_site = torch.zeros((L,1), device=msa.device).float()
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
# raw MSA profile
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
# Select Nclust sequence randomly (seed MSA or latent MSA)
Nclust = (min(N, params['MAXLAT'])-1) // nmer
Nclust = Nclust*nmer + 1
if N > Nclust*2:
Nextra = N - Nclust
else:
Nextra = N
Nextra = min(Nextra, params['MAXSEQ']) // nmer
Nextra = max(1, Nextra * nmer)
#
b_seq = list()
b_msa_clust = list()
b_msa_seed = list()
b_msa_extra = list()
b_mask_pos = list()
for i_cycle in range(params['MAXCYCLE']):
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
sample = torch.stack(sample, dim=-1)
sample = sample.reshape(-1)
# add MSA clusters pre-chosen before calling this function
if seed_msa_clus is not None:
sample_orig_shape = sample.shape
sample_seed = seed_msa_clus[i_cycle]
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
N_sample_more = len(sample) - len(sample_seed)
if N_sample_more > 0:
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
sample = torch.cat([sample_seed, sample_more])
else:
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
# 15% random masking
# - 10%: aa replaced with a uniformly sampled random amino acid
# - 10%: aa replaced with an amino acid sampled from the MSA profile
# - 10%: not replaced
# - 70%: replaced with a special token ("mask")
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
# explicitly remove probabilities from nucleic acids and atoms
#same_aa[..., ChemData().NPROTAAS:] = 0
#raw_profile[...,ChemData().NPROTAAS:] = 0
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
# explicitly set the probability of masking for nucleic acids and atoms
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
sampler = torch.distributions.categorical.Categorical(probs=probs)
mask_sample = sampler.sample()
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
use_seq = msa_clust
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
b_seq.append(msa_masked[0].clone())
## get extra sequenes
if N > Nclust*2: # there are enough extra sequences
msa_extra = msa[1:,:][sample[Nclust-1:]]
ins_extra = ins[1:,:][sample[Nclust-1:]]
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
elif N - Nclust < 1:
msa_extra = msa_masked.clone()
ins_extra = ins_clust.clone()
extra_mask = mask_pos.clone()
else:
msa_add = msa[1:,:][sample[Nclust-1:]]
ins_add = ins[1:,:][sample[Nclust-1:]]
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
N_extra = msa_extra.shape[0]
# clustering (assign remaining sequences to their closest cluster by Hamming distance
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
assignment = torch.argmax(agreement, dim=-1)
# seed MSA features
# 1. one_hot encoded aatype: msa_clust_onehot
# 2. cluster profile
count_extra = ~extra_mask
count_clust = ~mask_pos
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
count_profile += count_clust
count_profile += eps
msa_clust_profile /= count_profile[:,:,None]
# 3. insertion statistics
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
msa_clust_del += count_clust*ins_clust
msa_clust_del /= count_profile
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
#
if fixbb:
assert params['MAXCYCLE'] == 1
msa_clust_profile = msa_clust_onehot
msa_extra_onehot = msa_clust_onehot
ins_clust[:] = 0
ins_extra[:] = 0
# This is how it is done in rfdiff, but really it seems like it should be all 0.
# Keeping as-is for now for consistency, as it may be used in downstream masking done
# by apply_masks.
mask_pos = torch.full_like(msa_clust, 1).bool()
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
# extra MSA features
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
try:
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
except Exception as e:
print('msa_extra.shape',msa_extra.shape)
print('ins_extra.shape',ins_extra.shape)
if (tocpu):
b_msa_clust.append(msa_clust.cpu())
b_msa_seed.append(msa_seed.cpu())
b_msa_extra.append(msa_extra.cpu())
b_mask_pos.append(mask_pos.cpu())
else:
b_msa_clust.append(msa_clust)
b_msa_seed.append(msa_seed)
b_msa_extra.append(msa_extra)
b_mask_pos.append(mask_pos)
b_seq = torch.stack(b_seq)
b_msa_clust = torch.stack(b_msa_clust)
b_msa_seed = torch.stack(b_msa_seed)
b_msa_extra = torch.stack(b_msa_extra)
b_mask_pos = torch.stack(b_mask_pos)
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
conf = torch.zeros((n_tmpl, L, 1)).float()
t1d = torch.cat((t1d, conf), -1)
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
return xyz, t1d, mask_t, np.full((n_tmpl), "")
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
seqID_cut = params['SEQID']
if npick_global == None:
npick_global=max(npick, 1)
ntplt = len(tplt['ids'])
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
return blank_template(npick_global, qlen, random_noise)
# ignore templates having too high seqID
if seqID_cut <= 100.0:
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
else:
tplt_valid_idx = torch.arange(len(tplt['ids']))
# check again if there are templates having seqID < cutoff
ntplt = len(tplt['ids'])
npick = min(npick, ntplt)
if npick<1: # no templates
return blank_template(npick_global, qlen, random_noise)
if not pick_top: # select randomly among all possible templates
sample = torch.randperm(ntplt)[:npick]
else: # only consider top 50 templates
sample = torch.randperm(min(50,ntplt))[:npick]
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
t1d = torch.full((npick_global, qlen), 20).long()
t1d_val = torch.zeros((npick_global, qlen)).float()
for i,nt in enumerate(sample):
tplt_idx = tplt_valid_idx[nt]
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
pos = tplt['qmap'][0,sel,0] + offset
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
# 1-D features: alignment confidence
t1d[i,pos] = tplt['seq'][0,sel]
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
return xyz, t1d, mask_t, tplt_ids
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
"""Diagonally tiles template coordinates, 1d input features, and masks across
template and residue dimensions. 1st template is concatenated directly on residue
dimension after a random rotation & translation.
"""
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
i_tmpl = 0
i_res = 0
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
N_tmpl, L_tmpl = xyz_.shape[:2]
if i_tmpl == 0:
i1, i2 = 1, N_tmpl
else:
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
if not tplt_ids_out[0]: # only add first template
tplt_ids_out[0] = ids[0]
# remaining templates are diagonally tiled
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
tplt_ids_out[i1:i2] = ids[1:]
if i_tmpl == 0:
i_tmpl += N_tmpl
else:
i_tmpl += N_tmpl-1
i_res += L_tmpl
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
def generate_xyz_prev(xyz_t, mask_t, params):
"""
allows you to use different initializations for the coordinate track specified in params
"""
L = xyz_t.shape[1]
if params["BLACK_HOLE_INIT"]:
xyz_t, _, mask_t = blank_template(1, L)
return xyz_t[0].clone(), mask_t[0].clone()
### merge msa & insertion statistics of two proteins having different taxID
def merge_a3m_hetero(a3mA, a3mB, L_s):
# merge msa
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
msa = [query]
if a3mA['msa'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
msa.append(extra_A)
if a3mB['msa'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
msa.append(extra_B)
msa = torch.cat(msa, dim=0)
# merge ins
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
ins = [query]
if a3mA['ins'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
ins.append(extra_A)
if a3mB['ins'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
ins.append(extra_B)
ins = torch.cat(ins, dim=0)
a3m = {'msa': msa, 'ins': ins}
# merge taxids
if 'taxid' in a3mA and 'taxid' in a3mB:
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
return a3m
# merge msa & insertion statistics of units in homo-oligomers
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
N, L = msa_orig.shape[:2]
if mode == "repeat":
# AAAAAA
# AAAAAA
msa = torch.tile(msa_orig,(1,nmer))
ins = torch.tile(ins_orig,(1,nmer))
elif mode == "diag":
# AAAAAA
# A-----
# -A----
# --A---
# ---A--
# ----A-
# -----A
N = N - 1
new_N = 1 + N * nmer
new_L = L * nmer
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
start_L = 0
start_N = 1
for i_c in range(nmer):
msa[0, start_L:start_L+L] = msa_orig[0]
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
ins[0, start_L:start_L+L] = ins_orig[0]
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
start_L += L
start_N += N
else:
# AAAAAA
# A-----
# -AAAAA
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
msa[:N, :L] = msa_orig
ins[:N, :L] = ins_orig
start = L
for i_c in range(1,nmer):
msa[0, start:start+L] = msa_orig[0]
msa[N:, start:start+L] = msa_orig[1:]
ins[0, start:start+L] = ins_orig[0]
ins[N:, start:start+L] = ins_orig[1:]
start += L
return {"msa": msa, "ins": ins}
def merge_msas(a3m_list, L_s):
"""
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
combined MSA
"""
seen = set()
taxIDs = []
a3mA = a3m_list[0]
taxIDs.extend(a3mA["taxID"])
seen.update(a3mA["hash"])
msaA, insA = a3mA["msa"], a3mA["ins"]
for i in range(1, len(a3m_list)):
a3mB = a3m_list[i]
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
a3mA = {"msa": msaA, "ins": insA}
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
msaA, insA = a3mA["msa"], a3mA["ins"]
taxIDs.extend(a3mB["taxID"])
else:
final_pairsA = []
final_pairsB = []
msaB, insB = a3mB["msa"], a3mB["ins"]
for pair in pair_taxIDs:
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
final_pairsA.append(pair_a3mA[msaApair])
final_pairsB.append(pair_a3mB[msaBpair])
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
paired_msaB[final_pairsA] = msaB[final_pairsB]
msaA = torch.cat([msaA, paired_msaB], dim=1)
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
seen.update(a3mB["hash"])
return msaA, insA
def remove_all_gap_seqs(a3m):
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
a3m['msa'] = a3m['msa'][idx_seq_keep]
a3m['ins'] = a3m['ins'][idx_seq_keep]
return a3m
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
ID, only the sequence with the highest sequence identity to the query (1st
sequence in MSA) will be paired.
Sequences that aren't paired will be padded and added to the bottom of the
joined MSA. If a subregion of the input MSAs overlap (represent the same
chain), the subregion residue indices can be given as `idx_overlap`, and
the overlap region of the unpaired sequences will be included in the joined
MSA.
Parameters
----------
a3mA : dict
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
boolean tensor indicating whether each sequence is fully paired. Can be
a multi-MSA (contain >2 sub-MSAs).
a3mB : dict
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
idx_overlap : tuple or list (optional)
Start and end indices of overlap region in 1st MSA, followed by the
same in 2nd MSA.
Returns
-------
a3m : dict
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
"""
# preprocess overlap region
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
if idx_overlap is not None:
i1A, i2A, i1B, i2B = idx_overlap
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
else:
i1B_new, i2B_new = (0, L_B)
# pair sequences
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
i_pairedA, i_pairedB = [], []
for taxid in taxids_shared:
i_match = np.where(a3mA['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
i_pairedA.append(i_match[i_match_best])
i_match = np.where(a3mB['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
i_pairedB.append(i_match[i_match_best])
# unpaired sequences
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
# handle overlap region
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
# this diagram shows how the variables below make up the final multi-MSA
# (* denotes nongaps, - denotes gaps)
# 1 2 3 4
# |*|*|*|*| msa_paired
# |*|*|*|-| msaA_unpaired
# |-|*|-|*| msaB_unpaired
if idx_overlap is not None:
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
a3mB['msa'][i_pairedB, i1B:i2B],
a3mA['msa'][i_pairedA, i2A:],
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B:i2B],
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
else:
# no overlap region, simple offset pad & stack
# this code is actually a special case of "if" block above, but writing
# this out explicitly here to make the logic more clear
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
# stack paired & unpaired
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
# output seq is fully paired if seqs A & B both started out as paired and were paired to
# each other on tax ID.
# NOTE: there is a rare edge case that is ignored here for simplicity: if
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
# will not be labeled as such here
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
# insertion features in paired MSAs are assumed to be zero
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
return a3m
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
"""Load a multi-MSA, which is a MSA that is paired across more than 2
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
duplicate portions of the MSA for homo-oligomer repeated chains.
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
it can) and pairs sequences across as many sub-MSAs as possible by matching
taxonomic ID. For details on how pairing is done, see
`join_msas_by_taxid()`
Parameters
----------
hash_list : list of str
Hashes of MSAs to load and join. Must not contain duplicates.
taxid_list : list of str
Taxonomic IDs of query sequences of each input MSA.
Ls : list of int
Lengths of the chains corresponding to the hashes.
Returns
-------
a3m_out : dict
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
hashes_out : list of str
Hashes of MSAs in the order that they are joined in `a3m_out`.
Contains the same elements as the input `hash_list` but may be in a
different order.
Ls_out : list of int
Lengths of each chain in `a3m_out`
"""
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
# comprising sub-MSAs whose indices in the input lists are
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
a3m_list = [] # list of multi-MSAs
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
# rare and so is ignored here for simplicity.
N = len(hash_list)
for i1, i2 in itertools.permutations(range(N),2):
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
if i1 in idx_list and i2 in idx_list: continue # already loaded
if i1 == '' or i2 == '': continue # no taxID means no pMSA
# a paired MSA exists
if taxid_list[i1]==taxid_list[i2]:
h1, h2 = hash_list[i1], hash_list[i2]
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
if os.path.exists(fn):
msa, ins, taxid = parse_a3m(fn, paired=True)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
is_paired=torch.ones(msa.shape[0]).bool())
res_range1 = (0,Ls[i1])
res_range2 = (Ls[i1],msa.shape[1])
# both hashes are new, add paired MSA to list
if i1 not in idx_list and i2 not in idx_list:
a3m_list.append(a3m_new)
idx_list_groups.append([i1,i2])
res_range_groups.append([res_range1, res_range2])
# one of the hashes is already in a multi-MSA
# find that multi-MSA and join the new pMSA to it
elif i1 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i1`?
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i2)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range2[1] - res_range2[0]
res_range_groups[i_a3m].append((L, L+L_new))
elif i2 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i2`?
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i1)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range1[1] - res_range1[0]
res_range_groups[i_a3m].append((L, L+L_new))
# add unpaired MSAs
# ungroup hash indices now, since we're done making multi-MSAs
idx_list = [x for group in idx_list_groups for x in group]
for i in range(N):
if i not in idx_list:
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
msa, ins, taxid = parse_a3m(fn)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
a3m_list.append(a3m_new)
idx_list.append(i)
Ls_out = [Ls[i] for i in idx_list]
hashes_out = [hash_list[i] for i in idx_list]
# join multi-MSAs & unpaired MSAs
a3m_out = a3m_list[0]
for i in range(1, len(a3m_list)):
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
return a3m_out, hashes_out, Ls_out
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out):
"""Expands a multi-MSA of unique chains into an MSA of a
hetero-homo-oligomer in which some chains appear more than once. The query
sequences (1st sequence of MSA) are concatenated directly along the
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
stacked") so that exact repeat sequences aren't paired.
For example, if the original multi-MSA contains unique chains 1,2,3 but
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
(where - denotes a block of gap characters):
1 2 - 3 - -
- - 1 - 3 -
- - - - - 1
Parameters
----------
a3m : dict
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
representing the multi-MSA of unique chains.
hashes_in : list of str
Unique MSA hashes used in `a3m`.
hashes_out : list of str
Non-unique MSA hashes desired in expanded MSA.
Ls_in : list of int
Lengths of each chain in `a3m`
Ls_out : list of int
Lengths of each chain desired in expanded MSA.
params : dict
Data loading parameters
Returns
-------
a3m : dict
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
taxids because no further joining needs to be done.
"""
assert(len(hashes_out)==len(Ls_out))
assert(set(hashes_in)==set(hashes_out))
assert(a3m['msa'].shape[1]==sum(Ls_in))
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
# each new repeat will be offset in sequence dimension of final MSA
counts = dict()
n_copy = [] # n-th copy of this hash in `hashes`
for h in hashes_out:
if h in counts:
counts[h] += 1
else:
counts[h] = 1
n_copy.append(counts[h])
# num sequences in source & destination MSAs
N_in = a3m['msa'].shape[0]
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
# source MSA
msa_in, ins_in = a3m['msa'], a3m['ins']
# initialize destination MSA to gap characters
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
ins_out = torch.full((N_out, sum(Ls_out)), 0)
# for each destination chain
for i_out, h_out in enumerate(hashes_out):
# identify index of source chain
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
# residue indexes
i1_res_in = sum(Ls_in[:i_in])
i2_res_in = sum(Ls_in[:i_in+1])
i1_res_out = sum(Ls_out[:i_out])
i2_res_out = sum(Ls_out[:i_out+1])
# copy over query sequence
# NOTE: There is a bug in these next two lines!
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
# offset non-query sequences along sequence dimension based on repeat number of a given hash
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
# copy over non-query sequences
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
# only 1st oligomeric repeat can be fully paired
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
a3m_out = remove_all_gap_seqs(a3m_out)
return a3m_out
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
locate paired MSAs and pair sequences across all chains by taxonomic ID.
Unpaired sequences are padded and stacked on the bottom.
"""
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
hashes = []
hashes_unique = []
taxids_unique = []
Ls_unique = []
for chid,L_ in zip(chain_ids, Ls):
hashes.append(chid2hash[chid])
if chid2hash[chid] not in hashes_unique:
hashes_unique.append(chid2hash[chid])
taxids_unique.append(chid2taxid.get(chid))
Ls_unique.append(L_)
# loads multi-MSA for unique chains
a3m_prot, hashes_unique, Ls_unique = \
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
# expands multi-MSA to repeat chains of homo-oligomers
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
return a3m_prot
def choose_multimsa_clusters(msa_seq_is_paired, params):
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
clusters during MSA featurization.
"""
frac_paired = msa_seq_is_paired.float().mean()
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
return None
else:
# ensure that half of the clusters are fully-paired sequences,
# and let the rest be chosen randomly
N_seed = params['MAXLAT']//2
msa_seed_clus = []
for i_cycle in range(params['MAXCYCLE']):
idx_paired = torch.where(msa_seq_is_paired)[0]
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
return msa_seed_clus
#fd
def get_bond_distances(bond_feats):
atom_bonds = (bond_feats > 0)*(bond_feats<5)
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
return torch.from_numpy(dist_matrix).float()
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
xyz, mask, res_idx = parse_pdb(pdbfilename)
plddt = np.load(plddtfilename)
# update mask info with plddt (ignore sidechains if plddt < 90.0)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > sccut] = True
mask_lddt[:,:5] = True
mask = np.logical_and(mask, mask_lddt)
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
def get_msa(a3mfilename, item, maxseq=5000):
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}

208
rf2aa/data/merge_inputs.py Normal file
View File

@@ -0,0 +1,208 @@
import torch
from hashlib import md5
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, merge_hetero_templates, get_term_feats, join_msas_by_taxid, expand_multi_msa
from rf2aa.data.data_loader import RawInputData
from rf2aa.util import center_and_realign_missing, same_chain_from_bond_feats, random_rot_trans, idx_from_Ls
def merge_protein_inputs(protein_inputs, deterministic: bool = False):
if len(protein_inputs) == 0:
return None,[]
elif len(protein_inputs) == 1:
chain = list(protein_inputs.keys())[0]
input = list(protein_inputs.values())[0]
xyz_t = input.xyz_t
xyz_t[0:1] = random_rot_trans(xyz_t[0:1], deterministic=deterministic)
input.xyz_t = xyz_t
return input, [(chain, input.length())]
# handle merging MSAs and such
# first determine which sequence are identical, then which one have mergeable MSAs
# then cat the templates, other feats
else:
a3m_list = [
{"msa": input.msa,
"ins": input.ins,
"taxid": input.taxids
}
for input in protein_inputs.values()
]
hash_list = [md5(input.sequence_string().encode()).hexdigest() for input in protein_inputs.values()]
lengths_list = [input.length() for input in protein_inputs.values()]
seen = set()
unique_indices = []
for idx, hash in enumerate(hash_list):
if hash not in seen:
unique_indices.append(idx)
seen.add(hash)
unique_a3m = [a3m for i, a3m in enumerate(a3m_list) if i in unique_indices ]
unique_hashes = [value for index, value in enumerate(hash_list) if index in unique_indices]
unique_lengths_list = [value for index, value in enumerate(lengths_list) if index in unique_indices]
if len(unique_a3m) >1:
a3m_out = unique_a3m[0]
for unq_a3m in unique_a3m[1:]:
a3m_out = join_msas_by_taxid(a3m_out, unq_a3m)
a3m_out = expand_multi_msa(a3m_out, unique_hashes, hash_list, unique_lengths_list, lengths_list)
else:
a3m = unique_a3m[0]
msa, ins = a3m["msa"], a3m["ins"]
a3m_out = merge_a3m_homo(msa, ins, len(hash_list))
# merge templates
max_template_dim = max([input.xyz_t.shape[0] for input in protein_inputs.values()])
xyz_t_list = [input.xyz_t for input in protein_inputs.values()]
mask_t_list = [input.mask_t for input in protein_inputs.values()]
t1d_list = [input.t1d for input in protein_inputs.values()]
ids = ["inference"] * len(t1d_list)
xyz_t, t1d, mask_t, _ = merge_hetero_templates(xyz_t_list, t1d_list, mask_t_list, ids, lengths_list, deterministic=deterministic)
atom_frames = torch.zeros(0,3,2)
chirals = torch.zeros(0,5)
L_total = sum(lengths_list)
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in [input.bond_feats for input in protein_inputs.values()]:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
chain_lengths = list(zip(protein_inputs.keys(), lengths_list))
merged_input = RawInputData(
a3m_out["msa"],
a3m_out["ins"],
bond_feats,
xyz_t[:max_template_dim],
mask_t[:max_template_dim],
t1d[:max_template_dim],
chirals,
atom_frames,
taxids=None
)
return merged_input, chain_lengths
def merge_na_inputs(na_inputs):
# should just be trivially catting features
running_inputs = None
chain_lengths = []
for chid, input in na_inputs.items():
running_inputs = merge_two_inputs(running_inputs, input)
chain_lengths.append((chid, input.length()))
return running_inputs, chain_lengths
def merge_sm_inputs(sm_inputs):
# should be trivially catting features
running_inputs = None
chain_lengths = []
for chid, input in sm_inputs.items():
running_inputs = merge_two_inputs(running_inputs, input)
chain_lengths.append((chid, input.length()))
return running_inputs, chain_lengths
def merge_two_inputs(first_input, second_input):
# merges two arbitrary inputs of data types
if first_input is None and second_input is None:
return None
elif first_input is None:
return second_input
elif second_input is None:
return first_input
Ls = [first_input.length(), second_input.length()]
L_total = sum(Ls)
# merge msas
a3m_first = {
"msa": first_input.msa,
"ins": first_input.ins,
}
a3m_second = {
"msa": second_input.msa,
"ins": second_input.ins,
}
a3m = merge_a3m_hetero(a3m_first, a3m_second, Ls)
# merge bond_feats
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in [first_input.bond_feats, second_input.bond_feats]:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
# merge templates
xyz_t = torch.cat([first_input.xyz_t, second_input.xyz_t],dim=1)
t1d = torch.cat([first_input.t1d, second_input.t1d],dim=1)
mask_t = torch.cat([first_input.mask_t, second_input.mask_t],dim=1)
# handle chirals (need to residue offset)
if second_input.chirals.shape[0] > 0 :
second_input.chirals[:, :-1] = second_input.chirals[:, :-1] + first_input.length()
chirals = torch.cat([first_input.chirals, second_input.chirals])
# cat atom frames
atom_frames = torch.cat([first_input.atom_frames, second_input.atom_frames])
# return new object
return RawInputData(
a3m["msa"],
a3m["ins"],
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=None
)
def merge_all(
protein_inputs,
na_inputs,
sm_inputs,
residues_to_atomize,
deterministic: bool = False,
):
protein_inputs, protein_chain_lengths = merge_protein_inputs(protein_inputs, deterministic=deterministic)
na_inputs, na_chain_lengths = merge_na_inputs(na_inputs)
sm_inputs, sm_chain_lengths = merge_sm_inputs(sm_inputs)
if protein_inputs is None and na_inputs is None and sm_inputs is None:
raise ValueError("No valid inputs were provided")
running_inputs = merge_two_inputs(protein_inputs, na_inputs) #could handle pairing protein/NA MSAs here
running_inputs = merge_two_inputs(running_inputs, sm_inputs)
all_chain_lengths = protein_chain_lengths + na_chain_lengths + sm_chain_lengths
running_inputs.chain_lengths = all_chain_lengths
all_lengths = get_Ls_from_chain_lengths(running_inputs.chain_lengths)
protein_lengths = get_Ls_from_chain_lengths(protein_chain_lengths)
term_info = get_term_feats(all_lengths)
term_info[sum(protein_lengths):, :] = 0
running_inputs.term_info = term_info
xyz_t = running_inputs.xyz_t
mask_t = running_inputs.mask_t
same_chain = same_chain = same_chain_from_bond_feats(running_inputs.bond_feats)
ntempl = xyz_t.shape[0]
xyz_t = torch.stack(
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
)
xyz_t = torch.nan_to_num(xyz_t)
running_inputs.xyz_t = xyz_t
running_inputs.idx = idx_from_Ls(all_lengths)
# after everything is merged need to add bond feats for covales
# reindex protein feats function
if residues_to_atomize:
running_inputs.update_protein_features_after_atomize(residues_to_atomize)
return running_inputs
def get_Ls_from_chain_lengths(chain_lengths):
return [val[1] for val in chain_lengths]

View File

@@ -0,0 +1,46 @@
import numpy as np
import torch
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
from rf2aa.data.data_loader import RawInputData
from rf2aa.util import get_protein_bond_feats
def load_nucleic_acid(fasta_fn, input_type, model_runner):
if input_type not in ["dna", "rna"]:
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
if input_type == "dna":
dna_alphabet = True
rna_alphabet = False
elif input_type == "rna":
dna_alphabet = False
rna_alphabet = True
loader_params = model_runner.config.loader_params
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
if (msa.shape[0] > loader_params["MAXSEQ"]):
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
idxs_tokeep[0] = 0
msa = msa[idxs_tokeep]
ins = ins[idxs_tokeep]
if len(L) > 1:
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
L = L[0]
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=None,
)

812
rf2aa/data/parsers.py Normal file
View File

@@ -0,0 +1,812 @@
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import rf2aa.util as util
import gzip
import rf2aa
from rf2aa.ffindex import *
import torch
from openbabel import openbabel
from rf2aa.chemical import ChemicalData as ChemData
def get_dislf(seq, xyz, mask):
L = seq.shape[0]
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
sgs = xyz[resolved_cys_mask,5]
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
dslf = []
for i in is_dslf.nonzero():
dslf.append( (
resolved_cys_mask[ii[i]].item(),
resolved_cys_mask[jj[i]].item(),
) )
return dslf
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
print ('read_multichain_pdb',tmpl_chain)
# get full sequence from PDB
seq_full = list()
L_s = list()
prev_chain=''
offset = 0
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
L_s.append(len(seq_full) - offset)
seq_full = torch.tensor(seq_full).long()
L = len(seq_full)
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
msa[1,:L_s[0]] = 20
msa[2,L_s[0]:] = 20
ins = torch.zeros_like(msa)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
mask = torch.full((1, L, ChemData().NTOTAL), False)
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
seq = torch.full((1, L,), 20).long()
conf = torch.zeros(1, L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
outbatch = 0
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
xyz[0, idx, i_atm, :] = xyz_i
mask[0, idx, i_atm] = True
if line[21] == tmpl_chain:
xyz_t[0, idx, i_atm, :] = xyz_i
mask_t[0, idx, i_atm] = True
break
seq[0, idx] = aa_idx
if (mask_t.any()):
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
dslf = get_dislf(seq[0], xyz[0], mask[0])
# assign confidence 'CONF' to all residues with backbone in template
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
t1d = torch.cat((seq_1hot, conf), -1)
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# Parse a fasta file containing multiple chains separated by '/'
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
L_s = []
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
if L_s == []:
L_s = [len(x) for x in msa_i.split('/')]
msa_i = msa_i.replace('/','')
msa.append(msa_i)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
if rna_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGUN"), dtype='|S1').view(np.uint8)
elif dna_alphabet:
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins,L_s
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=10000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=8000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, maxseq=8000, paired=False):
msa = []
ins = []
taxIDs = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if filename.split('.')[-1] == 'gz':
fstream = gzip.open(filename, 'rt')
else:
fstream = open(filename, 'r')
for i, line in enumerate(fstream):
# skip labels
if line[0] == '>':
if paired: # paired MSAs only have a TAXID in the fasta header
taxIDs.append(line[1:].strip())
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
if i == 0:
taxIDs.append("query")
else:
match = re.search( r'TaxID=(\d+)', line)
if match:
taxIDs.append(match.group(1))
else:
taxIDs.append("") # query sequence
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins, np.array(taxIDs)
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False, lddt_mask=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
plddt = [float(r[3]) for r in res]
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
if lddt_mask == True:
plddt = np.array(plddt)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > .85, 5:] = True
mask_lddt[plddt > .70, :5] = True
mask = np.logical_and(mask, mask_lddt)
return xyz,mask,np.array(idx_s), np.array(seq)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
if len(hits) == max_templ:
break
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos[:len(hits)]):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
# print ("Process %s..."%data[0])
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.int64)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.int64)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
ntmplatoms = xyz_t.shape[1]
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, mask, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, mask, f1d
def clean_sdffile(filename):
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
lines2 = []
with open(filename) as f:
lines = f.readlines()
num_atoms = int(lines[3][:3])
for i in range(len(lines)):
if i>=4 and i<4+num_atoms:
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
else:
lines2.append(lines[i])
molstring = ''.join(lines2)
return molstring
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
"""Parse small molecule ligand.
Parameters
----------
filename : str
filetype : str
string : bool
If True, `filename` is a string containing the molecule data.
remove_H : bool
Whether to remove hydrogen atoms.
find_automorphs : bool
Whether to enumerate atom symmetry permutations.
Returns
-------
obmol : OBMol
openbabel molecule object representing the ligand
msa : torch.Tensor (N_atoms,) long
Integer-encoded "sequence" (atom types) of ligand
ins : torch.Tensor (N_atoms,) long
Insertion features (all zero) for RF input
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
Atom coordinates
mask : torch.Tensor (N_symmetry, N_atoms) bool
Boolean mask for whether atom exists
"""
obConversion = openbabel.OBConversion()
obConversion.SetInFormat(filetype)
obmol = openbabel.OBMol()
if string:
obConversion.ReadString(obmol,filename)
elif filetype=='sdf':
molstring = clean_sdffile(filename)
obConversion.ReadString(obmol,molstring)
else:
obConversion.ReadFile(obmol,filename)
if generate_conformer:
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
if remove_H:
obmol.DeleteHydrogens()
# the above sometimes fails to get all the hydrogens
i = 1
while i < obmol.NumAtoms()+1:
if obmol.GetAtom(i).GetAtomicNum()==1:
obmol.DeleteAtom(obmol.GetAtom(i))
else:
i += 1
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
for i in range(1, obmol.NumAtoms()+1)]
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
ins = torch.zeros_like(msa)
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
if find_automorphs:
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
return obmol, msa, ins, atom_coords, mask

View File

@@ -0,0 +1,35 @@
import os
from hydra import initialize, compose
from pathlib import Path
import subprocess
#from rf2aa.run_inference import ModelRunner
def make_msa(
fasta_file,
chain,
model_runner
):
out_dir_base = Path(model_runner.config.output_path)
hash = model_runner.config.job_name
out_dir = out_dir_base / hash / chain
out_dir.mkdir(parents=True, exist_ok=True)
command = model_runner.config.database_params.command
search_base = model_runner.config.database_params.sequencedb
num_cpus = model_runner.config.database_params.num_cpus
ram_gb = model_runner.config.database_params.mem
template_database = model_runner.config.database_params.hhdb
out_a3m = out_dir / "t000_.msa0.a3m"
out_atab = out_dir / "t000_.atab"
out_hhr = out_dir / "t000_.hhr"
if out_a3m.exists() and out_atab.exists() and out_hhr.exists():
return out_a3m, out_hhr, out_atab
search_command = f"./{command} {fasta_file} {out_dir} {num_cpus} {ram_gb} {search_base} {template_database}"
print(search_command)
_ = subprocess.run(search_command, shell=True)
return out_a3m, out_hhr, out_atab

93
rf2aa/data/protein.py Normal file
View File

@@ -0,0 +1,93 @@
import torch
from rf2aa.data.data_loader import RawInputData
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
from rf2aa.data.preprocessing import make_msa
from rf2aa.util import get_protein_bond_feats
def get_templates(
qlen,
ffdb,
hhr_fn,
atab_fn,
seqID_cut,
n_templ,
pick_top: bool = True,
offset: int = 0,
random_noise: float = 5.0,
deterministic: bool = False,
):
(
xyz_parsed,
mask_parsed,
qmap_parsed,
f0d_parsed,
f1d_parsed,
seq_parsed,
ids_parsed,
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
tplt = {
"xyz": xyz_parsed.unsqueeze(0),
"mask": mask_parsed.unsqueeze(0),
"qmap": qmap_parsed.unsqueeze(0),
"f0d": f0d_parsed.unsqueeze(0),
"f1d": f1d_parsed.unsqueeze(0),
"seq": seq_parsed.unsqueeze(0),
"ids": ids_parsed,
}
params = {
"SEQID": seqID_cut,
}
return TemplFeaturize(
tplt,
qlen,
params,
offset=offset,
npick=n_templ,
pick_top=pick_top,
random_noise=random_noise,
deterministic=deterministic,
)
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
msa, ins, taxIDs = parse_a3m(msa_file)
# NOTE: this next line is a bug, but is the way that
# the code is written in the original implementation!
ins[0] = msa[0]
L = msa.shape[1]
if hhr_fn is None or atab_fn is None:
print("No templates provided")
xyz_t, t1d, mask_t, _ = blank_template(1, L)
else:
xyz_t, t1d, mask_t, _ = get_templates(
L,
model_runner.ffdb,
hhr_fn,
atab_fn,
seqID_cut=model_runner.config.loader_params.seqid,
n_templ=model_runner.config.loader_params.n_templ,
deterministic=model_runner.deterministic,
)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=taxIDs,
)
def generate_msa_and_load_protein(fasta_file, chain, model_runner):
msa_file, hhr_file, atab_file = make_msa(fasta_file, chain, model_runner)
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)

View File

@@ -0,0 +1,41 @@
import torch
from rf2aa.data.data_loader import RawInputData
from rf2aa.data.data_loader_utils import blank_template
from rf2aa.data.parsers import parse_mol
from rf2aa.kinematics import get_chirals
from rf2aa.util import get_bond_feats, get_nxgraph, get_atom_frames
def load_small_molecule(input_file, input_type, model_runner):
if input_type == "smiles":
is_string = True
else:
is_string = False
obmol, msa, ins, xyz, mask = parse_mol(
input_file, filetype=input_type, string=is_string, generate_conformer=True
)
return compute_features_from_obmol(obmol, msa, xyz, model_runner)
def compute_features_from_obmol(obmol, msa, xyz, model_runner):
L = msa.shape[0]
ins = torch.zeros_like(msa)
bond_feats = get_bond_feats(obmol)
xyz_t, t1d, mask_t, _ = blank_template(
model_runner.config.loader_params.n_templ,
L,
deterministic=model_runner.deterministic,
)
chirals = get_chirals(obmol, xyz[0])
G = get_nxgraph(obmol)
atom_frames = get_atom_frames(msa, G)
msa, ins = msa[None], ins[None]
return RawInputData(
msa, ins, bond_feats, xyz_t, mask_t, t1d, chirals, atom_frames, taxids=None
)
def remove_leaving_atoms(input, is_leaving):
keep = ~is_leaving
return input.keep_features(keep)