Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import assertpy
|
||||
from assertpy import assert_that
|
||||
from icecream import ic
|
||||
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
||||
from rf2aa.model.Track_module import IterativeSimulator
|
||||
from rf2aa.model.layers.AuxiliaryPredictor import (
|
||||
DistanceNetwork,
|
||||
MaskedTokenNetwork,
|
||||
LDDTNetwork,
|
||||
PAENetwork,
|
||||
BinderNetwork,
|
||||
)
|
||||
from rf2aa.tensor_util import assert_shape, assert_equal
|
||||
import rf2aa.util
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
|
||||
def get_shape(t):
|
||||
if hasattr(t, "shape"):
|
||||
return t.shape
|
||||
if type(t) is tuple:
|
||||
return [get_shape(e) for e in t]
|
||||
else:
|
||||
return type(t)
|
||||
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
|
||||
repeat_length=None, # if symmetrizing repeats, what length are they?
|
||||
symmsub_k=None, # if symmetrizing repeats, which diagonals?
|
||||
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
|
||||
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
|
||||
copy_main_block_template=None, # whether or not to copy main block template along main diag
|
||||
n_extra_block=4,
|
||||
n_main_block=8,
|
||||
n_ref_block=4,
|
||||
n_finetune_block=0,
|
||||
d_msa=256,
|
||||
d_msa_full=64,
|
||||
d_pair=128,
|
||||
d_templ=64,
|
||||
n_head_msa=8,
|
||||
n_head_pair=4,
|
||||
n_head_templ=4,
|
||||
d_hidden=32,
|
||||
d_hidden_templ=64,
|
||||
d_t1d=0,
|
||||
p_drop=0.15,
|
||||
additional_dt1d=0,
|
||||
recycling_type="msa_pair",
|
||||
SE3_param={}, SE3_ref_param={},
|
||||
atom_type_index=None,
|
||||
aamask=None,
|
||||
ljlk_parameters=None,
|
||||
lj_correction_parameters=None,
|
||||
cb_len=None,
|
||||
cb_ang=None,
|
||||
cb_tor=None,
|
||||
num_bonds=None,
|
||||
lj_lin=0.6,
|
||||
use_chiral_l1=True,
|
||||
use_lj_l1=False,
|
||||
use_atom_frames=True,
|
||||
use_same_chain=False,
|
||||
enable_same_chain=False,
|
||||
refiner_topk=64,
|
||||
get_quaternion=False,
|
||||
# New for diffusion
|
||||
freeze_track_motif=False,
|
||||
assert_single_sequence_input=False,
|
||||
fit=False,
|
||||
tscale=1.0
|
||||
):
|
||||
super(RoseTTAFoldModule, self).__init__()
|
||||
self.freeze_track_motif = freeze_track_motif
|
||||
self.assert_single_sequence_input = assert_single_sequence_input
|
||||
self.recycling_type = recycling_type
|
||||
#
|
||||
# Input Embeddings
|
||||
d_state = SE3_param["l0_out_features"]
|
||||
self.latent_emb = MSA_emb(
|
||||
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
|
||||
enable_same_chain=enable_same_chain
|
||||
)
|
||||
self.full_emb = Extra_emb(
|
||||
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
|
||||
)
|
||||
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
|
||||
|
||||
self.templ_emb = Templ_emb(d_t1d=d_t1d,
|
||||
d_pair=d_pair,
|
||||
d_templ=d_templ,
|
||||
d_state=d_state,
|
||||
n_head=n_head_templ,
|
||||
d_hidden=d_hidden_templ,
|
||||
p_drop=0.25,
|
||||
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method,
|
||||
main_block=main_block,
|
||||
copy_main_block=copy_main_block_template,
|
||||
additional_dt1d=additional_dt1d)
|
||||
|
||||
# Update inputs with outputs from previous round
|
||||
|
||||
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
#
|
||||
self.simulator = IterativeSimulator(
|
||||
n_extra_block=n_extra_block,
|
||||
n_main_block=n_main_block,
|
||||
n_ref_block=n_ref_block,
|
||||
n_finetune_block=n_finetune_block,
|
||||
d_msa=d_msa,
|
||||
d_msa_full=d_msa_full,
|
||||
d_pair=d_pair,
|
||||
d_hidden=d_hidden,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
SE3_param=SE3_param,
|
||||
SE3_ref_param=SE3_ref_param,
|
||||
p_drop=p_drop,
|
||||
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
|
||||
aamask=aamask,
|
||||
ljlk_parameters=ljlk_parameters,
|
||||
lj_correction_parameters=lj_correction_parameters,
|
||||
num_bonds=num_bonds,
|
||||
cb_len=cb_len,
|
||||
cb_ang=cb_ang,
|
||||
cb_tor=cb_tor,
|
||||
lj_lin=lj_lin,
|
||||
use_lj_l1=use_lj_l1,
|
||||
use_chiral_l1=use_chiral_l1,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method,
|
||||
main_block=main_block,
|
||||
use_same_chain=use_same_chain,
|
||||
enable_same_chain=enable_same_chain,
|
||||
refiner_topk=refiner_topk
|
||||
)
|
||||
|
||||
##
|
||||
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||
self.lddt_pred = LDDTNetwork(d_state)
|
||||
self.pae_pred = PAENetwork(d_pair)
|
||||
self.pde_pred = PAENetwork(
|
||||
d_pair
|
||||
) # distance error, but use same architecture as aligned error
|
||||
# binder predictions are made on top of the pair features, just like
|
||||
# PAE predictions are. It's not clear if this is the best place to insert
|
||||
# this prediction head.
|
||||
# self.binder_network = BinderNetwork(d_pair, d_state)
|
||||
|
||||
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
|
||||
|
||||
self.use_atom_frames = use_atom_frames
|
||||
self.enable_same_chain = enable_same_chain
|
||||
self.get_quaternion = get_quaternion
|
||||
self.verbose_checks = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
msa_latent,
|
||||
msa_full,
|
||||
seq,
|
||||
seq_unmasked,
|
||||
xyz,
|
||||
sctors,
|
||||
idx,
|
||||
bond_feats,
|
||||
dist_matrix,
|
||||
chirals,
|
||||
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
||||
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
|
||||
return_raw=False,
|
||||
use_checkpoint=False,
|
||||
return_infer=False, #fd ?
|
||||
p2p_crop=-1, topk_crop=-1, # striping
|
||||
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
|
||||
):
|
||||
# ic(get_shape(msa_latent))
|
||||
# ic(get_shape(msa_full))
|
||||
# ic(get_shape(seq))
|
||||
# ic(get_shape(seq_unmasked))
|
||||
# ic(get_shape(xyz))
|
||||
# ic(get_shape(sctors))
|
||||
# ic(get_shape(idx))
|
||||
# ic(get_shape(bond_feats))
|
||||
# ic(get_shape(chirals))
|
||||
# ic(get_shape(atom_frames))
|
||||
# ic(get_shape(t1d))
|
||||
# ic(get_shape(t2d))
|
||||
# ic(get_shape(xyz_t))
|
||||
# ic(get_shape(alpha_t))
|
||||
# ic(get_shape(mask_t))
|
||||
# ic(get_shape(same_chain))
|
||||
# ic(get_shape(msa_prev))
|
||||
# ic(get_shape(pair_prev))
|
||||
# ic(get_shape(mask_recycle))
|
||||
# ic()
|
||||
# ic()
|
||||
B, N, L = msa_latent.shape[:3]
|
||||
A = atom_frames.shape[1]
|
||||
dtype = msa_latent.dtype
|
||||
|
||||
if self.assert_single_sequence_input:
|
||||
assert_shape(msa_latent, (1, 1, L, 164))
|
||||
assert_shape(msa_full, (1, 1, L, 83))
|
||||
assert_shape(seq, (1, L))
|
||||
assert_shape(seq_unmasked, (1, L))
|
||||
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
|
||||
assert_shape(sctors, (1, L, 20, 2))
|
||||
assert_shape(idx, (1, L))
|
||||
assert_shape(bond_feats, (1, L, L))
|
||||
assert_shape(dist_matrix, (1, L, L))
|
||||
# assert_shape(chirals, (1, 0))
|
||||
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
|
||||
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
|
||||
assert_shape(t1d, (1, 1, L, 80))
|
||||
assert_shape(t2d, (1, 1, L, L, 68))
|
||||
assert_shape(xyz_t, (1, 1, L, 3))
|
||||
assert_shape(alpha_t, (1, 1, L, 60))
|
||||
assert_shape(mask_t, (1, 1, L, L))
|
||||
assert_shape(same_chain, (1, L, L))
|
||||
device = msa_latent.device
|
||||
assert_that(msa_full.device).is_equal_to(device)
|
||||
assert_that(seq.device).is_equal_to(device)
|
||||
assert_that(seq_unmasked.device).is_equal_to(device)
|
||||
assert_that(xyz.device).is_equal_to(device)
|
||||
assert_that(sctors.device).is_equal_to(device)
|
||||
assert_that(idx.device).is_equal_to(device)
|
||||
assert_that(bond_feats.device).is_equal_to(device)
|
||||
assert_that(dist_matrix.device).is_equal_to(device)
|
||||
assert_that(atom_frames.device).is_equal_to(device)
|
||||
assert_that(t1d.device).is_equal_to(device)
|
||||
assert_that(t2d.device).is_equal_to(device)
|
||||
assert_that(xyz_t.device).is_equal_to(device)
|
||||
assert_that(alpha_t.device).is_equal_to(device)
|
||||
assert_that(mask_t.device).is_equal_to(device)
|
||||
assert_that(same_chain.device).is_equal_to(device)
|
||||
|
||||
if self.verbose_checks:
|
||||
#ic(is_motif.shape)
|
||||
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
|
||||
#is_protein_motif = is_motif & ~is_sm
|
||||
#if is_motif.any():
|
||||
# motif_protein_i = torch.where(is_motif)[0][0]
|
||||
#is_motif_sm = is_motif & is_sm
|
||||
#if is_sm.any():
|
||||
# motif_sm_i = torch.where(is_motif_sm)[0][0]
|
||||
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
|
||||
|
||||
"""
|
||||
msa_full: NSEQ,N_INDEL,N_TERMINUS,
|
||||
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
NINDEL = 1
|
||||
NTERMINUS = 2
|
||||
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
|
||||
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
|
||||
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
|
||||
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
|
||||
|
||||
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
|
||||
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
|
||||
|
||||
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
|
||||
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
|
||||
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||
msa_latent_indel2 = np.r_[
|
||||
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
|
||||
]
|
||||
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||
|
||||
#i_name = [(diffused_protein_i, "diffused_protein")]
|
||||
#if is_sm.any():
|
||||
# i_name.insert(0, (motif_sm_i, "motif_sm"))
|
||||
#if is_motif.any():
|
||||
# i_name.insert(0, (motif_protein_i, "motif_protein"))
|
||||
i_name = [(0, "tst")]
|
||||
for i, name in i_name:
|
||||
ic(f"------------------{name}:{i}----------------")
|
||||
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||
msa_full_indel = msa_full[
|
||||
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||
]
|
||||
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
|
||||
|
||||
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
|
||||
msa_latent_indel1 = msa_latent[
|
||||
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||
]
|
||||
msa_latent_indel2 = msa_latent[
|
||||
0,
|
||||
0,
|
||||
i,
|
||||
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
|
||||
]
|
||||
msa_latent_term = msa_latent[
|
||||
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||
]
|
||||
|
||||
assert_equal(msa_full_seq, msa_latent_seq1)
|
||||
assert_equal(msa_full_seq, msa_latent_seq2)
|
||||
assert_equal(msa_full_indel, msa_latent_indel1)
|
||||
assert_equal(msa_full_indel, msa_latent_indel2)
|
||||
assert_equal(msa_full_term, msa_latent_term)
|
||||
# if 'motif' in name:
|
||||
msa_cat = torch.where(msa_full_seq)[0]
|
||||
ic(msa_cat, seq[0, i])
|
||||
assert_equal(seq[0, i : i + 1], msa_cat)
|
||||
assert_equal(seq[0, i], seq_unmasked[0, i])
|
||||
ic(
|
||||
name,
|
||||
# torch.where(msa_latent[0,0,i,:80]),
|
||||
# torch.where(msa_full[0,0,i]),
|
||||
seq[0, i],
|
||||
seq_unmasked[0, i],
|
||||
torch.where(t1d[0, 0, i]),
|
||||
xyz[0, i, :4, 0],
|
||||
xyz_t[0, 0, i, 0],
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
#if self.enable_same_chain == False:
|
||||
# same_chain = None
|
||||
msa_latent, pair, state = self.latent_emb(
|
||||
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
|
||||
)
|
||||
msa_full = self.full_emb(msa_full, seq, idx)
|
||||
pair = pair + self.bond_emb(bond_feats)
|
||||
|
||||
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
|
||||
msa_full = msa_full.to(dtype)
|
||||
|
||||
#
|
||||
# Do recycling
|
||||
if msa_prev is None:
|
||||
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||
if pair_prev is None:
|
||||
pair_prev = torch.zeros_like(pair)
|
||||
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
||||
state_prev = torch.zeros_like(state)
|
||||
|
||||
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
||||
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
state = state + state_recycle # if state is not recycled these will be zeros
|
||||
|
||||
# add template embedding
|
||||
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
|
||||
|
||||
# Predict coordinates from given inputs
|
||||
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
|
||||
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
|
||||
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
|
||||
symmids, symmsub, symmRs, symmmeta,
|
||||
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
|
||||
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
|
||||
p2p_crop=p2p_crop, topk_crop=topk_crop
|
||||
)
|
||||
|
||||
if return_raw:
|
||||
# get last structure
|
||||
xyz_last = xyz_allatom[-1].unsqueeze(0)
|
||||
return msa[:,0], pair, xyz_last, alpha_s[-1], None
|
||||
|
||||
# predict masked amino acids
|
||||
logits_aa = self.aa_pred(msa)
|
||||
|
||||
# predict distogram & orientograms
|
||||
logits = self.c6d_pred(pair)
|
||||
|
||||
# Predict LDDT
|
||||
lddt = self.lddt_pred(state)
|
||||
|
||||
if self.verbose_checks:
|
||||
pseq_0 = logits_aa.permute(0, 2, 1)
|
||||
ic(pseq_0.shape)
|
||||
pseq_0 = pseq_0[0]
|
||||
ic(
|
||||
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
|
||||
)
|
||||
ic(
|
||||
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
|
||||
)
|
||||
|
||||
logits_pae = logits_pde = p_bind = None
|
||||
# predict aligned error and distance error
|
||||
logits_pae = self.pae_pred(pair)
|
||||
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||
|
||||
#fd predict bind/no-bind
|
||||
p_bind = self.bind_pred(logits_pae,same_chain)
|
||||
|
||||
if self.get_quaternion:
|
||||
return (
|
||||
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
|
||||
)
|
||||
else:
|
||||
return (
|
||||
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
|
||||
)
|
||||
Reference in New Issue
Block a user