Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
308
rf2aa/data/covale.py
Normal file
308
rf2aa/data/covale.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import torch
|
||||
from openbabel import openbabel
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from tempfile import NamedTemporaryFile
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.data.parsers import parse_mol
|
||||
from rf2aa.data.small_molecule import compute_features_from_obmol
|
||||
from rf2aa.util import get_bond_feats
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoleculeToMoleculeBond:
|
||||
chain_index_first: int
|
||||
absolute_atom_index_first: int
|
||||
chain_index_second: int
|
||||
absolute_atom_index_second: int
|
||||
new_chirality_atom_first: Optional[str]
|
||||
new_chirality_atom_second: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class AtomizedResidue:
|
||||
chain: str
|
||||
chain_index_in_combined_chain: int
|
||||
absolute_N_index_in_chain: int
|
||||
absolute_C_index_in_chain: int
|
||||
original_chain: str
|
||||
index_in_original_chain: int
|
||||
|
||||
|
||||
def load_covalent_molecules(protein_inputs, config, model_runner):
|
||||
if config.covale_inputs is None:
|
||||
return None
|
||||
|
||||
if config.sm_inputs is None:
|
||||
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
|
||||
|
||||
covalent_bonds = eval(config.covale_inputs)
|
||||
sm_inputs = delete_leaving_atoms(config.sm_inputs)
|
||||
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
|
||||
chainid_to_input = {}
|
||||
for chain, combined_molecule in combined_molecules.items():
|
||||
extra_bonds_for_chain = extra_bonds[chain]
|
||||
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
|
||||
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
|
||||
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
|
||||
xyz = recompute_xyz_after_chirality(mol)
|
||||
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
|
||||
chainid_to_input[chain] = input
|
||||
|
||||
return chainid_to_input, residues_to_atomize
|
||||
|
||||
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
|
||||
residues_to_atomize = [] # hold on to delete wayward inputs
|
||||
combined_molecules = {} # combined multiple molecules that are bonded
|
||||
extra_bonds = {}
|
||||
for bond in covalent_bonds:
|
||||
prot_chid, prot_res_idx, atom_to_bond = bond[0]
|
||||
sm_chid, sm_atom_num = bond[1]
|
||||
chirality_first_atom, chirality_second_atom = bond[2]
|
||||
if chirality_first_atom.strip() == "null":
|
||||
chirality_first_atom = None
|
||||
if chirality_second_atom.strip() == "null":
|
||||
chirality_second_atom = None
|
||||
|
||||
sm_atom_num = int(sm_atom_num) - 1 # 0 index
|
||||
try:
|
||||
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
|
||||
except:
|
||||
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
|
||||
continue
|
||||
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
|
||||
try:
|
||||
protein_input = protein_inputs[prot_chid]
|
||||
except Exception as e:
|
||||
raise ValueError(f"first atom in covale_input must be present in\
|
||||
a protein chain. Given chain: {prot_chid} was not in \
|
||||
given protein chains: {list(protein_inputs.keys())}")
|
||||
|
||||
residue = (prot_chid, prot_res_idx, atom_to_bond)
|
||||
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
|
||||
if sm_chid not in combined_molecules:
|
||||
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
|
||||
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
|
||||
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
|
||||
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
|
||||
|
||||
if sm_chid not in extra_bonds:
|
||||
extra_bonds[sm_chid] = []
|
||||
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
|
||||
absolute_chain_index_first,
|
||||
sm_atom_num,
|
||||
absolute_chain_index_second,
|
||||
atom_index,
|
||||
new_chirality_atom_first=chirality_first_atom,
|
||||
new_chirality_atom_second=chirality_second_atom
|
||||
))
|
||||
residues_to_atomize.append(AtomizedResidue(
|
||||
sm_chid,
|
||||
absolute_chain_index_second,
|
||||
0,
|
||||
2,
|
||||
prot_chid,
|
||||
int(prot_res_idx) -1
|
||||
))
|
||||
|
||||
return residues_to_atomize, combined_molecules, extra_bonds
|
||||
|
||||
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
|
||||
"""convert residue into sdf and record index for covalent bond"""
|
||||
prot_chid, prot_res_idx, atom_to_bond = residue
|
||||
protein_input = protein_inputs[prot_chid]
|
||||
prot_res_abs_idx = int(prot_res_idx) -1
|
||||
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
|
||||
residue_identity = ChemData().num2aa[residue_identity_num]
|
||||
molecule_info = model_runner.molecule_db[residue_identity]
|
||||
sdf = molecule_info["sdf"]
|
||||
temp_file = create_and_populate_temp_file(sdf)
|
||||
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
|
||||
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
|
||||
|
||||
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
|
||||
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
|
||||
atom_names = molecule_info["atom_id"]
|
||||
atom_index = atom_names.index(atom_to_bond.strip())
|
||||
return temp_file, atom_index
|
||||
|
||||
def get_combined_atoms_bonds(combined_molecule):
|
||||
atom_list = []
|
||||
bond_feats_list = []
|
||||
xyzs = []
|
||||
Ls = []
|
||||
for molecule in combined_molecule:
|
||||
obmol, msa, ins, xyz, mask = parse_mol(
|
||||
molecule,
|
||||
filetype="sdf",
|
||||
string=False,
|
||||
generate_conformer=True,
|
||||
find_automorphs=False
|
||||
)
|
||||
bond_feats = get_bond_feats(obmol)
|
||||
|
||||
atom_list.append(msa)
|
||||
bond_feats_list.append(bond_feats)
|
||||
xyzs.append(xyz)
|
||||
Ls.append(msa.shape[0])
|
||||
|
||||
atoms = torch.cat(atom_list)
|
||||
L_total = sum(Ls)
|
||||
bond_feats = torch.zeros((L_total, L_total)).long()
|
||||
offset = 0
|
||||
for bf in bond_feats_list:
|
||||
L = bf.shape[0]
|
||||
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||||
offset += L
|
||||
xyz = torch.cat(xyzs, dim=1)[0]
|
||||
return atoms, bond_feats, xyz, Ls
|
||||
|
||||
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
|
||||
mol = openbabel.OBMol()
|
||||
for i,k in enumerate(msa):
|
||||
element = ChemData().num2aa[k]
|
||||
atomnum = ChemData().atomtype2atomnum[element]
|
||||
a = mol.NewAtom()
|
||||
a.SetAtomicNum(atomnum)
|
||||
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
|
||||
|
||||
first_index, second_index = bond_feats.nonzero(as_tuple=True)
|
||||
for i, j in zip(first_index, second_index):
|
||||
order = bond_feats[i,j]
|
||||
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
|
||||
mol.AddBond(bond)
|
||||
|
||||
for bond in extra_bonds:
|
||||
absolute_index_first = get_absolute_index_from_relative_indices(
|
||||
bond.chain_index_first,
|
||||
bond.absolute_atom_index_first,
|
||||
Ls
|
||||
)
|
||||
absolute_index_second = get_absolute_index_from_relative_indices(
|
||||
bond.chain_index_second,
|
||||
bond.absolute_atom_index_second,
|
||||
Ls
|
||||
)
|
||||
order = 1 #all covale bonds are single bonds
|
||||
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
|
||||
mol.AddBond(openbabel_bond)
|
||||
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
|
||||
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
|
||||
return mol
|
||||
|
||||
def make_openbabel_bond(mol, i, j, order):
|
||||
obb = openbabel.OBBond()
|
||||
obb.SetBegin(mol.GetAtom(i+1))
|
||||
obb.SetEnd(mol.GetAtom(j+1))
|
||||
if order == 4:
|
||||
obb.SetBondOrder(2)
|
||||
obb.SetAromatic()
|
||||
else:
|
||||
obb.SetBondOrder(order)
|
||||
return obb
|
||||
|
||||
def set_chirality(mol, absolute_atom_index, new_chirality):
|
||||
stereo = openbabel.OBStereoFacade(mol)
|
||||
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
|
||||
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
|
||||
if tetstereo is None:
|
||||
return
|
||||
|
||||
assert new_chirality is not None, "you have introduced a new stereocenter, \
|
||||
so you must specify its chirality either as CW, or CCW"
|
||||
|
||||
config = tetstereo.GetConfig()
|
||||
config.winding = chirality_options[new_chirality]
|
||||
tetstereo.SetConfig(config)
|
||||
print("Updating chirality...")
|
||||
else:
|
||||
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
|
||||
|
||||
chirality_options = {
|
||||
"CW": openbabel.OBStereo.Clockwise,
|
||||
"CCW": openbabel.OBStereo.AntiClockwise,
|
||||
}
|
||||
|
||||
def recompute_xyz_after_chirality(obmol):
|
||||
builder = openbabel.OBBuilder()
|
||||
builder.Build(obmol)
|
||||
ff = openbabel.OBForceField.FindForceField("mmff94")
|
||||
did_setup = ff.Setup(obmol)
|
||||
if did_setup:
|
||||
ff.FastRotorSearch()
|
||||
ff.GetCoordinates(obmol)
|
||||
else:
|
||||
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||||
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||
return atom_coords
|
||||
|
||||
def delete_leaving_atoms(sm_inputs):
|
||||
updated_sm_inputs = {}
|
||||
for chain in sm_inputs:
|
||||
if "is_leaving" not in sm_inputs[chain]:
|
||||
continue
|
||||
is_leaving = eval(sm_inputs[chain]["is_leaving"])
|
||||
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
|
||||
updated_sm_inputs[chain] = {
|
||||
"input": create_and_populate_temp_file(sdf_string),
|
||||
"input_type": "sdf"
|
||||
}
|
||||
|
||||
sm_inputs.update(updated_sm_inputs)
|
||||
return sm_inputs
|
||||
|
||||
def delete_leaving_atoms_single_chain(filename, is_leaving):
|
||||
obmol, msa, ins, xyz, mask = parse_mol(
|
||||
filename,
|
||||
filetype="sdf",
|
||||
string=False,
|
||||
generate_conformer=True
|
||||
)
|
||||
assert len(is_leaving) == obmol.NumAtoms()
|
||||
leaving_indices = torch.tensor(is_leaving).nonzero()
|
||||
for idx in leaving_indices:
|
||||
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
|
||||
obConversion = openbabel.OBConversion()
|
||||
obConversion.SetInAndOutFormats("sdf", "sdf")
|
||||
sdf_string = obConversion.WriteString(obmol)
|
||||
return sdf_string
|
||||
|
||||
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
|
||||
offset = sum(Ls[:chain_index])
|
||||
return offset + absolute_index_in_chain
|
||||
|
||||
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
|
||||
updated_residues_to_atomize = []
|
||||
for residue in residues_to_atomize:
|
||||
if residue.chain == chain:
|
||||
absolute_index_N = get_absolute_index_from_relative_indices(
|
||||
residue.chain_index_in_combined_chain,
|
||||
residue.absolute_N_index_in_chain,
|
||||
Ls)
|
||||
absolute_index_C = get_absolute_index_from_relative_indices(
|
||||
residue.chain_index_in_combined_chain,
|
||||
residue.absolute_C_index_in_chain,
|
||||
Ls)
|
||||
updated_residue = AtomizedResidue(
|
||||
residue.chain,
|
||||
None,
|
||||
absolute_index_N,
|
||||
absolute_index_C,
|
||||
residue.original_chain,
|
||||
residue.index_in_original_chain
|
||||
)
|
||||
updated_residues_to_atomize.append(updated_residue)
|
||||
else:
|
||||
updated_residues_to_atomize.append(residue)
|
||||
return updated_residues_to_atomize
|
||||
|
||||
def create_and_populate_temp_file(data):
|
||||
# Create a temporary file
|
||||
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
|
||||
# Write the string to the temporary file
|
||||
temp_file.write(data)
|
||||
|
||||
# Get the filename
|
||||
temp_file_name = temp_file.name
|
||||
|
||||
return temp_file_name
|
||||
202
rf2aa/data/data_loader.py
Normal file
202
rf2aa/data/data_loader.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import torch
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional, List
|
||||
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
|
||||
from rf2aa.kinematics import xyz_to_t2d
|
||||
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
|
||||
Ls_from_same_chain_2d, idx_from_Ls, is_atom
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RawInputData:
|
||||
msa: torch.Tensor
|
||||
ins: torch.Tensor
|
||||
bond_feats: torch.Tensor
|
||||
xyz_t: torch.Tensor
|
||||
mask_t: torch.Tensor
|
||||
t1d: torch.Tensor
|
||||
chirals: torch.Tensor
|
||||
atom_frames: torch.Tensor
|
||||
taxids: Optional[List[str]] = None
|
||||
term_info: Optional[torch.Tensor] = None
|
||||
chain_lengths: Optional[List] = None
|
||||
idx: Optional[List] = None
|
||||
|
||||
def query_sequence(self):
|
||||
return self.msa[0]
|
||||
|
||||
def sequence_string(self):
|
||||
three_letter_sequence = [ChemData().num2aa[num] for num in self.query_sequence()]
|
||||
return "".join([ChemData().aa_321[three] for three in three_letter_sequence])
|
||||
|
||||
def is_atom(self):
|
||||
return is_atom(self.query_sequence())
|
||||
|
||||
def length(self):
|
||||
return self.msa.shape[1]
|
||||
|
||||
def get_chain_bins_from_chain_lengths(self):
|
||||
if self.chain_lengths is None:
|
||||
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
|
||||
setting chain_lengths. Chain_lengths is set in merge_inputs")
|
||||
chain_bins = {}
|
||||
running_length = 0
|
||||
for chain, length in self.chain_lengths:
|
||||
chain_bins[chain] = (running_length, running_length+length)
|
||||
running_length = running_length + length
|
||||
return chain_bins
|
||||
|
||||
def update_protein_features_after_atomize(self, residues_to_atomize):
|
||||
if self.chain_lengths is None:
|
||||
raise("Cannot update protein features without chain_lengths. \
|
||||
merge_inputs must be called before this function")
|
||||
chain_bins = self.get_chain_bins_from_chain_lengths()
|
||||
keep = torch.ones(self.length())
|
||||
prev_absolute_index = None
|
||||
prev_C = None
|
||||
#need to atomize residues from N term to Cterm to handle atomizing neighbors
|
||||
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
|
||||
for residue in residues_to_atomize:
|
||||
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
|
||||
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
|
||||
|
||||
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
|
||||
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
|
||||
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
|
||||
# if residue is first in the chain, no extra bond feats to following residue
|
||||
if absolute_index_in_combined_input != original_chain_start_index:
|
||||
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
|
||||
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
|
||||
|
||||
# if residue is last in chain, no extra bonds feats to following residue
|
||||
if absolute_index_in_combined_input != original_chain_end_index-1:
|
||||
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
|
||||
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
|
||||
keep[absolute_index_in_combined_input] = 0
|
||||
|
||||
# find neighboring residues that were atomized
|
||||
if prev_absolute_index is not None:
|
||||
if prev_absolute_index + 1 == absolute_index_in_combined_input:
|
||||
self.bond_feats[prev_C, N_index] = 1
|
||||
self.bond_feats[N_index, prev_C] = 1
|
||||
|
||||
prev_absolute_index = absolute_index_in_combined_input
|
||||
prev_C = C_index
|
||||
# remove protein features
|
||||
self.keep_features(keep.bool())
|
||||
|
||||
def keep_features(self, keep):
|
||||
if not torch.all(keep[self.is_atom()]):
|
||||
raise ValueError("cannot remove atoms")
|
||||
self.msa = self.msa[:,keep]
|
||||
self.ins = self.ins[:,keep]
|
||||
self.bond_feats = self.bond_feats[keep][:,keep]
|
||||
self.xyz_t = self.xyz_t[:,keep]
|
||||
self.t1d = self.t1d[:,keep]
|
||||
self.mask_t = self.mask_t[:,keep]
|
||||
if self.term_info is not None:
|
||||
self.term_info = self.term_info[keep]
|
||||
if self.idx is not None:
|
||||
self.idx = self.idx[keep]
|
||||
# assumes all chirals are after all protein residues
|
||||
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
|
||||
|
||||
def construct_features(self, model_runner):
|
||||
loader_params = model_runner.config.loader_params
|
||||
B, L = 1, self.length()
|
||||
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
|
||||
self.msa.long(),
|
||||
self.ins.long(),
|
||||
loader_params,
|
||||
p_mask=loader_params.get("p_msa_mask", 0),
|
||||
term_info=self.term_info,
|
||||
deterministic=model_runner.deterministic,
|
||||
)
|
||||
dist_matrix = get_bond_distances(self.bond_feats)
|
||||
|
||||
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
|
||||
# xyz_prev = torch.nan_to_num(xyz_prev)
|
||||
|
||||
# NOTE: The above is the way things "should" be done, this is for compatability with training.
|
||||
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
|
||||
|
||||
self.xyz_t = torch.nan_to_num(self.xyz_t)
|
||||
|
||||
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
|
||||
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
|
||||
|
||||
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
|
||||
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
|
||||
t2d = t2d[0]
|
||||
# get torsion angles from templates
|
||||
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
|
||||
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
|
||||
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
|
||||
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
|
||||
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
|
||||
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
|
||||
alpha_t = alpha_t[0]
|
||||
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
|
||||
|
||||
same_chain = same_chain_from_bond_feats(self.bond_feats)
|
||||
return RFInput(
|
||||
msa_latent=msa_seed,
|
||||
msa_full=msa_extra,
|
||||
seq=seq,
|
||||
seq_unmasked=self.query_sequence(),
|
||||
bond_feats=self.bond_feats,
|
||||
dist_matrix=dist_matrix,
|
||||
chirals=self.chirals,
|
||||
atom_frames=self.atom_frames.long(),
|
||||
xyz_prev=xyz_prev,
|
||||
alpha_prev=alpha_prev,
|
||||
t1d=self.t1d,
|
||||
t2d=t2d,
|
||||
xyz_t=self.xyz_t[..., 1, :],
|
||||
alpha_t=alpha_t.float(),
|
||||
mask_t=mask_t_2d.float(),
|
||||
same_chain=same_chain.long(),
|
||||
idx=self.idx
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RFInput:
|
||||
msa_latent: torch.Tensor
|
||||
msa_full: torch.Tensor
|
||||
seq: torch.Tensor
|
||||
seq_unmasked: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
bond_feats: torch.Tensor
|
||||
dist_matrix: torch.Tensor
|
||||
chirals: torch.Tensor
|
||||
atom_frames: torch.Tensor
|
||||
xyz_prev: torch.Tensor
|
||||
alpha_prev: torch.Tensor
|
||||
t1d: torch.Tensor
|
||||
t2d: torch.Tensor
|
||||
xyz_t: torch.Tensor
|
||||
alpha_t: torch.Tensor
|
||||
mask_t: torch.Tensor
|
||||
same_chain: torch.Tensor
|
||||
msa_prev: Optional[torch.Tensor] = None
|
||||
pair_prev: Optional[torch.Tensor] = None
|
||||
state_prev: Optional[torch.Tensor] = None
|
||||
mask_recycle: Optional[torch.Tensor] = None
|
||||
|
||||
def to(self, gpu):
|
||||
for field in fields(self):
|
||||
field_value = getattr(self, field.name)
|
||||
if torch.is_tensor(field_value):
|
||||
setattr(self, field.name, field_value.to(gpu))
|
||||
|
||||
def add_batch_dim(self):
|
||||
""" mimic pytorch dataloader at inference time"""
|
||||
for field in fields(self):
|
||||
field_value = getattr(self, field.name)
|
||||
if torch.is_tensor(field_value):
|
||||
setattr(self, field.name, field_value[None])
|
||||
|
||||
909
rf2aa/data/data_loader_utils.py
Normal file
909
rf2aa/data/data_loader_utils.py
Normal file
@@ -0,0 +1,909 @@
|
||||
import torch
|
||||
import warnings
|
||||
import time
|
||||
from icecream import ic
|
||||
from torch.utils import data
|
||||
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(script_dir)
|
||||
sys.path.append(script_dir+'/../')
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import networkx as nx
|
||||
|
||||
from rf2aa.data.parsers import parse_a3m, parse_pdb
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
|
||||
from rf2aa.util import random_rot_trans, \
|
||||
is_atom, is_protein, is_nucleic, is_atom
|
||||
|
||||
|
||||
def MSABlockDeletion(msa, ins, nb=5):
|
||||
'''
|
||||
Input: MSA having shape (N, L)
|
||||
output: new MSA with block deletion
|
||||
'''
|
||||
N, L = msa.shape
|
||||
block_size = max(int(N*0.3), 1)
|
||||
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
|
||||
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
|
||||
to_delete = np.unique(np.clip(to_delete, 1, N-1))
|
||||
#
|
||||
mask = np.ones(N, bool)
|
||||
mask[to_delete] = 0
|
||||
|
||||
return msa[mask], ins[mask]
|
||||
|
||||
def cluster_sum(data, assignment, N_seq, N_res):
|
||||
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
|
||||
return csum
|
||||
|
||||
def get_term_feats(Ls):
|
||||
"""Creates N/C-terminus binary features"""
|
||||
term_info = torch.zeros((sum(Ls),2)).float()
|
||||
start = 0
|
||||
for L_chain in Ls:
|
||||
term_info[start, 0] = 1.0 # flag for N-term
|
||||
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
|
||||
start += L_chain
|
||||
return term_info
|
||||
|
||||
|
||||
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
|
||||
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
|
||||
'''
|
||||
Input: full MSA information (after Block deletion if necessary) & full insertion information
|
||||
Output: seed MSA features & extra sequences
|
||||
|
||||
Seed MSA features:
|
||||
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
|
||||
- profile of clustered sequences (22)
|
||||
- insertion statistics (2)
|
||||
- N-term or C-term? (2)
|
||||
extra sequence features:
|
||||
- aatype of extra sequence (22)
|
||||
- insertion info (1)
|
||||
- N-term or C-term? (2)
|
||||
'''
|
||||
if deterministic:
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
# TODO: delete me, just for testing purposes
|
||||
msa = msa[:2]
|
||||
|
||||
if fixbb:
|
||||
p_mask = 0
|
||||
msa = msa[:1]
|
||||
ins = ins[:1]
|
||||
N, L = msa.shape
|
||||
|
||||
if term_info is None:
|
||||
if len(L_s)==0:
|
||||
L_s = [L]
|
||||
term_info = get_term_feats(L_s)
|
||||
term_info = term_info.to(msa.device)
|
||||
|
||||
#binding_site = torch.zeros((L,1), device=msa.device).float()
|
||||
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
|
||||
|
||||
# raw MSA profile
|
||||
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
|
||||
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
|
||||
|
||||
# Select Nclust sequence randomly (seed MSA or latent MSA)
|
||||
Nclust = (min(N, params['MAXLAT'])-1) // nmer
|
||||
Nclust = Nclust*nmer + 1
|
||||
|
||||
if N > Nclust*2:
|
||||
Nextra = N - Nclust
|
||||
else:
|
||||
Nextra = N
|
||||
Nextra = min(Nextra, params['MAXSEQ']) // nmer
|
||||
Nextra = max(1, Nextra * nmer)
|
||||
#
|
||||
b_seq = list()
|
||||
b_msa_clust = list()
|
||||
b_msa_seed = list()
|
||||
b_msa_extra = list()
|
||||
b_mask_pos = list()
|
||||
for i_cycle in range(params['MAXCYCLE']):
|
||||
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
|
||||
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
|
||||
sample = torch.stack(sample, dim=-1)
|
||||
sample = sample.reshape(-1)
|
||||
|
||||
# add MSA clusters pre-chosen before calling this function
|
||||
if seed_msa_clus is not None:
|
||||
sample_orig_shape = sample.shape
|
||||
sample_seed = seed_msa_clus[i_cycle]
|
||||
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
|
||||
N_sample_more = len(sample) - len(sample_seed)
|
||||
if N_sample_more > 0:
|
||||
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
|
||||
sample = torch.cat([sample_seed, sample_more])
|
||||
else:
|
||||
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
|
||||
|
||||
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
|
||||
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
|
||||
|
||||
# 15% random masking
|
||||
# - 10%: aa replaced with a uniformly sampled random amino acid
|
||||
# - 10%: aa replaced with an amino acid sampled from the MSA profile
|
||||
# - 10%: not replaced
|
||||
# - 70%: replaced with a special token ("mask")
|
||||
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
|
||||
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
|
||||
# explicitly remove probabilities from nucleic acids and atoms
|
||||
#same_aa[..., ChemData().NPROTAAS:] = 0
|
||||
#raw_profile[...,ChemData().NPROTAAS:] = 0
|
||||
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
|
||||
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
|
||||
|
||||
# explicitly set the probability of masking for nucleic acids and atoms
|
||||
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
|
||||
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
|
||||
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
|
||||
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
|
||||
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
|
||||
|
||||
sampler = torch.distributions.categorical.Categorical(probs=probs)
|
||||
mask_sample = sampler.sample()
|
||||
|
||||
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
|
||||
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
|
||||
use_seq = msa_clust
|
||||
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
|
||||
b_seq.append(msa_masked[0].clone())
|
||||
|
||||
## get extra sequenes
|
||||
if N > Nclust*2: # there are enough extra sequences
|
||||
msa_extra = msa[1:,:][sample[Nclust-1:]]
|
||||
ins_extra = ins[1:,:][sample[Nclust-1:]]
|
||||
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
|
||||
elif N - Nclust < 1:
|
||||
msa_extra = msa_masked.clone()
|
||||
ins_extra = ins_clust.clone()
|
||||
extra_mask = mask_pos.clone()
|
||||
else:
|
||||
msa_add = msa[1:,:][sample[Nclust-1:]]
|
||||
ins_add = ins[1:,:][sample[Nclust-1:]]
|
||||
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
|
||||
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
|
||||
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
|
||||
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
|
||||
N_extra = msa_extra.shape[0]
|
||||
|
||||
# clustering (assign remaining sequences to their closest cluster by Hamming distance
|
||||
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
|
||||
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
|
||||
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
|
||||
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
|
||||
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
|
||||
assignment = torch.argmax(agreement, dim=-1)
|
||||
|
||||
# seed MSA features
|
||||
# 1. one_hot encoded aatype: msa_clust_onehot
|
||||
# 2. cluster profile
|
||||
count_extra = ~extra_mask
|
||||
count_clust = ~mask_pos
|
||||
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
|
||||
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
|
||||
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
||||
count_profile += count_clust
|
||||
count_profile += eps
|
||||
msa_clust_profile /= count_profile[:,:,None]
|
||||
# 3. insertion statistics
|
||||
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
||||
msa_clust_del += count_clust*ins_clust
|
||||
msa_clust_del /= count_profile
|
||||
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
|
||||
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
|
||||
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
|
||||
#
|
||||
if fixbb:
|
||||
assert params['MAXCYCLE'] == 1
|
||||
msa_clust_profile = msa_clust_onehot
|
||||
msa_extra_onehot = msa_clust_onehot
|
||||
ins_clust[:] = 0
|
||||
ins_extra[:] = 0
|
||||
# This is how it is done in rfdiff, but really it seems like it should be all 0.
|
||||
# Keeping as-is for now for consistency, as it may be used in downstream masking done
|
||||
# by apply_masks.
|
||||
mask_pos = torch.full_like(msa_clust, 1).bool()
|
||||
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
|
||||
|
||||
# extra MSA features
|
||||
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
|
||||
try:
|
||||
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
|
||||
except Exception as e:
|
||||
print('msa_extra.shape',msa_extra.shape)
|
||||
print('ins_extra.shape',ins_extra.shape)
|
||||
|
||||
if (tocpu):
|
||||
b_msa_clust.append(msa_clust.cpu())
|
||||
b_msa_seed.append(msa_seed.cpu())
|
||||
b_msa_extra.append(msa_extra.cpu())
|
||||
b_mask_pos.append(mask_pos.cpu())
|
||||
else:
|
||||
b_msa_clust.append(msa_clust)
|
||||
b_msa_seed.append(msa_seed)
|
||||
b_msa_extra.append(msa_extra)
|
||||
b_mask_pos.append(mask_pos)
|
||||
|
||||
b_seq = torch.stack(b_seq)
|
||||
b_msa_clust = torch.stack(b_msa_clust)
|
||||
b_msa_seed = torch.stack(b_msa_seed)
|
||||
b_msa_extra = torch.stack(b_msa_extra)
|
||||
b_mask_pos = torch.stack(b_mask_pos)
|
||||
|
||||
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
|
||||
|
||||
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
|
||||
if deterministic:
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
|
||||
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
|
||||
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
|
||||
conf = torch.zeros((n_tmpl, L, 1)).float()
|
||||
t1d = torch.cat((t1d, conf), -1)
|
||||
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
|
||||
return xyz, t1d, mask_t, np.full((n_tmpl), "")
|
||||
|
||||
|
||||
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
|
||||
if deterministic:
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
seqID_cut = params['SEQID']
|
||||
|
||||
if npick_global == None:
|
||||
npick_global=max(npick, 1)
|
||||
|
||||
ntplt = len(tplt['ids'])
|
||||
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
|
||||
return blank_template(npick_global, qlen, random_noise)
|
||||
|
||||
# ignore templates having too high seqID
|
||||
if seqID_cut <= 100.0:
|
||||
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
|
||||
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
|
||||
else:
|
||||
tplt_valid_idx = torch.arange(len(tplt['ids']))
|
||||
|
||||
# check again if there are templates having seqID < cutoff
|
||||
ntplt = len(tplt['ids'])
|
||||
npick = min(npick, ntplt)
|
||||
if npick<1: # no templates
|
||||
return blank_template(npick_global, qlen, random_noise)
|
||||
|
||||
if not pick_top: # select randomly among all possible templates
|
||||
sample = torch.randperm(ntplt)[:npick]
|
||||
else: # only consider top 50 templates
|
||||
sample = torch.randperm(min(50,ntplt))[:npick]
|
||||
|
||||
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
|
||||
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
|
||||
t1d = torch.full((npick_global, qlen), 20).long()
|
||||
t1d_val = torch.zeros((npick_global, qlen)).float()
|
||||
for i,nt in enumerate(sample):
|
||||
tplt_idx = tplt_valid_idx[nt]
|
||||
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
|
||||
pos = tplt['qmap'][0,sel,0] + offset
|
||||
|
||||
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
|
||||
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
|
||||
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
|
||||
|
||||
# 1-D features: alignment confidence
|
||||
t1d[i,pos] = tplt['seq'][0,sel]
|
||||
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
|
||||
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
|
||||
|
||||
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
|
||||
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
|
||||
|
||||
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
|
||||
return xyz, t1d, mask_t, tplt_ids
|
||||
|
||||
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
|
||||
"""Diagonally tiles template coordinates, 1d input features, and masks across
|
||||
template and residue dimensions. 1st template is concatenated directly on residue
|
||||
dimension after a random rotation & translation.
|
||||
"""
|
||||
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
|
||||
|
||||
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
|
||||
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
|
||||
i_tmpl = 0
|
||||
i_res = 0
|
||||
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
|
||||
N_tmpl, L_tmpl = xyz_.shape[:2]
|
||||
if i_tmpl == 0:
|
||||
i1, i2 = 1, N_tmpl
|
||||
else:
|
||||
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
|
||||
|
||||
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
|
||||
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
|
||||
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
|
||||
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
|
||||
|
||||
if not tplt_ids_out[0]: # only add first template
|
||||
tplt_ids_out[0] = ids[0]
|
||||
# remaining templates are diagonally tiled
|
||||
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
|
||||
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
|
||||
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
|
||||
tplt_ids_out[i1:i2] = ids[1:]
|
||||
if i_tmpl == 0:
|
||||
i_tmpl += N_tmpl
|
||||
else:
|
||||
i_tmpl += N_tmpl-1
|
||||
i_res += L_tmpl
|
||||
|
||||
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
|
||||
|
||||
def generate_xyz_prev(xyz_t, mask_t, params):
|
||||
"""
|
||||
allows you to use different initializations for the coordinate track specified in params
|
||||
"""
|
||||
L = xyz_t.shape[1]
|
||||
if params["BLACK_HOLE_INIT"]:
|
||||
xyz_t, _, mask_t = blank_template(1, L)
|
||||
return xyz_t[0].clone(), mask_t[0].clone()
|
||||
|
||||
### merge msa & insertion statistics of two proteins having different taxID
|
||||
def merge_a3m_hetero(a3mA, a3mB, L_s):
|
||||
# merge msa
|
||||
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
|
||||
|
||||
msa = [query]
|
||||
if a3mA['msa'].shape[0] > 1:
|
||||
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
|
||||
msa.append(extra_A)
|
||||
if a3mB['msa'].shape[0] > 1:
|
||||
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
|
||||
msa.append(extra_B)
|
||||
msa = torch.cat(msa, dim=0)
|
||||
|
||||
# merge ins
|
||||
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
|
||||
ins = [query]
|
||||
if a3mA['ins'].shape[0] > 1:
|
||||
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
|
||||
ins.append(extra_A)
|
||||
if a3mB['ins'].shape[0] > 1:
|
||||
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
|
||||
ins.append(extra_B)
|
||||
ins = torch.cat(ins, dim=0)
|
||||
|
||||
a3m = {'msa': msa, 'ins': ins}
|
||||
|
||||
# merge taxids
|
||||
if 'taxid' in a3mA and 'taxid' in a3mB:
|
||||
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
|
||||
|
||||
return a3m
|
||||
|
||||
# merge msa & insertion statistics of units in homo-oligomers
|
||||
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
|
||||
N, L = msa_orig.shape[:2]
|
||||
if mode == "repeat":
|
||||
|
||||
# AAAAAA
|
||||
# AAAAAA
|
||||
|
||||
msa = torch.tile(msa_orig,(1,nmer))
|
||||
ins = torch.tile(ins_orig,(1,nmer))
|
||||
|
||||
elif mode == "diag":
|
||||
|
||||
# AAAAAA
|
||||
# A-----
|
||||
# -A----
|
||||
# --A---
|
||||
# ---A--
|
||||
# ----A-
|
||||
# -----A
|
||||
|
||||
N = N - 1
|
||||
new_N = 1 + N * nmer
|
||||
new_L = L * nmer
|
||||
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
||||
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
||||
|
||||
start_L = 0
|
||||
start_N = 1
|
||||
for i_c in range(nmer):
|
||||
msa[0, start_L:start_L+L] = msa_orig[0]
|
||||
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
|
||||
ins[0, start_L:start_L+L] = ins_orig[0]
|
||||
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
|
||||
start_L += L
|
||||
start_N += N
|
||||
else:
|
||||
|
||||
# AAAAAA
|
||||
# A-----
|
||||
# -AAAAA
|
||||
|
||||
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
||||
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
||||
|
||||
msa[:N, :L] = msa_orig
|
||||
ins[:N, :L] = ins_orig
|
||||
start = L
|
||||
|
||||
for i_c in range(1,nmer):
|
||||
msa[0, start:start+L] = msa_orig[0]
|
||||
msa[N:, start:start+L] = msa_orig[1:]
|
||||
ins[0, start:start+L] = ins_orig[0]
|
||||
ins[N:, start:start+L] = ins_orig[1:]
|
||||
start += L
|
||||
|
||||
return {"msa": msa, "ins": ins}
|
||||
|
||||
def merge_msas(a3m_list, L_s):
|
||||
"""
|
||||
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
|
||||
combined MSA
|
||||
"""
|
||||
seen = set()
|
||||
taxIDs = []
|
||||
a3mA = a3m_list[0]
|
||||
taxIDs.extend(a3mA["taxID"])
|
||||
seen.update(a3mA["hash"])
|
||||
msaA, insA = a3mA["msa"], a3mA["ins"]
|
||||
for i in range(1, len(a3m_list)):
|
||||
a3mB = a3m_list[i]
|
||||
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
|
||||
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
|
||||
a3mA = {"msa": msaA, "ins": insA}
|
||||
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
|
||||
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
|
||||
msaA, insA = a3mA["msa"], a3mA["ins"]
|
||||
taxIDs.extend(a3mB["taxID"])
|
||||
else:
|
||||
final_pairsA = []
|
||||
final_pairsB = []
|
||||
msaB, insB = a3mB["msa"], a3mB["ins"]
|
||||
for pair in pair_taxIDs:
|
||||
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
|
||||
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
|
||||
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
|
||||
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
|
||||
final_pairsA.append(pair_a3mA[msaApair])
|
||||
final_pairsB.append(pair_a3mB[msaBpair])
|
||||
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
|
||||
paired_msaB[final_pairsA] = msaB[final_pairsB]
|
||||
msaA = torch.cat([msaA, paired_msaB], dim=1)
|
||||
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
|
||||
seen.update(a3mB["hash"])
|
||||
|
||||
return msaA, insA
|
||||
|
||||
def remove_all_gap_seqs(a3m):
|
||||
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
|
||||
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
|
||||
a3m['msa'] = a3m['msa'][idx_seq_keep]
|
||||
a3m['ins'] = a3m['ins'][idx_seq_keep]
|
||||
return a3m
|
||||
|
||||
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
|
||||
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
|
||||
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
|
||||
ID, only the sequence with the highest sequence identity to the query (1st
|
||||
sequence in MSA) will be paired.
|
||||
|
||||
Sequences that aren't paired will be padded and added to the bottom of the
|
||||
joined MSA. If a subregion of the input MSAs overlap (represent the same
|
||||
chain), the subregion residue indices can be given as `idx_overlap`, and
|
||||
the overlap region of the unpaired sequences will be included in the joined
|
||||
MSA.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a3mA : dict
|
||||
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
|
||||
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
|
||||
boolean tensor indicating whether each sequence is fully paired. Can be
|
||||
a multi-MSA (contain >2 sub-MSAs).
|
||||
a3mB : dict
|
||||
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
|
||||
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
|
||||
idx_overlap : tuple or list (optional)
|
||||
Start and end indices of overlap region in 1st MSA, followed by the
|
||||
same in 2nd MSA.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a3m : dict
|
||||
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
|
||||
"""
|
||||
# preprocess overlap region
|
||||
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
|
||||
if idx_overlap is not None:
|
||||
i1A, i2A, i1B, i2B = idx_overlap
|
||||
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
|
||||
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
|
||||
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
|
||||
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
|
||||
else:
|
||||
i1B_new, i2B_new = (0, L_B)
|
||||
|
||||
# pair sequences
|
||||
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
|
||||
i_pairedA, i_pairedB = [], []
|
||||
|
||||
for taxid in taxids_shared:
|
||||
i_match = np.where(a3mA['taxid']==taxid)[0]
|
||||
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
|
||||
i_pairedA.append(i_match[i_match_best])
|
||||
|
||||
i_match = np.where(a3mB['taxid']==taxid)[0]
|
||||
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
|
||||
i_pairedB.append(i_match[i_match_best])
|
||||
|
||||
# unpaired sequences
|
||||
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
|
||||
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
|
||||
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
|
||||
|
||||
# handle overlap region
|
||||
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
|
||||
# this diagram shows how the variables below make up the final multi-MSA
|
||||
# (* denotes nongaps, - denotes gaps)
|
||||
# 1 2 3 4
|
||||
# |*|*|*|*| msa_paired
|
||||
# |*|*|*|-| msaA_unpaired
|
||||
# |-|*|-|*| msaB_unpaired
|
||||
if idx_overlap is not None:
|
||||
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
|
||||
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
|
||||
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
|
||||
|
||||
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
|
||||
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
|
||||
a3mB['msa'][i_pairedB, i1B:i2B],
|
||||
a3mA['msa'][i_pairedA, i2A:],
|
||||
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
|
||||
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
||||
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
|
||||
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
|
||||
a3mB['msa'][i_unpairedB, i1B:i2B],
|
||||
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
|
||||
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
|
||||
else:
|
||||
# no overlap region, simple offset pad & stack
|
||||
# this code is actually a special case of "if" block above, but writing
|
||||
# this out explicitly here to make the logic more clear
|
||||
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
|
||||
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
||||
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
|
||||
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
|
||||
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
|
||||
|
||||
# stack paired & unpaired
|
||||
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
|
||||
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
|
||||
|
||||
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
|
||||
# output seq is fully paired if seqs A & B both started out as paired and were paired to
|
||||
# each other on tax ID.
|
||||
# NOTE: there is a rare edge case that is ignored here for simplicity: if
|
||||
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
|
||||
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
|
||||
# will not be labeled as such here
|
||||
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
|
||||
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
|
||||
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
|
||||
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
|
||||
|
||||
# insertion features in paired MSAs are assumed to be zero
|
||||
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
|
||||
return a3m
|
||||
|
||||
|
||||
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
|
||||
"""Load a multi-MSA, which is a MSA that is paired across more than 2
|
||||
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
|
||||
duplicate portions of the MSA for homo-oligomer repeated chains.
|
||||
|
||||
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
|
||||
it can) and pairs sequences across as many sub-MSAs as possible by matching
|
||||
taxonomic ID. For details on how pairing is done, see
|
||||
`join_msas_by_taxid()`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hash_list : list of str
|
||||
Hashes of MSAs to load and join. Must not contain duplicates.
|
||||
taxid_list : list of str
|
||||
Taxonomic IDs of query sequences of each input MSA.
|
||||
Ls : list of int
|
||||
Lengths of the chains corresponding to the hashes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
a3m_out : dict
|
||||
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
|
||||
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
|
||||
hashes_out : list of str
|
||||
Hashes of MSAs in the order that they are joined in `a3m_out`.
|
||||
Contains the same elements as the input `hash_list` but may be in a
|
||||
different order.
|
||||
Ls_out : list of int
|
||||
Lengths of each chain in `a3m_out`
|
||||
"""
|
||||
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
|
||||
|
||||
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
|
||||
# comprising sub-MSAs whose indices in the input lists are
|
||||
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
|
||||
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
|
||||
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
|
||||
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
|
||||
a3m_list = [] # list of multi-MSAs
|
||||
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
|
||||
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
|
||||
|
||||
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
|
||||
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
|
||||
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
|
||||
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
|
||||
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
|
||||
# rare and so is ignored here for simplicity.
|
||||
N = len(hash_list)
|
||||
for i1, i2 in itertools.permutations(range(N),2):
|
||||
|
||||
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
|
||||
if i1 in idx_list and i2 in idx_list: continue # already loaded
|
||||
if i1 == '' or i2 == '': continue # no taxID means no pMSA
|
||||
|
||||
# a paired MSA exists
|
||||
if taxid_list[i1]==taxid_list[i2]:
|
||||
|
||||
h1, h2 = hash_list[i1], hash_list[i2]
|
||||
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
|
||||
|
||||
if os.path.exists(fn):
|
||||
msa, ins, taxid = parse_a3m(fn, paired=True)
|
||||
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
|
||||
is_paired=torch.ones(msa.shape[0]).bool())
|
||||
res_range1 = (0,Ls[i1])
|
||||
res_range2 = (Ls[i1],msa.shape[1])
|
||||
|
||||
# both hashes are new, add paired MSA to list
|
||||
if i1 not in idx_list and i2 not in idx_list:
|
||||
a3m_list.append(a3m_new)
|
||||
idx_list_groups.append([i1,i2])
|
||||
res_range_groups.append([res_range1, res_range2])
|
||||
|
||||
# one of the hashes is already in a multi-MSA
|
||||
# find that multi-MSA and join the new pMSA to it
|
||||
elif i1 in idx_list:
|
||||
# which multi-MSA & sub-MSA has the hash with index `i1`?
|
||||
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
|
||||
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
|
||||
|
||||
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
|
||||
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
||||
|
||||
idx_list_groups[i_a3m].append(i2)
|
||||
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
||||
L_new = res_range2[1] - res_range2[0]
|
||||
res_range_groups[i_a3m].append((L, L+L_new))
|
||||
|
||||
elif i2 in idx_list:
|
||||
# which multi-MSA & sub-MSA has the hash with index `i2`?
|
||||
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
|
||||
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
|
||||
|
||||
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
|
||||
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
||||
|
||||
idx_list_groups[i_a3m].append(i1)
|
||||
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
||||
L_new = res_range1[1] - res_range1[0]
|
||||
res_range_groups[i_a3m].append((L, L+L_new))
|
||||
|
||||
# add unpaired MSAs
|
||||
# ungroup hash indices now, since we're done making multi-MSAs
|
||||
idx_list = [x for group in idx_list_groups for x in group]
|
||||
for i in range(N):
|
||||
if i not in idx_list:
|
||||
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
|
||||
msa, ins, taxid = parse_a3m(fn)
|
||||
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
|
||||
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
|
||||
a3m_list.append(a3m_new)
|
||||
idx_list.append(i)
|
||||
|
||||
Ls_out = [Ls[i] for i in idx_list]
|
||||
hashes_out = [hash_list[i] for i in idx_list]
|
||||
|
||||
# join multi-MSAs & unpaired MSAs
|
||||
a3m_out = a3m_list[0]
|
||||
for i in range(1, len(a3m_list)):
|
||||
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
|
||||
|
||||
return a3m_out, hashes_out, Ls_out
|
||||
|
||||
|
||||
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out):
|
||||
"""Expands a multi-MSA of unique chains into an MSA of a
|
||||
hetero-homo-oligomer in which some chains appear more than once. The query
|
||||
sequences (1st sequence of MSA) are concatenated directly along the
|
||||
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
|
||||
stacked") so that exact repeat sequences aren't paired.
|
||||
|
||||
For example, if the original multi-MSA contains unique chains 1,2,3 but
|
||||
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
|
||||
(where - denotes a block of gap characters):
|
||||
|
||||
1 2 - 3 - -
|
||||
- - 1 - 3 -
|
||||
- - - - - 1
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a3m : dict
|
||||
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
|
||||
representing the multi-MSA of unique chains.
|
||||
hashes_in : list of str
|
||||
Unique MSA hashes used in `a3m`.
|
||||
hashes_out : list of str
|
||||
Non-unique MSA hashes desired in expanded MSA.
|
||||
Ls_in : list of int
|
||||
Lengths of each chain in `a3m`
|
||||
Ls_out : list of int
|
||||
Lengths of each chain desired in expanded MSA.
|
||||
params : dict
|
||||
Data loading parameters
|
||||
|
||||
Returns
|
||||
-------
|
||||
a3m : dict
|
||||
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
|
||||
taxids because no further joining needs to be done.
|
||||
"""
|
||||
assert(len(hashes_out)==len(Ls_out))
|
||||
assert(set(hashes_in)==set(hashes_out))
|
||||
assert(a3m['msa'].shape[1]==sum(Ls_in))
|
||||
|
||||
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
|
||||
# each new repeat will be offset in sequence dimension of final MSA
|
||||
counts = dict()
|
||||
n_copy = [] # n-th copy of this hash in `hashes`
|
||||
for h in hashes_out:
|
||||
if h in counts:
|
||||
counts[h] += 1
|
||||
else:
|
||||
counts[h] = 1
|
||||
n_copy.append(counts[h])
|
||||
|
||||
# num sequences in source & destination MSAs
|
||||
N_in = a3m['msa'].shape[0]
|
||||
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
|
||||
|
||||
# source MSA
|
||||
msa_in, ins_in = a3m['msa'], a3m['ins']
|
||||
|
||||
# initialize destination MSA to gap characters
|
||||
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
|
||||
ins_out = torch.full((N_out, sum(Ls_out)), 0)
|
||||
|
||||
# for each destination chain
|
||||
for i_out, h_out in enumerate(hashes_out):
|
||||
# identify index of source chain
|
||||
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
|
||||
|
||||
# residue indexes
|
||||
i1_res_in = sum(Ls_in[:i_in])
|
||||
i2_res_in = sum(Ls_in[:i_in+1])
|
||||
i1_res_out = sum(Ls_out[:i_out])
|
||||
i2_res_out = sum(Ls_out[:i_out+1])
|
||||
|
||||
# copy over query sequence
|
||||
# NOTE: There is a bug in these next two lines!
|
||||
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
|
||||
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
||||
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
||||
|
||||
# offset non-query sequences along sequence dimension based on repeat number of a given hash
|
||||
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
|
||||
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
|
||||
# copy over non-query sequences
|
||||
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
|
||||
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
|
||||
|
||||
# only 1st oligomeric repeat can be fully paired
|
||||
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
|
||||
|
||||
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
|
||||
a3m_out = remove_all_gap_seqs(a3m_out)
|
||||
|
||||
return a3m_out
|
||||
|
||||
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
|
||||
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
|
||||
locate paired MSAs and pair sequences across all chains by taxonomic ID.
|
||||
Unpaired sequences are padded and stacked on the bottom.
|
||||
"""
|
||||
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
|
||||
hashes = []
|
||||
hashes_unique = []
|
||||
taxids_unique = []
|
||||
Ls_unique = []
|
||||
for chid,L_ in zip(chain_ids, Ls):
|
||||
hashes.append(chid2hash[chid])
|
||||
if chid2hash[chid] not in hashes_unique:
|
||||
hashes_unique.append(chid2hash[chid])
|
||||
taxids_unique.append(chid2taxid.get(chid))
|
||||
Ls_unique.append(L_)
|
||||
|
||||
# loads multi-MSA for unique chains
|
||||
a3m_prot, hashes_unique, Ls_unique = \
|
||||
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
|
||||
|
||||
# expands multi-MSA to repeat chains of homo-oligomers
|
||||
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
|
||||
|
||||
return a3m_prot
|
||||
|
||||
def choose_multimsa_clusters(msa_seq_is_paired, params):
|
||||
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
|
||||
clusters during MSA featurization.
|
||||
"""
|
||||
frac_paired = msa_seq_is_paired.float().mean()
|
||||
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
|
||||
return None
|
||||
else:
|
||||
# ensure that half of the clusters are fully-paired sequences,
|
||||
# and let the rest be chosen randomly
|
||||
N_seed = params['MAXLAT']//2
|
||||
msa_seed_clus = []
|
||||
for i_cycle in range(params['MAXCYCLE']):
|
||||
idx_paired = torch.where(msa_seq_is_paired)[0]
|
||||
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
|
||||
return msa_seed_clus
|
||||
|
||||
|
||||
#fd
|
||||
def get_bond_distances(bond_feats):
|
||||
atom_bonds = (bond_feats > 0)*(bond_feats<5)
|
||||
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
|
||||
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
|
||||
return torch.from_numpy(dist_matrix).float()
|
||||
|
||||
|
||||
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
|
||||
xyz, mask, res_idx = parse_pdb(pdbfilename)
|
||||
plddt = np.load(plddtfilename)
|
||||
|
||||
# update mask info with plddt (ignore sidechains if plddt < 90.0)
|
||||
mask_lddt = np.full_like(mask, False)
|
||||
mask_lddt[plddt > sccut] = True
|
||||
mask_lddt[:,:5] = True
|
||||
mask = np.logical_and(mask, mask_lddt)
|
||||
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
|
||||
|
||||
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
|
||||
|
||||
def get_msa(a3mfilename, item, maxseq=5000):
|
||||
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
|
||||
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}
|
||||
208
rf2aa/data/merge_inputs.py
Normal file
208
rf2aa/data/merge_inputs.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import torch
|
||||
from hashlib import md5
|
||||
|
||||
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, merge_hetero_templates, get_term_feats, join_msas_by_taxid, expand_multi_msa
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.util import center_and_realign_missing, same_chain_from_bond_feats, random_rot_trans, idx_from_Ls
|
||||
|
||||
|
||||
def merge_protein_inputs(protein_inputs, deterministic: bool = False):
|
||||
if len(protein_inputs) == 0:
|
||||
return None,[]
|
||||
elif len(protein_inputs) == 1:
|
||||
chain = list(protein_inputs.keys())[0]
|
||||
input = list(protein_inputs.values())[0]
|
||||
xyz_t = input.xyz_t
|
||||
xyz_t[0:1] = random_rot_trans(xyz_t[0:1], deterministic=deterministic)
|
||||
input.xyz_t = xyz_t
|
||||
return input, [(chain, input.length())]
|
||||
# handle merging MSAs and such
|
||||
# first determine which sequence are identical, then which one have mergeable MSAs
|
||||
# then cat the templates, other feats
|
||||
else:
|
||||
a3m_list = [
|
||||
{"msa": input.msa,
|
||||
"ins": input.ins,
|
||||
"taxid": input.taxids
|
||||
}
|
||||
for input in protein_inputs.values()
|
||||
]
|
||||
hash_list = [md5(input.sequence_string().encode()).hexdigest() for input in protein_inputs.values()]
|
||||
lengths_list = [input.length() for input in protein_inputs.values()]
|
||||
|
||||
seen = set()
|
||||
unique_indices = []
|
||||
for idx, hash in enumerate(hash_list):
|
||||
if hash not in seen:
|
||||
unique_indices.append(idx)
|
||||
seen.add(hash)
|
||||
|
||||
unique_a3m = [a3m for i, a3m in enumerate(a3m_list) if i in unique_indices ]
|
||||
unique_hashes = [value for index, value in enumerate(hash_list) if index in unique_indices]
|
||||
unique_lengths_list = [value for index, value in enumerate(lengths_list) if index in unique_indices]
|
||||
|
||||
if len(unique_a3m) >1:
|
||||
a3m_out = unique_a3m[0]
|
||||
for unq_a3m in unique_a3m[1:]:
|
||||
a3m_out = join_msas_by_taxid(a3m_out, unq_a3m)
|
||||
a3m_out = expand_multi_msa(a3m_out, unique_hashes, hash_list, unique_lengths_list, lengths_list)
|
||||
else:
|
||||
a3m = unique_a3m[0]
|
||||
msa, ins = a3m["msa"], a3m["ins"]
|
||||
a3m_out = merge_a3m_homo(msa, ins, len(hash_list))
|
||||
|
||||
# merge templates
|
||||
max_template_dim = max([input.xyz_t.shape[0] for input in protein_inputs.values()])
|
||||
xyz_t_list = [input.xyz_t for input in protein_inputs.values()]
|
||||
mask_t_list = [input.mask_t for input in protein_inputs.values()]
|
||||
t1d_list = [input.t1d for input in protein_inputs.values()]
|
||||
ids = ["inference"] * len(t1d_list)
|
||||
xyz_t, t1d, mask_t, _ = merge_hetero_templates(xyz_t_list, t1d_list, mask_t_list, ids, lengths_list, deterministic=deterministic)
|
||||
|
||||
atom_frames = torch.zeros(0,3,2)
|
||||
chirals = torch.zeros(0,5)
|
||||
|
||||
|
||||
L_total = sum(lengths_list)
|
||||
bond_feats = torch.zeros((L_total, L_total)).long()
|
||||
offset = 0
|
||||
for bf in [input.bond_feats for input in protein_inputs.values()]:
|
||||
L = bf.shape[0]
|
||||
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||||
offset += L
|
||||
chain_lengths = list(zip(protein_inputs.keys(), lengths_list))
|
||||
|
||||
merged_input = RawInputData(
|
||||
a3m_out["msa"],
|
||||
a3m_out["ins"],
|
||||
bond_feats,
|
||||
xyz_t[:max_template_dim],
|
||||
mask_t[:max_template_dim],
|
||||
t1d[:max_template_dim],
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=None
|
||||
)
|
||||
return merged_input, chain_lengths
|
||||
|
||||
def merge_na_inputs(na_inputs):
|
||||
# should just be trivially catting features
|
||||
running_inputs = None
|
||||
chain_lengths = []
|
||||
for chid, input in na_inputs.items():
|
||||
running_inputs = merge_two_inputs(running_inputs, input)
|
||||
chain_lengths.append((chid, input.length()))
|
||||
return running_inputs, chain_lengths
|
||||
|
||||
def merge_sm_inputs(sm_inputs):
|
||||
# should be trivially catting features
|
||||
running_inputs = None
|
||||
chain_lengths = []
|
||||
for chid, input in sm_inputs.items():
|
||||
running_inputs = merge_two_inputs(running_inputs, input)
|
||||
chain_lengths.append((chid, input.length()))
|
||||
return running_inputs, chain_lengths
|
||||
|
||||
def merge_two_inputs(first_input, second_input):
|
||||
# merges two arbitrary inputs of data types
|
||||
if first_input is None and second_input is None:
|
||||
return None
|
||||
elif first_input is None:
|
||||
return second_input
|
||||
elif second_input is None:
|
||||
return first_input
|
||||
|
||||
Ls = [first_input.length(), second_input.length()]
|
||||
L_total = sum(Ls)
|
||||
# merge msas
|
||||
|
||||
a3m_first = {
|
||||
"msa": first_input.msa,
|
||||
"ins": first_input.ins,
|
||||
}
|
||||
a3m_second = {
|
||||
"msa": second_input.msa,
|
||||
"ins": second_input.ins,
|
||||
}
|
||||
a3m = merge_a3m_hetero(a3m_first, a3m_second, Ls)
|
||||
# merge bond_feats
|
||||
bond_feats = torch.zeros((L_total, L_total)).long()
|
||||
offset = 0
|
||||
for bf in [first_input.bond_feats, second_input.bond_feats]:
|
||||
L = bf.shape[0]
|
||||
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||||
offset += L
|
||||
|
||||
# merge templates
|
||||
xyz_t = torch.cat([first_input.xyz_t, second_input.xyz_t],dim=1)
|
||||
t1d = torch.cat([first_input.t1d, second_input.t1d],dim=1)
|
||||
mask_t = torch.cat([first_input.mask_t, second_input.mask_t],dim=1)
|
||||
|
||||
# handle chirals (need to residue offset)
|
||||
if second_input.chirals.shape[0] > 0 :
|
||||
second_input.chirals[:, :-1] = second_input.chirals[:, :-1] + first_input.length()
|
||||
chirals = torch.cat([first_input.chirals, second_input.chirals])
|
||||
|
||||
# cat atom frames
|
||||
atom_frames = torch.cat([first_input.atom_frames, second_input.atom_frames])
|
||||
# return new object
|
||||
return RawInputData(
|
||||
a3m["msa"],
|
||||
a3m["ins"],
|
||||
bond_feats,
|
||||
xyz_t,
|
||||
mask_t,
|
||||
t1d,
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=None
|
||||
)
|
||||
|
||||
def merge_all(
|
||||
protein_inputs,
|
||||
na_inputs,
|
||||
sm_inputs,
|
||||
residues_to_atomize,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
|
||||
protein_inputs, protein_chain_lengths = merge_protein_inputs(protein_inputs, deterministic=deterministic)
|
||||
|
||||
na_inputs, na_chain_lengths = merge_na_inputs(na_inputs)
|
||||
sm_inputs, sm_chain_lengths = merge_sm_inputs(sm_inputs)
|
||||
if protein_inputs is None and na_inputs is None and sm_inputs is None:
|
||||
raise ValueError("No valid inputs were provided")
|
||||
running_inputs = merge_two_inputs(protein_inputs, na_inputs) #could handle pairing protein/NA MSAs here
|
||||
running_inputs = merge_two_inputs(running_inputs, sm_inputs)
|
||||
|
||||
all_chain_lengths = protein_chain_lengths + na_chain_lengths + sm_chain_lengths
|
||||
running_inputs.chain_lengths = all_chain_lengths
|
||||
|
||||
all_lengths = get_Ls_from_chain_lengths(running_inputs.chain_lengths)
|
||||
protein_lengths = get_Ls_from_chain_lengths(protein_chain_lengths)
|
||||
term_info = get_term_feats(all_lengths)
|
||||
term_info[sum(protein_lengths):, :] = 0
|
||||
running_inputs.term_info = term_info
|
||||
|
||||
xyz_t = running_inputs.xyz_t
|
||||
mask_t = running_inputs.mask_t
|
||||
|
||||
same_chain = same_chain = same_chain_from_bond_feats(running_inputs.bond_feats)
|
||||
ntempl = xyz_t.shape[0]
|
||||
xyz_t = torch.stack(
|
||||
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
|
||||
)
|
||||
xyz_t = torch.nan_to_num(xyz_t)
|
||||
running_inputs.xyz_t = xyz_t
|
||||
running_inputs.idx = idx_from_Ls(all_lengths)
|
||||
|
||||
# after everything is merged need to add bond feats for covales
|
||||
# reindex protein feats function
|
||||
if residues_to_atomize:
|
||||
running_inputs.update_protein_features_after_atomize(residues_to_atomize)
|
||||
|
||||
return running_inputs
|
||||
|
||||
def get_Ls_from_chain_lengths(chain_lengths):
|
||||
return [val[1] for val in chain_lengths]
|
||||
|
||||
46
rf2aa/data/nucleic_acid.py
Normal file
46
rf2aa/data/nucleic_acid.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
|
||||
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.util import get_protein_bond_feats
|
||||
|
||||
def load_nucleic_acid(fasta_fn, input_type, model_runner):
|
||||
if input_type not in ["dna", "rna"]:
|
||||
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
|
||||
if input_type == "dna":
|
||||
dna_alphabet = True
|
||||
rna_alphabet = False
|
||||
elif input_type == "rna":
|
||||
dna_alphabet = False
|
||||
rna_alphabet = True
|
||||
|
||||
loader_params = model_runner.config.loader_params
|
||||
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
|
||||
if (msa.shape[0] > loader_params["MAXSEQ"]):
|
||||
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
|
||||
idxs_tokeep[0] = 0
|
||||
msa = msa[idxs_tokeep]
|
||||
ins = ins[idxs_tokeep]
|
||||
if len(L) > 1:
|
||||
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
|
||||
L = L[0]
|
||||
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
|
||||
|
||||
|
||||
bond_feats = get_protein_bond_feats(L)
|
||||
chirals = torch.zeros(0, 5)
|
||||
atom_frames = torch.zeros(0, 3, 2)
|
||||
|
||||
return RawInputData(
|
||||
torch.from_numpy(msa),
|
||||
torch.from_numpy(ins),
|
||||
bond_feats,
|
||||
xyz_t,
|
||||
mask_t,
|
||||
t1d,
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=None,
|
||||
)
|
||||
812
rf2aa/data/parsers.py
Normal file
812
rf2aa/data/parsers.py
Normal file
@@ -0,0 +1,812 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.spatial
|
||||
import string
|
||||
import os,re
|
||||
from os.path import exists
|
||||
import random
|
||||
import rf2aa.util as util
|
||||
import gzip
|
||||
import rf2aa
|
||||
from rf2aa.ffindex import *
|
||||
import torch
|
||||
from openbabel import openbabel
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
def get_dislf(seq, xyz, mask):
|
||||
L = seq.shape[0]
|
||||
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
|
||||
sgs = xyz[resolved_cys_mask,5]
|
||||
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
|
||||
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
|
||||
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
|
||||
|
||||
dslf = []
|
||||
for i in is_dslf.nonzero():
|
||||
dslf.append( (
|
||||
resolved_cys_mask[ii[i]].item(),
|
||||
resolved_cys_mask[jj[i]].item(),
|
||||
) )
|
||||
return dslf
|
||||
|
||||
def read_template_pdb(L, pdb_fn, target_chain=None):
|
||||
# get full sequence from given PDB
|
||||
seq_full = list()
|
||||
prev_chain=''
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
if line[12:16].strip() != "CA":
|
||||
continue
|
||||
if line[21] != prev_chain:
|
||||
if len(seq_full) > 0:
|
||||
L_s.append(len(seq_full)-offset)
|
||||
offset = len(seq_full)
|
||||
prev_chain = line[21]
|
||||
aa = line[17:20]
|
||||
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
|
||||
|
||||
seq_full = torch.tensor(seq_full).long()
|
||||
|
||||
xyz = torch.full((L, 36, 3), np.nan).float()
|
||||
seq = torch.full((L,), 20).long()
|
||||
conf = torch.zeros(L,1).float()
|
||||
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
|
||||
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
|
||||
#
|
||||
idx = resNo - 1
|
||||
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
|
||||
break
|
||||
seq[idx] = aa_idx
|
||||
|
||||
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
|
||||
mask = mask.all(dim=-1)[:,None]
|
||||
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
|
||||
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
|
||||
t1d = torch.cat((seq_1hot, conf), -1)
|
||||
|
||||
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
|
||||
return xyz[None], t1d[None]
|
||||
|
||||
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
|
||||
print ('read_multichain_pdb',tmpl_chain)
|
||||
|
||||
# get full sequence from PDB
|
||||
seq_full = list()
|
||||
L_s = list()
|
||||
prev_chain=''
|
||||
offset = 0
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
if line[12:16].strip() != "CA":
|
||||
continue
|
||||
if line[21] != prev_chain:
|
||||
if len(seq_full) > 0:
|
||||
L_s.append(len(seq_full)-offset)
|
||||
offset = len(seq_full)
|
||||
prev_chain = line[21]
|
||||
aa = line[17:20]
|
||||
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
|
||||
L_s.append(len(seq_full) - offset)
|
||||
|
||||
seq_full = torch.tensor(seq_full).long()
|
||||
L = len(seq_full)
|
||||
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
|
||||
msa[1,:L_s[0]] = 20
|
||||
msa[2,L_s[0]:] = 20
|
||||
ins = torch.zeros_like(msa)
|
||||
|
||||
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
|
||||
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
|
||||
|
||||
mask = torch.full((1, L, ChemData().NTOTAL), False)
|
||||
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
|
||||
seq = torch.full((1, L,), 20).long()
|
||||
conf = torch.zeros(1, L,1).float()
|
||||
|
||||
with open(pdb_fn) as fp:
|
||||
for line in fp:
|
||||
if line[:4] != "ATOM":
|
||||
continue
|
||||
outbatch = 0
|
||||
|
||||
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
|
||||
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
|
||||
|
||||
idx = resNo - 1
|
||||
|
||||
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
|
||||
if tgtatm == atom:
|
||||
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
|
||||
xyz[0, idx, i_atm, :] = xyz_i
|
||||
mask[0, idx, i_atm] = True
|
||||
if line[21] == tmpl_chain:
|
||||
xyz_t[0, idx, i_atm, :] = xyz_i
|
||||
mask_t[0, idx, i_atm] = True
|
||||
break
|
||||
seq[0, idx] = aa_idx
|
||||
|
||||
if (mask_t.any()):
|
||||
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
|
||||
|
||||
dslf = get_dislf(seq[0], xyz[0], mask[0])
|
||||
|
||||
# assign confidence 'CONF' to all residues with backbone in template
|
||||
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
|
||||
|
||||
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
|
||||
t1d = torch.cat((seq_1hot, conf), -1)
|
||||
|
||||
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
|
||||
|
||||
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
|
||||
msa = []
|
||||
ins = []
|
||||
|
||||
fstream = open(filename,"r")
|
||||
|
||||
for line in fstream:
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa.append(line)
|
||||
|
||||
# sequence length
|
||||
L = len(msa[-1])
|
||||
|
||||
i = np.zeros((L))
|
||||
ins.append(i)
|
||||
|
||||
# convert letters into numbers
|
||||
if rmsa_alphabet:
|
||||
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||
else:
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
ins = np.array(ins, dtype=np.uint8)
|
||||
|
||||
return msa,ins
|
||||
|
||||
# Parse a fasta file containing multiple chains separated by '/'
|
||||
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
|
||||
msa = []
|
||||
ins = []
|
||||
|
||||
fstream = open(filename,"r")
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
L_s = []
|
||||
for line in fstream:
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa_i = line.translate(table)
|
||||
msa_i = msa_i.replace('B','D') # hacky...
|
||||
if L_s == []:
|
||||
L_s = [len(x) for x in msa_i.split('/')]
|
||||
msa_i = msa_i.replace('/','')
|
||||
msa.append(msa_i)
|
||||
|
||||
# sequence length
|
||||
L = len(msa[-1])
|
||||
|
||||
i = np.zeros((L))
|
||||
ins.append(i)
|
||||
|
||||
if (len(msa) >= maxseq):
|
||||
break
|
||||
|
||||
# convert letters into numbers
|
||||
if rna_alphabet:
|
||||
alphabet = np.array(list("00000000000000000000-000000ACGUN"), dtype='|S1').view(np.uint8)
|
||||
elif dna_alphabet:
|
||||
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
|
||||
else:
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
ins = np.array(ins, dtype=np.uint8)
|
||||
|
||||
return msa,ins,L_s
|
||||
|
||||
#fd - parse protein/RNA coupled fastas
|
||||
def parse_mixed_fasta(filename, maxseq=10000):
|
||||
msa1,msa2 = [],[]
|
||||
|
||||
fstream = open(filename,"r")
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
unpaired_r, unpaired_p = 0, 0
|
||||
|
||||
for line in fstream:
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa_i = line.translate(table)
|
||||
msa_i = msa_i.replace('B','D') # hacky...
|
||||
|
||||
msas_i = msa_i.split('/')
|
||||
|
||||
if (len(msas_i)==1):
|
||||
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
|
||||
|
||||
if (len(msa1)==0 or (
|
||||
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
|
||||
)):
|
||||
# skip if we've already found half of our limit in unpaired protein seqs
|
||||
if sum([1 for x in msas_i[1] if x != '-']) == 0:
|
||||
unpaired_p += 1
|
||||
if unpaired_p > maxseq // 2:
|
||||
continue
|
||||
|
||||
# skip if we've already found half of our limit in unpaired rna seqs
|
||||
if sum([1 for x in msas_i[0] if x != '-']) == 0:
|
||||
unpaired_r += 1
|
||||
if unpaired_r > maxseq // 2:
|
||||
continue
|
||||
|
||||
msa1.append(msas_i[0])
|
||||
msa2.append(msas_i[1])
|
||||
else:
|
||||
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
|
||||
|
||||
if (len(msa1) >= maxseq):
|
||||
break
|
||||
|
||||
# convert letters into numbers
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa1[msa1 == alphabet[i]] = i
|
||||
msa1[msa1>=31] = 21 # anything unknown to 'X'
|
||||
|
||||
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa2[msa2 == alphabet[i]] = i
|
||||
msa2[msa2>=31] = 30 # anything unknown to 'N'
|
||||
|
||||
msa = np.concatenate((msa1,msa2),axis=-1)
|
||||
|
||||
ins = np.zeros(msa.shape, dtype=np.uint8)
|
||||
|
||||
return msa,ins
|
||||
|
||||
|
||||
# parse a fasta alignment IF it exists
|
||||
# otherwise return single-sequence msa
|
||||
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
|
||||
if (exists(filename)):
|
||||
return parse_fasta(filename, maxseq, rmsa_alphabet)
|
||||
else:
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
|
||||
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
seq[seq == alphabet[i]] = i
|
||||
|
||||
return (seq, np.zeros_like(seq))
|
||||
|
||||
|
||||
#fd - parse protein/RNA coupled fastas
|
||||
def parse_mixed_fasta(filename, maxseq=8000):
|
||||
msa1,msa2 = [],[]
|
||||
|
||||
fstream = open(filename,"r")
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
unpaired_r, unpaired_p = 0, 0
|
||||
|
||||
for line in fstream:
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa_i = line.translate(table)
|
||||
msa_i = msa_i.replace('B','D') # hacky...
|
||||
|
||||
msas_i = msa_i.split('/')
|
||||
|
||||
if (len(msas_i)==1):
|
||||
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
|
||||
|
||||
if (len(msa1)==0 or (
|
||||
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
|
||||
)):
|
||||
# skip if we've already found half of our limit in unpaired protein seqs
|
||||
if sum([1 for x in msas_i[1] if x != '-']) == 0:
|
||||
unpaired_p += 1
|
||||
if unpaired_p > maxseq // 2:
|
||||
continue
|
||||
|
||||
# skip if we've already found half of our limit in unpaired rna seqs
|
||||
if sum([1 for x in msas_i[0] if x != '-']) == 0:
|
||||
unpaired_r += 1
|
||||
if unpaired_r > maxseq // 2:
|
||||
continue
|
||||
|
||||
msa1.append(msas_i[0])
|
||||
msa2.append(msas_i[1])
|
||||
else:
|
||||
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
|
||||
|
||||
if (len(msa1) >= maxseq):
|
||||
break
|
||||
|
||||
# convert letters into numbers
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa1[msa1 == alphabet[i]] = i
|
||||
msa1[msa1>=31] = 21 # anything unknown to 'X'
|
||||
|
||||
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa2[msa2 == alphabet[i]] = i
|
||||
msa2[msa2>=31] = 30 # anything unknown to 'N'
|
||||
|
||||
msa = np.concatenate((msa1,msa2),axis=-1)
|
||||
|
||||
ins = np.zeros(msa.shape, dtype=np.uint8)
|
||||
|
||||
return msa,ins
|
||||
|
||||
|
||||
# read A3M and convert letters into
|
||||
# integers in the 0..20 range,
|
||||
# also keep track of insertions
|
||||
def parse_a3m(filename, maxseq=8000, paired=False):
|
||||
msa = []
|
||||
ins = []
|
||||
taxIDs = []
|
||||
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
# read file line by line
|
||||
if filename.split('.')[-1] == 'gz':
|
||||
fstream = gzip.open(filename, 'rt')
|
||||
else:
|
||||
fstream = open(filename, 'r')
|
||||
|
||||
for i, line in enumerate(fstream):
|
||||
|
||||
# skip labels
|
||||
if line[0] == '>':
|
||||
if paired: # paired MSAs only have a TAXID in the fasta header
|
||||
taxIDs.append(line[1:].strip())
|
||||
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
|
||||
if i == 0:
|
||||
taxIDs.append("query")
|
||||
else:
|
||||
match = re.search( r'TaxID=(\d+)', line)
|
||||
if match:
|
||||
taxIDs.append(match.group(1))
|
||||
else:
|
||||
taxIDs.append("") # query sequence
|
||||
continue
|
||||
|
||||
# remove right whitespaces
|
||||
line = line.rstrip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
# remove lowercase letters and append to MSA
|
||||
msa.append(line.translate(table))
|
||||
|
||||
# sequence length
|
||||
L = len(msa[-1])
|
||||
|
||||
# 0 - match or gap; 1 - insertion
|
||||
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
|
||||
i = np.zeros((L))
|
||||
|
||||
if np.sum(a) > 0:
|
||||
# positions of insertions
|
||||
pos = np.where(a==1)[0]
|
||||
|
||||
# shift by occurrence
|
||||
a = pos - np.arange(pos.shape[0])
|
||||
|
||||
# position of insertions in cleaned sequence
|
||||
# and their length
|
||||
pos,num = np.unique(a, return_counts=True)
|
||||
|
||||
# append to the matrix of insetions
|
||||
i[pos] = num
|
||||
|
||||
ins.append(i)
|
||||
|
||||
if (len(msa) >= maxseq):
|
||||
break
|
||||
|
||||
# convert letters into numbers
|
||||
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
|
||||
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||
for i in range(alphabet.shape[0]):
|
||||
msa[msa == alphabet[i]] = i
|
||||
|
||||
# treat all unknown characters as gaps
|
||||
msa[msa > 20] = 20
|
||||
|
||||
ins = np.array(ins, dtype=np.uint8)
|
||||
|
||||
return msa,ins, np.array(taxIDs)
|
||||
|
||||
|
||||
# read and extract xyz coords of N,Ca,C atoms
|
||||
# from a PDB file
|
||||
def parse_pdb(filename, seq=False, lddt_mask=False):
|
||||
lines = open(filename,'r').readlines()
|
||||
if seq:
|
||||
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
|
||||
return parse_pdb_lines(lines)
|
||||
|
||||
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||
idx_s = [int(r[1]) for r in res]
|
||||
plddt = [float(r[3]) for r in res]
|
||||
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
|
||||
idx = pdb_idx_s.index((chain,resNo))
|
||||
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
if lddt_mask == True:
|
||||
plddt = np.array(plddt)
|
||||
mask_lddt = np.full_like(mask, False)
|
||||
mask_lddt[plddt > .85, 5:] = True
|
||||
mask_lddt[plddt > .70, :5] = True
|
||||
mask = np.logical_and(mask, mask_lddt)
|
||||
|
||||
return xyz,mask,np.array(idx_s), np.array(seq)
|
||||
|
||||
#'''
|
||||
def parse_pdb_lines(lines):
|
||||
|
||||
# indices of residues observed in the structure
|
||||
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||
idx_s = [int(r[1]) for r in res]
|
||||
|
||||
# 4 BB + up to 10 SC atoms
|
||||
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
|
||||
for l in lines:
|
||||
if l[:4] != "ATOM":
|
||||
continue
|
||||
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
|
||||
idx = pdb_idx_s.index((chain,resNo))
|
||||
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
|
||||
if tgtatm == atom:
|
||||
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||
break
|
||||
|
||||
# save atom mask
|
||||
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||
|
||||
return xyz,mask,np.array(idx_s)
|
||||
|
||||
|
||||
def parse_templates(item, params):
|
||||
|
||||
# init FFindexDB of templates
|
||||
### and extract template IDs
|
||||
### present in the DB
|
||||
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
|
||||
read_data(params['FFDB']+'_pdb.ffdata'))
|
||||
#ffids = set([i.name for i in ffdb.index])
|
||||
|
||||
# process tabulated hhsearch output to get
|
||||
# matched positions and positional scores
|
||||
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
|
||||
hits = []
|
||||
for l in open(infile, "r").readlines():
|
||||
if l[0]=='>':
|
||||
key = l[1:].split()[0]
|
||||
hits.append([key,[],[]])
|
||||
elif "score" in l or "dssp" in l:
|
||||
continue
|
||||
else:
|
||||
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||
|
||||
# get per-hit statistics from an .hhr file
|
||||
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||
# [Probab, E-value, Score, Aligned_cols,
|
||||
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||
lines = open(infile[:-4]+'hhr', "r").readlines()
|
||||
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||
for i,posi in enumerate(pos):
|
||||
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||
|
||||
# parse templates from FFDB
|
||||
for hi in hits:
|
||||
#if hi[0] not in ffids:
|
||||
# continue
|
||||
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||
if entry == None:
|
||||
continue
|
||||
data = read_entry_lines(entry, ffdb.data)
|
||||
hi += list(parse_pdb_lines(data))
|
||||
|
||||
# process hits
|
||||
counter = 0
|
||||
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
|
||||
for data in hits:
|
||||
if len(data)<7:
|
||||
continue
|
||||
|
||||
qi,ti = np.array(data[1]).T
|
||||
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||
ncol = sel1.shape[0]
|
||||
if ncol < 10:
|
||||
continue
|
||||
|
||||
ids.append(data[0])
|
||||
f0d.append(data[3])
|
||||
f1d.append(np.array(data[2])[sel1])
|
||||
xyz.append(data[4][sel2])
|
||||
mask.append(data[5][sel2])
|
||||
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||
counter += 1
|
||||
|
||||
xyz = np.vstack(xyz).astype(np.float32)
|
||||
mask = np.vstack(mask).astype(bool)
|
||||
qmap = np.vstack(qmap).astype(np.long)
|
||||
f0d = np.vstack(f0d).astype(np.float32)
|
||||
f1d = np.vstack(f1d).astype(np.float32)
|
||||
ids = ids
|
||||
|
||||
return xyz,mask,qmap,f0d,f1d,ids
|
||||
|
||||
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
|
||||
# process tabulated hhsearch output to get
|
||||
# matched positions and positional scores
|
||||
hits = []
|
||||
for l in open(atab_fn, "r").readlines():
|
||||
if l[0]=='>':
|
||||
if len(hits) == max_templ:
|
||||
break
|
||||
key = l[1:].split()[0]
|
||||
hits.append([key,[],[]])
|
||||
elif "score" in l or "dssp" in l:
|
||||
continue
|
||||
else:
|
||||
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||
|
||||
# get per-hit statistics from an .hhr file
|
||||
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||
# [Probab, E-value, Score, Aligned_cols,
|
||||
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||
lines = open(hhr_fn, "r").readlines()
|
||||
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||
for i,posi in enumerate(pos[:len(hits)]):
|
||||
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||
|
||||
# parse templates from FFDB
|
||||
for hi in hits:
|
||||
#if hi[0] not in ffids:
|
||||
# continue
|
||||
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||
if entry == None:
|
||||
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
|
||||
continue
|
||||
data = read_entry_lines(entry, ffdb.data)
|
||||
hi += list(parse_pdb_lines_w_seq(data))
|
||||
|
||||
# process hits
|
||||
counter = 0
|
||||
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
|
||||
for data in hits:
|
||||
if len(data)<7:
|
||||
continue
|
||||
# print ("Process %s..."%data[0])
|
||||
|
||||
qi,ti = np.array(data[1]).T
|
||||
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||
ncol = sel1.shape[0]
|
||||
if ncol < 10:
|
||||
continue
|
||||
|
||||
ids.append(data[0])
|
||||
f0d.append(data[3])
|
||||
f1d.append(np.array(data[2])[sel1])
|
||||
xyz.append(data[4][sel2])
|
||||
mask.append(data[5][sel2])
|
||||
seq.append(data[-1][sel2])
|
||||
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||
counter += 1
|
||||
|
||||
xyz = np.vstack(xyz).astype(np.float32)
|
||||
mask = np.vstack(mask).astype(bool)
|
||||
qmap = np.vstack(qmap).astype(np.int64)
|
||||
f0d = np.vstack(f0d).astype(np.float32)
|
||||
f1d = np.vstack(f1d).astype(np.float32)
|
||||
seq = np.hstack(seq).astype(np.int64)
|
||||
ids = ids
|
||||
|
||||
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
|
||||
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
|
||||
|
||||
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
|
||||
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
|
||||
ntmplatoms = xyz_t.shape[1]
|
||||
|
||||
npick = min(n_templ, len(ids))
|
||||
if npick < 1: # no templates
|
||||
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
|
||||
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
|
||||
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
|
||||
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
|
||||
return xyz, mask, t1d
|
||||
|
||||
sample = torch.arange(npick)
|
||||
#
|
||||
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
|
||||
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
|
||||
f1d = torch.full((npick, qlen), 20).long()
|
||||
f1d_val = torch.zeros((npick, qlen, 1)).float()
|
||||
#
|
||||
for i, nt in enumerate(sample):
|
||||
sel = torch.where(qmap[:,1] == nt)[0]
|
||||
pos = qmap[sel, 0]
|
||||
xyz[i, pos] = xyz_t[sel]
|
||||
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
|
||||
f1d[i, pos] = seq[sel]
|
||||
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
|
||||
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
|
||||
|
||||
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
|
||||
f1d = torch.cat((f1d, f1d_val), dim=-1)
|
||||
|
||||
return xyz, mask, f1d
|
||||
|
||||
|
||||
def clean_sdffile(filename):
|
||||
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
|
||||
lines2 = []
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
num_atoms = int(lines[3][:3])
|
||||
for i in range(len(lines)):
|
||||
if i>=4 and i<4+num_atoms:
|
||||
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
|
||||
else:
|
||||
lines2.append(lines[i])
|
||||
molstring = ''.join(lines2)
|
||||
|
||||
return molstring
|
||||
|
||||
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
|
||||
"""Parse small molecule ligand.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
filetype : str
|
||||
string : bool
|
||||
If True, `filename` is a string containing the molecule data.
|
||||
remove_H : bool
|
||||
Whether to remove hydrogen atoms.
|
||||
find_automorphs : bool
|
||||
Whether to enumerate atom symmetry permutations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
obmol : OBMol
|
||||
openbabel molecule object representing the ligand
|
||||
msa : torch.Tensor (N_atoms,) long
|
||||
Integer-encoded "sequence" (atom types) of ligand
|
||||
ins : torch.Tensor (N_atoms,) long
|
||||
Insertion features (all zero) for RF input
|
||||
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
|
||||
Atom coordinates
|
||||
mask : torch.Tensor (N_symmetry, N_atoms) bool
|
||||
Boolean mask for whether atom exists
|
||||
"""
|
||||
obConversion = openbabel.OBConversion()
|
||||
obConversion.SetInFormat(filetype)
|
||||
obmol = openbabel.OBMol()
|
||||
if string:
|
||||
obConversion.ReadString(obmol,filename)
|
||||
elif filetype=='sdf':
|
||||
molstring = clean_sdffile(filename)
|
||||
obConversion.ReadString(obmol,molstring)
|
||||
else:
|
||||
obConversion.ReadFile(obmol,filename)
|
||||
if generate_conformer:
|
||||
builder = openbabel.OBBuilder()
|
||||
builder.Build(obmol)
|
||||
ff = openbabel.OBForceField.FindForceField("mmff94")
|
||||
did_setup = ff.Setup(obmol)
|
||||
if did_setup:
|
||||
ff.FastRotorSearch()
|
||||
ff.GetCoordinates(obmol)
|
||||
else:
|
||||
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||||
if remove_H:
|
||||
obmol.DeleteHydrogens()
|
||||
# the above sometimes fails to get all the hydrogens
|
||||
i = 1
|
||||
while i < obmol.NumAtoms()+1:
|
||||
if obmol.GetAtom(i).GetAtomicNum()==1:
|
||||
obmol.DeleteAtom(obmol.GetAtom(i))
|
||||
else:
|
||||
i += 1
|
||||
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
|
||||
for i in range(1, obmol.NumAtoms()+1)]
|
||||
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
|
||||
ins = torch.zeros_like(msa)
|
||||
|
||||
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
|
||||
|
||||
if find_automorphs:
|
||||
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
|
||||
|
||||
return obmol, msa, ins, atom_coords, mask
|
||||
35
rf2aa/data/preprocessing.py
Normal file
35
rf2aa/data/preprocessing.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
from hydra import initialize, compose
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
#from rf2aa.run_inference import ModelRunner
|
||||
|
||||
|
||||
def make_msa(
|
||||
fasta_file,
|
||||
chain,
|
||||
model_runner
|
||||
):
|
||||
out_dir_base = Path(model_runner.config.output_path)
|
||||
hash = model_runner.config.job_name
|
||||
out_dir = out_dir_base / hash / chain
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
command = model_runner.config.database_params.command
|
||||
search_base = model_runner.config.database_params.sequencedb
|
||||
num_cpus = model_runner.config.database_params.num_cpus
|
||||
ram_gb = model_runner.config.database_params.mem
|
||||
template_database = model_runner.config.database_params.hhdb
|
||||
|
||||
out_a3m = out_dir / "t000_.msa0.a3m"
|
||||
out_atab = out_dir / "t000_.atab"
|
||||
out_hhr = out_dir / "t000_.hhr"
|
||||
if out_a3m.exists() and out_atab.exists() and out_hhr.exists():
|
||||
return out_a3m, out_hhr, out_atab
|
||||
|
||||
search_command = f"./{command} {fasta_file} {out_dir} {num_cpus} {ram_gb} {search_base} {template_database}"
|
||||
print(search_command)
|
||||
_ = subprocess.run(search_command, shell=True)
|
||||
return out_a3m, out_hhr, out_atab
|
||||
|
||||
93
rf2aa/data/protein.py
Normal file
93
rf2aa/data/protein.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
|
||||
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
|
||||
from rf2aa.data.preprocessing import make_msa
|
||||
from rf2aa.util import get_protein_bond_feats
|
||||
|
||||
|
||||
def get_templates(
|
||||
qlen,
|
||||
ffdb,
|
||||
hhr_fn,
|
||||
atab_fn,
|
||||
seqID_cut,
|
||||
n_templ,
|
||||
pick_top: bool = True,
|
||||
offset: int = 0,
|
||||
random_noise: float = 5.0,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
(
|
||||
xyz_parsed,
|
||||
mask_parsed,
|
||||
qmap_parsed,
|
||||
f0d_parsed,
|
||||
f1d_parsed,
|
||||
seq_parsed,
|
||||
ids_parsed,
|
||||
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
|
||||
tplt = {
|
||||
"xyz": xyz_parsed.unsqueeze(0),
|
||||
"mask": mask_parsed.unsqueeze(0),
|
||||
"qmap": qmap_parsed.unsqueeze(0),
|
||||
"f0d": f0d_parsed.unsqueeze(0),
|
||||
"f1d": f1d_parsed.unsqueeze(0),
|
||||
"seq": seq_parsed.unsqueeze(0),
|
||||
"ids": ids_parsed,
|
||||
}
|
||||
params = {
|
||||
"SEQID": seqID_cut,
|
||||
}
|
||||
return TemplFeaturize(
|
||||
tplt,
|
||||
qlen,
|
||||
params,
|
||||
offset=offset,
|
||||
npick=n_templ,
|
||||
pick_top=pick_top,
|
||||
random_noise=random_noise,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
|
||||
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
|
||||
msa, ins, taxIDs = parse_a3m(msa_file)
|
||||
# NOTE: this next line is a bug, but is the way that
|
||||
# the code is written in the original implementation!
|
||||
ins[0] = msa[0]
|
||||
|
||||
L = msa.shape[1]
|
||||
if hhr_fn is None or atab_fn is None:
|
||||
print("No templates provided")
|
||||
xyz_t, t1d, mask_t, _ = blank_template(1, L)
|
||||
else:
|
||||
xyz_t, t1d, mask_t, _ = get_templates(
|
||||
L,
|
||||
model_runner.ffdb,
|
||||
hhr_fn,
|
||||
atab_fn,
|
||||
seqID_cut=model_runner.config.loader_params.seqid,
|
||||
n_templ=model_runner.config.loader_params.n_templ,
|
||||
deterministic=model_runner.deterministic,
|
||||
)
|
||||
|
||||
bond_feats = get_protein_bond_feats(L)
|
||||
chirals = torch.zeros(0, 5)
|
||||
atom_frames = torch.zeros(0, 3, 2)
|
||||
return RawInputData(
|
||||
torch.from_numpy(msa),
|
||||
torch.from_numpy(ins),
|
||||
bond_feats,
|
||||
xyz_t,
|
||||
mask_t,
|
||||
t1d,
|
||||
chirals,
|
||||
atom_frames,
|
||||
taxids=taxIDs,
|
||||
)
|
||||
|
||||
def generate_msa_and_load_protein(fasta_file, chain, model_runner):
|
||||
msa_file, hhr_file, atab_file = make_msa(fasta_file, chain, model_runner)
|
||||
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)
|
||||
41
rf2aa/data/small_molecule.py
Normal file
41
rf2aa/data/small_molecule.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from rf2aa.data.data_loader import RawInputData
|
||||
from rf2aa.data.data_loader_utils import blank_template
|
||||
from rf2aa.data.parsers import parse_mol
|
||||
from rf2aa.kinematics import get_chirals
|
||||
from rf2aa.util import get_bond_feats, get_nxgraph, get_atom_frames
|
||||
|
||||
|
||||
def load_small_molecule(input_file, input_type, model_runner):
|
||||
if input_type == "smiles":
|
||||
is_string = True
|
||||
else:
|
||||
is_string = False
|
||||
|
||||
obmol, msa, ins, xyz, mask = parse_mol(
|
||||
input_file, filetype=input_type, string=is_string, generate_conformer=True
|
||||
)
|
||||
return compute_features_from_obmol(obmol, msa, xyz, model_runner)
|
||||
|
||||
def compute_features_from_obmol(obmol, msa, xyz, model_runner):
|
||||
L = msa.shape[0]
|
||||
ins = torch.zeros_like(msa)
|
||||
bond_feats = get_bond_feats(obmol)
|
||||
|
||||
xyz_t, t1d, mask_t, _ = blank_template(
|
||||
model_runner.config.loader_params.n_templ,
|
||||
L,
|
||||
deterministic=model_runner.deterministic,
|
||||
)
|
||||
chirals = get_chirals(obmol, xyz[0])
|
||||
G = get_nxgraph(obmol)
|
||||
atom_frames = get_atom_frames(msa, G)
|
||||
msa, ins = msa[None], ins[None]
|
||||
return RawInputData(
|
||||
msa, ins, bond_feats, xyz_t, mask_t, t1d, chirals, atom_frames, taxids=None
|
||||
)
|
||||
|
||||
def remove_leaving_atoms(input, is_leaving):
|
||||
keep = ~is_leaving
|
||||
return input.keep_features(keep)
|
||||
Reference in New Issue
Block a user