Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import assertpy
|
||||
from assertpy import assert_that
|
||||
from icecream import ic
|
||||
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
||||
from rf2aa.model.Track_module import IterativeSimulator
|
||||
from rf2aa.model.layers.AuxiliaryPredictor import (
|
||||
DistanceNetwork,
|
||||
MaskedTokenNetwork,
|
||||
LDDTNetwork,
|
||||
PAENetwork,
|
||||
BinderNetwork,
|
||||
)
|
||||
from rf2aa.tensor_util import assert_shape, assert_equal
|
||||
import rf2aa.util
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
|
||||
def get_shape(t):
|
||||
if hasattr(t, "shape"):
|
||||
return t.shape
|
||||
if type(t) is tuple:
|
||||
return [get_shape(e) for e in t]
|
||||
else:
|
||||
return type(t)
|
||||
|
||||
|
||||
class RoseTTAFoldModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
|
||||
repeat_length=None, # if symmetrizing repeats, what length are they?
|
||||
symmsub_k=None, # if symmetrizing repeats, which diagonals?
|
||||
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
|
||||
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
|
||||
copy_main_block_template=None, # whether or not to copy main block template along main diag
|
||||
n_extra_block=4,
|
||||
n_main_block=8,
|
||||
n_ref_block=4,
|
||||
n_finetune_block=0,
|
||||
d_msa=256,
|
||||
d_msa_full=64,
|
||||
d_pair=128,
|
||||
d_templ=64,
|
||||
n_head_msa=8,
|
||||
n_head_pair=4,
|
||||
n_head_templ=4,
|
||||
d_hidden=32,
|
||||
d_hidden_templ=64,
|
||||
d_t1d=0,
|
||||
p_drop=0.15,
|
||||
additional_dt1d=0,
|
||||
recycling_type="msa_pair",
|
||||
SE3_param={}, SE3_ref_param={},
|
||||
atom_type_index=None,
|
||||
aamask=None,
|
||||
ljlk_parameters=None,
|
||||
lj_correction_parameters=None,
|
||||
cb_len=None,
|
||||
cb_ang=None,
|
||||
cb_tor=None,
|
||||
num_bonds=None,
|
||||
lj_lin=0.6,
|
||||
use_chiral_l1=True,
|
||||
use_lj_l1=False,
|
||||
use_atom_frames=True,
|
||||
use_same_chain=False,
|
||||
enable_same_chain=False,
|
||||
refiner_topk=64,
|
||||
get_quaternion=False,
|
||||
# New for diffusion
|
||||
freeze_track_motif=False,
|
||||
assert_single_sequence_input=False,
|
||||
fit=False,
|
||||
tscale=1.0
|
||||
):
|
||||
super(RoseTTAFoldModule, self).__init__()
|
||||
self.freeze_track_motif = freeze_track_motif
|
||||
self.assert_single_sequence_input = assert_single_sequence_input
|
||||
self.recycling_type = recycling_type
|
||||
#
|
||||
# Input Embeddings
|
||||
d_state = SE3_param["l0_out_features"]
|
||||
self.latent_emb = MSA_emb(
|
||||
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
|
||||
enable_same_chain=enable_same_chain
|
||||
)
|
||||
self.full_emb = Extra_emb(
|
||||
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
|
||||
)
|
||||
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
|
||||
|
||||
self.templ_emb = Templ_emb(d_t1d=d_t1d,
|
||||
d_pair=d_pair,
|
||||
d_templ=d_templ,
|
||||
d_state=d_state,
|
||||
n_head=n_head_templ,
|
||||
d_hidden=d_hidden_templ,
|
||||
p_drop=0.25,
|
||||
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method,
|
||||
main_block=main_block,
|
||||
copy_main_block=copy_main_block_template,
|
||||
additional_dt1d=additional_dt1d)
|
||||
|
||||
# Update inputs with outputs from previous round
|
||||
|
||||
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||
#
|
||||
self.simulator = IterativeSimulator(
|
||||
n_extra_block=n_extra_block,
|
||||
n_main_block=n_main_block,
|
||||
n_ref_block=n_ref_block,
|
||||
n_finetune_block=n_finetune_block,
|
||||
d_msa=d_msa,
|
||||
d_msa_full=d_msa_full,
|
||||
d_pair=d_pair,
|
||||
d_hidden=d_hidden,
|
||||
n_head_msa=n_head_msa,
|
||||
n_head_pair=n_head_pair,
|
||||
SE3_param=SE3_param,
|
||||
SE3_ref_param=SE3_ref_param,
|
||||
p_drop=p_drop,
|
||||
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
|
||||
aamask=aamask,
|
||||
ljlk_parameters=ljlk_parameters,
|
||||
lj_correction_parameters=lj_correction_parameters,
|
||||
num_bonds=num_bonds,
|
||||
cb_len=cb_len,
|
||||
cb_ang=cb_ang,
|
||||
cb_tor=cb_tor,
|
||||
lj_lin=lj_lin,
|
||||
use_lj_l1=use_lj_l1,
|
||||
use_chiral_l1=use_chiral_l1,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method,
|
||||
main_block=main_block,
|
||||
use_same_chain=use_same_chain,
|
||||
enable_same_chain=enable_same_chain,
|
||||
refiner_topk=refiner_topk
|
||||
)
|
||||
|
||||
##
|
||||
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||
self.lddt_pred = LDDTNetwork(d_state)
|
||||
self.pae_pred = PAENetwork(d_pair)
|
||||
self.pde_pred = PAENetwork(
|
||||
d_pair
|
||||
) # distance error, but use same architecture as aligned error
|
||||
# binder predictions are made on top of the pair features, just like
|
||||
# PAE predictions are. It's not clear if this is the best place to insert
|
||||
# this prediction head.
|
||||
# self.binder_network = BinderNetwork(d_pair, d_state)
|
||||
|
||||
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
|
||||
|
||||
self.use_atom_frames = use_atom_frames
|
||||
self.enable_same_chain = enable_same_chain
|
||||
self.get_quaternion = get_quaternion
|
||||
self.verbose_checks = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
msa_latent,
|
||||
msa_full,
|
||||
seq,
|
||||
seq_unmasked,
|
||||
xyz,
|
||||
sctors,
|
||||
idx,
|
||||
bond_feats,
|
||||
dist_matrix,
|
||||
chirals,
|
||||
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
||||
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
|
||||
return_raw=False,
|
||||
use_checkpoint=False,
|
||||
return_infer=False, #fd ?
|
||||
p2p_crop=-1, topk_crop=-1, # striping
|
||||
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
|
||||
):
|
||||
# ic(get_shape(msa_latent))
|
||||
# ic(get_shape(msa_full))
|
||||
# ic(get_shape(seq))
|
||||
# ic(get_shape(seq_unmasked))
|
||||
# ic(get_shape(xyz))
|
||||
# ic(get_shape(sctors))
|
||||
# ic(get_shape(idx))
|
||||
# ic(get_shape(bond_feats))
|
||||
# ic(get_shape(chirals))
|
||||
# ic(get_shape(atom_frames))
|
||||
# ic(get_shape(t1d))
|
||||
# ic(get_shape(t2d))
|
||||
# ic(get_shape(xyz_t))
|
||||
# ic(get_shape(alpha_t))
|
||||
# ic(get_shape(mask_t))
|
||||
# ic(get_shape(same_chain))
|
||||
# ic(get_shape(msa_prev))
|
||||
# ic(get_shape(pair_prev))
|
||||
# ic(get_shape(mask_recycle))
|
||||
# ic()
|
||||
# ic()
|
||||
B, N, L = msa_latent.shape[:3]
|
||||
A = atom_frames.shape[1]
|
||||
dtype = msa_latent.dtype
|
||||
|
||||
if self.assert_single_sequence_input:
|
||||
assert_shape(msa_latent, (1, 1, L, 164))
|
||||
assert_shape(msa_full, (1, 1, L, 83))
|
||||
assert_shape(seq, (1, L))
|
||||
assert_shape(seq_unmasked, (1, L))
|
||||
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
|
||||
assert_shape(sctors, (1, L, 20, 2))
|
||||
assert_shape(idx, (1, L))
|
||||
assert_shape(bond_feats, (1, L, L))
|
||||
assert_shape(dist_matrix, (1, L, L))
|
||||
# assert_shape(chirals, (1, 0))
|
||||
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
|
||||
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
|
||||
assert_shape(t1d, (1, 1, L, 80))
|
||||
assert_shape(t2d, (1, 1, L, L, 68))
|
||||
assert_shape(xyz_t, (1, 1, L, 3))
|
||||
assert_shape(alpha_t, (1, 1, L, 60))
|
||||
assert_shape(mask_t, (1, 1, L, L))
|
||||
assert_shape(same_chain, (1, L, L))
|
||||
device = msa_latent.device
|
||||
assert_that(msa_full.device).is_equal_to(device)
|
||||
assert_that(seq.device).is_equal_to(device)
|
||||
assert_that(seq_unmasked.device).is_equal_to(device)
|
||||
assert_that(xyz.device).is_equal_to(device)
|
||||
assert_that(sctors.device).is_equal_to(device)
|
||||
assert_that(idx.device).is_equal_to(device)
|
||||
assert_that(bond_feats.device).is_equal_to(device)
|
||||
assert_that(dist_matrix.device).is_equal_to(device)
|
||||
assert_that(atom_frames.device).is_equal_to(device)
|
||||
assert_that(t1d.device).is_equal_to(device)
|
||||
assert_that(t2d.device).is_equal_to(device)
|
||||
assert_that(xyz_t.device).is_equal_to(device)
|
||||
assert_that(alpha_t.device).is_equal_to(device)
|
||||
assert_that(mask_t.device).is_equal_to(device)
|
||||
assert_that(same_chain.device).is_equal_to(device)
|
||||
|
||||
if self.verbose_checks:
|
||||
#ic(is_motif.shape)
|
||||
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
|
||||
#is_protein_motif = is_motif & ~is_sm
|
||||
#if is_motif.any():
|
||||
# motif_protein_i = torch.where(is_motif)[0][0]
|
||||
#is_motif_sm = is_motif & is_sm
|
||||
#if is_sm.any():
|
||||
# motif_sm_i = torch.where(is_motif_sm)[0][0]
|
||||
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
|
||||
|
||||
"""
|
||||
msa_full: NSEQ,N_INDEL,N_TERMINUS,
|
||||
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
NINDEL = 1
|
||||
NTERMINUS = 2
|
||||
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
|
||||
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
|
||||
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
|
||||
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
|
||||
|
||||
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
|
||||
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
|
||||
|
||||
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
|
||||
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
|
||||
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||
msa_latent_indel2 = np.r_[
|
||||
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
|
||||
]
|
||||
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||
|
||||
#i_name = [(diffused_protein_i, "diffused_protein")]
|
||||
#if is_sm.any():
|
||||
# i_name.insert(0, (motif_sm_i, "motif_sm"))
|
||||
#if is_motif.any():
|
||||
# i_name.insert(0, (motif_protein_i, "motif_protein"))
|
||||
i_name = [(0, "tst")]
|
||||
for i, name in i_name:
|
||||
ic(f"------------------{name}:{i}----------------")
|
||||
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||
msa_full_indel = msa_full[
|
||||
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||
]
|
||||
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
|
||||
|
||||
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
|
||||
msa_latent_indel1 = msa_latent[
|
||||
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||
]
|
||||
msa_latent_indel2 = msa_latent[
|
||||
0,
|
||||
0,
|
||||
i,
|
||||
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
|
||||
]
|
||||
msa_latent_term = msa_latent[
|
||||
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||
]
|
||||
|
||||
assert_equal(msa_full_seq, msa_latent_seq1)
|
||||
assert_equal(msa_full_seq, msa_latent_seq2)
|
||||
assert_equal(msa_full_indel, msa_latent_indel1)
|
||||
assert_equal(msa_full_indel, msa_latent_indel2)
|
||||
assert_equal(msa_full_term, msa_latent_term)
|
||||
# if 'motif' in name:
|
||||
msa_cat = torch.where(msa_full_seq)[0]
|
||||
ic(msa_cat, seq[0, i])
|
||||
assert_equal(seq[0, i : i + 1], msa_cat)
|
||||
assert_equal(seq[0, i], seq_unmasked[0, i])
|
||||
ic(
|
||||
name,
|
||||
# torch.where(msa_latent[0,0,i,:80]),
|
||||
# torch.where(msa_full[0,0,i]),
|
||||
seq[0, i],
|
||||
seq_unmasked[0, i],
|
||||
torch.where(t1d[0, 0, i]),
|
||||
xyz[0, i, :4, 0],
|
||||
xyz_t[0, 0, i, 0],
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
#if self.enable_same_chain == False:
|
||||
# same_chain = None
|
||||
msa_latent, pair, state = self.latent_emb(
|
||||
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
|
||||
)
|
||||
msa_full = self.full_emb(msa_full, seq, idx)
|
||||
pair = pair + self.bond_emb(bond_feats)
|
||||
|
||||
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
|
||||
msa_full = msa_full.to(dtype)
|
||||
|
||||
#
|
||||
# Do recycling
|
||||
if msa_prev is None:
|
||||
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||
if pair_prev is None:
|
||||
pair_prev = torch.zeros_like(pair)
|
||||
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
||||
state_prev = torch.zeros_like(state)
|
||||
|
||||
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
||||
|
||||
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||
pair = pair + pair_recycle
|
||||
state = state + state_recycle # if state is not recycled these will be zeros
|
||||
|
||||
# add template embedding
|
||||
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
|
||||
|
||||
# Predict coordinates from given inputs
|
||||
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
|
||||
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
|
||||
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
|
||||
symmids, symmsub, symmRs, symmmeta,
|
||||
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
|
||||
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
|
||||
p2p_crop=p2p_crop, topk_crop=topk_crop
|
||||
)
|
||||
|
||||
if return_raw:
|
||||
# get last structure
|
||||
xyz_last = xyz_allatom[-1].unsqueeze(0)
|
||||
return msa[:,0], pair, xyz_last, alpha_s[-1], None
|
||||
|
||||
# predict masked amino acids
|
||||
logits_aa = self.aa_pred(msa)
|
||||
|
||||
# predict distogram & orientograms
|
||||
logits = self.c6d_pred(pair)
|
||||
|
||||
# Predict LDDT
|
||||
lddt = self.lddt_pred(state)
|
||||
|
||||
if self.verbose_checks:
|
||||
pseq_0 = logits_aa.permute(0, 2, 1)
|
||||
ic(pseq_0.shape)
|
||||
pseq_0 = pseq_0[0]
|
||||
ic(
|
||||
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
|
||||
)
|
||||
ic(
|
||||
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
|
||||
)
|
||||
|
||||
logits_pae = logits_pde = p_bind = None
|
||||
# predict aligned error and distance error
|
||||
logits_pae = self.pae_pred(pair)
|
||||
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||
|
||||
#fd predict bind/no-bind
|
||||
p_bind = self.bind_pred(logits_pae,same_chain)
|
||||
|
||||
if self.get_quaternion:
|
||||
return (
|
||||
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
|
||||
)
|
||||
else:
|
||||
return (
|
||||
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
|
||||
)
|
||||
1220
rf2aa/model/Track_module.py
Normal file
1220
rf2aa/model/Track_module.py
Normal file
File diff suppressed because it is too large
Load Diff
475
rf2aa/model/layers/Attention_module.py
Normal file
475
rf2aa/model/layers/Attention_module.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from opt_einsum import contract as einsum
|
||||
from rf2aa.util_module import init_lecun_normal
|
||||
class FeedForwardLayer(nn.Module):
|
||||
def __init__(self, d_model, r_ff, p_drop=0.1):
|
||||
super(FeedForwardLayer, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize linear layer right before ReLu: He initializer (kaiming normal)
|
||||
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.linear1.bias)
|
||||
|
||||
# initialize linear layer right before residual connection: zero initialize
|
||||
nn.init.zeros_(self.linear2.weight)
|
||||
nn.init.zeros_(self.linear2.bias)
|
||||
|
||||
def forward(self, src):
|
||||
src = self.norm(src)
|
||||
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
||||
return src
|
||||
|
||||
class Attention(nn.Module):
|
||||
# calculate multi-head attention
|
||||
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
|
||||
super(Attention, self).__init__()
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
#
|
||||
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||
#
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
#
|
||||
# initialize all parameters properly
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, query, key, value):
|
||||
B, Q = query.shape[:2]
|
||||
B, K = key.shape[:2]
|
||||
#
|
||||
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
||||
key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
||||
value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bqhd,bkhd->bhqk', query, key)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bhqk,bkhd->bqhd', attn, value)
|
||||
out = out.reshape(B, Q, self.h*self.dim)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
|
||||
return out
|
||||
|
||||
# MSA Attention (row/column) from AlphaFold architecture
|
||||
class SequenceWeight(nn.Module):
|
||||
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
|
||||
super(SequenceWeight, self).__init__()
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
self.scale = 1.0 / math.sqrt(self.dim)
|
||||
|
||||
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.dropout = nn.Dropout(p_drop)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_query.weight)
|
||||
nn.init.xavier_uniform_(self.to_key.weight)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
|
||||
tar_seq = msa[:,0]
|
||||
|
||||
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
|
||||
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
|
||||
|
||||
q = q * self.scale
|
||||
attn = einsum('bqihd,bkihd->bkihq', q, k)
|
||||
attn = F.softmax(attn, dim=1)
|
||||
return self.dropout(attn)
|
||||
|
||||
class MSARowAttentionWithBias(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
||||
super(MSARowAttentionWithBias, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
#
|
||||
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa, pair): # TODO: make this as tied-attention
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
#
|
||||
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(msa))
|
||||
#
|
||||
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
|
||||
key = key * self.scaling
|
||||
attn = einsum('bsqhd,bskhd->bqkh', query, key)
|
||||
attn = attn + bias
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
#
|
||||
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
|
||||
out = gate * out
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class MSAColAttention(nn.Module):
|
||||
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
||||
super(MSAColAttention, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
#
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
#
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||
gate = torch.sigmoid(self.to_g(msa))
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bqihd,bkihd->bihqk', query, key)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
|
||||
out = gate * out
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class MSAColGlobalAttention(nn.Module):
|
||||
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
|
||||
super(MSAColGlobalAttention, self).__init__()
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
#
|
||||
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
|
||||
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, msa):
|
||||
B, N, L = msa.shape[:3]
|
||||
#
|
||||
msa = self.norm_msa(msa)
|
||||
#
|
||||
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||
query = query.mean(dim=1) # (B, L, h, dim)
|
||||
key = self.to_k(msa) # (B, N, L, dim)
|
||||
value = self.to_v(msa) # (B, N, L, dim)
|
||||
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
|
||||
#
|
||||
query = query * self.scaling
|
||||
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
#
|
||||
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
|
||||
out = gate * out # (B, N, L, h*dim)
|
||||
#
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
|
||||
class TriangleAttention(nn.Module):
|
||||
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
|
||||
super(TriangleAttention, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_pair)
|
||||
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
|
||||
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
self.start_node=start_node
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, pair):
|
||||
B, L = pair.shape[:2]
|
||||
|
||||
pair = self.norm(pair)
|
||||
|
||||
# input projection
|
||||
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
|
||||
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
|
||||
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
|
||||
bias = self.to_b(pair) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||
|
||||
# attention
|
||||
query = query * self.scaling
|
||||
if self.start_node:
|
||||
attn = einsum('bijhd,bikhd->bijkh', query, key)
|
||||
else:
|
||||
attn = einsum('bijhd,bkjhd->bijkh', query, key)
|
||||
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
|
||||
attn = F.softmax(attn, dim=-2)
|
||||
if self.start_node:
|
||||
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||
else:
|
||||
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||
out = gate * out # gated attention
|
||||
|
||||
# output projection
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class TriangleMultiplication(nn.Module):
|
||||
def __init__(self, d_pair, d_hidden=128, outgoing=True):
|
||||
super(TriangleMultiplication, self).__init__()
|
||||
self.norm = nn.LayerNorm(d_pair)
|
||||
self.left_proj = nn.Linear(d_pair, d_hidden)
|
||||
self.right_proj = nn.Linear(d_pair, d_hidden)
|
||||
self.left_gate = nn.Linear(d_pair, d_hidden)
|
||||
self.right_gate = nn.Linear(d_pair, d_hidden)
|
||||
#
|
||||
self.gate = nn.Linear(d_pair, d_pair)
|
||||
self.norm_out = nn.LayerNorm(d_hidden)
|
||||
self.out_proj = nn.Linear(d_hidden, d_pair)
|
||||
|
||||
self.outgoing = outgoing
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# normal distribution for regular linear weights
|
||||
self.left_proj = init_lecun_normal(self.left_proj)
|
||||
self.right_proj = init_lecun_normal(self.right_proj)
|
||||
|
||||
# Set Bias of Linear layers to zeros
|
||||
nn.init.zeros_(self.left_proj.bias)
|
||||
nn.init.zeros_(self.right_proj.bias)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.left_gate.weight)
|
||||
nn.init.ones_(self.left_gate.bias)
|
||||
|
||||
nn.init.zeros_(self.right_gate.weight)
|
||||
nn.init.ones_(self.right_gate.bias)
|
||||
|
||||
nn.init.zeros_(self.gate.weight)
|
||||
nn.init.ones_(self.gate.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.out_proj.weight)
|
||||
nn.init.zeros_(self.out_proj.bias)
|
||||
|
||||
def forward(self, pair):
|
||||
B, L = pair.shape[:2]
|
||||
pair = self.norm(pair)
|
||||
|
||||
left = self.left_proj(pair) # (B, L, L, d_h)
|
||||
left_gate = torch.sigmoid(self.left_gate(pair))
|
||||
left = left_gate * left
|
||||
|
||||
right = self.right_proj(pair) # (B, L, L, d_h)
|
||||
right_gate = torch.sigmoid(self.right_gate(pair))
|
||||
right = right_gate * right
|
||||
|
||||
if self.outgoing:
|
||||
out = einsum('bikd,bjkd->bijd', left, right/float(L))
|
||||
else:
|
||||
out = einsum('bkid,bkjd->bijd', left, right/float(L))
|
||||
out = self.norm_out(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
|
||||
out = gate * out
|
||||
return out
|
||||
|
||||
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
|
||||
class BiasedAxialAttention(nn.Module):
|
||||
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
||||
super(BiasedAxialAttention, self).__init__()
|
||||
#
|
||||
self.is_row = is_row
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.norm_bias = nn.LayerNorm(d_bias)
|
||||
|
||||
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||
self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
||||
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||
|
||||
self.scaling = 1/math.sqrt(d_hidden)
|
||||
self.h = n_head
|
||||
self.dim = d_hidden
|
||||
|
||||
# initialize all parameters properly
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||
nn.init.xavier_uniform_(self.to_q.weight)
|
||||
nn.init.xavier_uniform_(self.to_k.weight)
|
||||
nn.init.xavier_uniform_(self.to_v.weight)
|
||||
|
||||
# bias: normal distribution
|
||||
self.to_b = init_lecun_normal(self.to_b)
|
||||
|
||||
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||
nn.init.zeros_(self.to_g.weight)
|
||||
nn.init.ones_(self.to_g.bias)
|
||||
|
||||
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
nn.init.zeros_(self.to_out.bias)
|
||||
|
||||
def forward(self, pair, bias):
|
||||
# pair: (B, L, L, d_pair)
|
||||
B, L = pair.shape[:2]
|
||||
|
||||
if self.is_row:
|
||||
pair = pair.permute(0,2,1,3)
|
||||
bias = bias.permute(0,2,1,3)
|
||||
|
||||
pair = self.norm_pair(pair)
|
||||
bias = self.norm_bias(bias)
|
||||
|
||||
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
||||
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
||||
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
||||
bias = self.to_b(bias) # (B, L, L, h)
|
||||
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||
|
||||
query = query * self.scaling
|
||||
key = key / L # normalize for tied attention
|
||||
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
|
||||
attn = attn + bias # apply bias
|
||||
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
|
||||
|
||||
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(B, L, L, -1)
|
||||
out = gate * out
|
||||
|
||||
out = self.to_out(out)
|
||||
if self.is_row:
|
||||
out = out.permute(0,2,1,3)
|
||||
return out
|
||||
|
||||
111
rf2aa/model/layers/AuxiliaryPredictor.py
Normal file
111
rf2aa/model/layers/AuxiliaryPredictor.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
class DistanceNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.0):
|
||||
super(DistanceNetwork, self).__init__()
|
||||
#HACK: dimensions are hard coded here
|
||||
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
|
||||
self.proj_asymm = nn.Linear(n_feat, 37+19)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
# initialize linear layer for final logit prediction
|
||||
nn.init.zeros_(self.proj_symm.weight)
|
||||
nn.init.zeros_(self.proj_asymm.weight)
|
||||
nn.init.zeros_(self.proj_symm.bias)
|
||||
nn.init.zeros_(self.proj_asymm.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# input: pair info (B, L, L, C)
|
||||
|
||||
# predict theta, phi (non-symmetric)
|
||||
logits_asymm = self.proj_asymm(x)
|
||||
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
|
||||
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
|
||||
|
||||
# predict dist, omega
|
||||
logits_symm = self.proj_symm(x)
|
||||
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
|
||||
logits_dist = logits_symm[:,:,:,:61].permute(0,3,1,2)
|
||||
logits_omega = logits_symm[:,:,:,61:].permute(0,3,1,2)
|
||||
|
||||
return logits_dist, logits_omega, logits_theta, logits_phi
|
||||
|
||||
class MaskedTokenNetwork(nn.Module):
|
||||
def __init__(self, n_feat, p_drop=0.0):
|
||||
super(MaskedTokenNetwork, self).__init__()
|
||||
|
||||
#fd note this predicts probability for the mask token (which is never in ground truth)
|
||||
# it should be ok though(?)
|
||||
self.proj = nn.Linear(n_feat, ChemData().NAATOKENS)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, L = x.shape[:3]
|
||||
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
|
||||
|
||||
return logits
|
||||
|
||||
class LDDTNetwork(nn.Module):
|
||||
def __init__(self, n_feat, n_bin_lddt=50):
|
||||
super(LDDTNetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, n_bin_lddt)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.proj(x) # (B, L, 50)
|
||||
|
||||
return logits.permute(0,2,1)
|
||||
|
||||
class PAENetwork(nn.Module):
|
||||
def __init__(self, n_feat, n_bin_pae=64):
|
||||
super(PAENetwork, self).__init__()
|
||||
self.proj = nn.Linear(n_feat, n_bin_pae)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.proj(x) # (B, L, L, 64)
|
||||
|
||||
return logits.permute(0,3,1,2)
|
||||
|
||||
class BinderNetwork(nn.Module):
|
||||
def __init__(self, n_bin_pae=64):
|
||||
super(BinderNetwork, self).__init__()
|
||||
self.classify = torch.nn.Linear(n_bin_pae, 1)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
nn.init.zeros_(self.classify.weight)
|
||||
nn.init.zeros_(self.classify.bias)
|
||||
|
||||
def forward(self, pae, same_chain):
|
||||
logits = pae.permute(0,2,3,1)
|
||||
logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain
|
||||
prob = torch.sigmoid( self.classify( logits_inter ) )
|
||||
return prob
|
||||
|
||||
aux_predictor_factory = {
|
||||
"c6d": DistanceNetwork,
|
||||
"mlm": MaskedTokenNetwork,
|
||||
"plddt": LDDTNetwork,
|
||||
"pae": PAENetwork,
|
||||
"binder": BinderNetwork
|
||||
}
|
||||
458
rf2aa/model/layers/Embeddings.py
Normal file
458
rf2aa/model/layers/Embeddings.py
Normal file
@@ -0,0 +1,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from opt_einsum import contract as einsum
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from rf2aa.util import *
|
||||
from rf2aa.util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
|
||||
from rf2aa.model.layers.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
|
||||
from rf2aa.model.Track_module import PairStr2Pair, PositionalEncoding2D
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
# Module contains classes and functions to generate initial embeddings
|
||||
|
||||
class MSA_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0,
|
||||
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False, enable_same_chain=False):
|
||||
if (d_init==0):
|
||||
d_init = 2*ChemData().NAATOKENS+2+2
|
||||
|
||||
super(MSA_emb, self).__init__()
|
||||
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
|
||||
self.emb_left = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||
self.emb_right = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||
self.emb_state = nn.Embedding(ChemData().NAATOKENS, d_state)
|
||||
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
|
||||
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain,
|
||||
enable_same_chain=enable_same_chain)
|
||||
self.enable_same_chain = enable_same_chain
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
self.emb_q = init_lecun_normal(self.emb_q)
|
||||
self.emb_left = init_lecun_normal(self.emb_left)
|
||||
self.emb_right = init_lecun_normal(self.emb_right)
|
||||
self.emb_state = init_lecun_normal(self.emb_state)
|
||||
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
|
||||
def _msa_emb(self, msa, seq):
|
||||
N = msa.shape[1]
|
||||
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
|
||||
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
|
||||
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
|
||||
return msa
|
||||
|
||||
def _pair_emb(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||||
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
|
||||
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
|
||||
pair = left + right # (B, L, L, d_pair)
|
||||
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix, same_chain=same_chain) # add relative position
|
||||
|
||||
return pair
|
||||
|
||||
def _state_emb(self, seq):
|
||||
return self.emb_state(seq)
|
||||
|
||||
def forward(self, msa, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||||
# Inputs:
|
||||
# - msa: Input MSA (B, N, L, d_init)
|
||||
# - seq: Input Sequence (B, L)
|
||||
# - idx: Residue index
|
||||
# - bond_feats: Bond features (B, L, L)
|
||||
# Outputs:
|
||||
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||
# - pair: Initial Pair embedding (B, L, L, d_pair)
|
||||
|
||||
if self.enable_same_chain == False:
|
||||
same_chain = None
|
||||
|
||||
msa = self._msa_emb(msa, seq)
|
||||
|
||||
# pair embedding
|
||||
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix, same_chain=same_chain)
|
||||
# state embedding
|
||||
state = self._state_emb(seq)
|
||||
return msa, pair, state
|
||||
|
||||
class MSA_emb_nostate(MSA_emb):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
|
||||
super().__init__(d_msa, d_pair, d_state, d_init, minpos, maxpos, maxpos_atom, p_drop, use_same_chain)
|
||||
if d_init==0:
|
||||
d_init = 2*ChemData().NAATOKENS + 2 + 2
|
||||
self.emb_state = None # emb state is just the identity
|
||||
|
||||
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
|
||||
msa = self._msa_emb(msa, seq)
|
||||
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
|
||||
return msa, pair, None
|
||||
|
||||
class Extra_emb(nn.Module):
|
||||
# Get initial seed MSA embedding
|
||||
def __init__(self, d_msa=256, d_init=0, p_drop=0.1):
|
||||
super(Extra_emb, self).__init__()
|
||||
if d_init==0:
|
||||
d_init=ChemData().NAATOKENS-1+4
|
||||
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence
|
||||
#self.drop = nn.Dropout(p_drop)
|
||||
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
def forward(self, msa, seq, idx):
|
||||
# Inputs:
|
||||
# - msa: Input MSA (B, N, L, d_init)
|
||||
# - seq: Input Sequence (B, L)
|
||||
# - idx: Residue index
|
||||
# Outputs:
|
||||
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||
N = msa.shape[1] # number of sequenes in MSA
|
||||
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
|
||||
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||
#return self.drop(msa)
|
||||
return (msa)
|
||||
|
||||
class Bond_emb(nn.Module):
|
||||
def __init__(self, d_pair=128, d_init=0):
|
||||
super(Bond_emb, self).__init__()
|
||||
|
||||
if d_init==0:
|
||||
d_init = ChemData().NBTYPES
|
||||
|
||||
self.emb = nn.Linear(d_init, d_pair)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
def forward(self, bond_feats):
|
||||
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=ChemData().NBTYPES)
|
||||
return self.emb(bond_feats.float())
|
||||
|
||||
class TemplatePairStack(nn.Module):
|
||||
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, d_t1d=22, d_state=32, p_drop=0.25,
|
||||
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method=None):
|
||||
|
||||
super(TemplatePairStack, self).__init__()
|
||||
self.n_block = n_block
|
||||
self.proj_t1d = nn.Linear(d_t1d, d_state)
|
||||
|
||||
proc_s = [PairStr2Pair(d_pair=d_templ,
|
||||
n_head=n_head,
|
||||
d_hidden=d_hidden,
|
||||
d_state=d_state,
|
||||
p_drop=p_drop,
|
||||
symmetrize_repeats=symmetrize_repeats,
|
||||
repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k,
|
||||
sym_method=sym_method) for i in range(n_block)]
|
||||
|
||||
self.block = nn.ModuleList(proc_s)
|
||||
self.norm = nn.LayerNorm(d_templ)
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||
nn.init.zeros_(self.proj_t1d.bias)
|
||||
|
||||
def forward(self, templ, rbf_feat, t1d, use_checkpoint=False, p2p_crop=-1):
|
||||
B, T, L = templ.shape[:3]
|
||||
templ = templ.reshape(B*T, L, L, -1)
|
||||
t1d = t1d.reshape(B*T, L, -1)
|
||||
state = self.proj_t1d(t1d)
|
||||
|
||||
for i_block in range(self.n_block):
|
||||
if use_checkpoint:
|
||||
templ = checkpoint.checkpoint(
|
||||
create_custom_forward(self.block[i_block]),
|
||||
templ, rbf_feat, state, p2p_crop,
|
||||
use_reentrant=True
|
||||
)
|
||||
else:
|
||||
templ = self.block[i_block](templ, rbf_feat, state)
|
||||
return self.norm(templ).reshape(B, T, L, L, -1)
|
||||
|
||||
|
||||
def copy_main_2d(pair, Leff, idx):
|
||||
"""
|
||||
Copies the "main unit" of a block in generic 2D representation of shape (...,L,L,h)
|
||||
along the main diagonal
|
||||
"""
|
||||
start = idx*Leff
|
||||
end = (idx+1)*Leff
|
||||
|
||||
# grab the main block
|
||||
main = torch.clone( pair[..., start:end, start:end, :] )
|
||||
|
||||
# copy it around the main diag
|
||||
L = pair.shape[-2]
|
||||
assert L%Leff == 0
|
||||
N = L//Leff
|
||||
|
||||
for i_block in range(N):
|
||||
start = i_block*Leff
|
||||
stop = (i_block+1)*Leff
|
||||
|
||||
pair[...,start:stop, start:stop, :] = main
|
||||
|
||||
return pair
|
||||
|
||||
|
||||
def copy_main_1d(single, Leff, idx):
|
||||
"""
|
||||
Copies the "main unit" of a block in generic 1D representation of shape (...,L,h)
|
||||
to all other (non-main) blocks
|
||||
|
||||
Parameters:
|
||||
single (torch.tensor, required): Shape [...,L,h] "1D" tensor
|
||||
"""
|
||||
main_start = idx*Leff
|
||||
main_end = (idx+1)*Leff
|
||||
|
||||
# grab main block
|
||||
main = torch.clone(single[..., main_start:main_end, :])
|
||||
|
||||
# copy it around
|
||||
L = single.shape[-2]
|
||||
assert L%Leff == 0
|
||||
N = L//Leff
|
||||
|
||||
for i_block in range(N):
|
||||
start = i_block*Leff
|
||||
end = (i_block+1)*Leff
|
||||
|
||||
single[..., start:end, :] = main
|
||||
|
||||
return single
|
||||
|
||||
|
||||
class Templ_emb(nn.Module):
|
||||
# Get template embedding
|
||||
# Features are
|
||||
# t2d:
|
||||
# - 61 distogram bins + 6 orientations (67)
|
||||
# - Mask (missing/unaligned) (1)
|
||||
# t1d:
|
||||
# - tiled AA sequence (20 standard aa + gap)
|
||||
# - confidence (1)
|
||||
#
|
||||
def __init__(self, d_t1d=0, d_t2d=67+1, d_tor=0, d_pair=128, d_state=32,
|
||||
n_block=2, d_templ=64,
|
||||
n_head=4, d_hidden=16, p_drop=0.25,
|
||||
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method='mean',
|
||||
main_block=None, copy_main_block=None, additional_dt1d=0):
|
||||
if d_t1d==0:
|
||||
d_t1d=(ChemData().NAATOKENS-1)+1
|
||||
if d_tor==0:
|
||||
d_tor=3*ChemData().NTOTALDOFS
|
||||
|
||||
self.main_block = main_block
|
||||
self.symmetrize_repeats = symmetrize_repeats
|
||||
self.copy_main_block = copy_main_block
|
||||
self.repeat_length = repeat_length
|
||||
d_t1d += additional_dt1d
|
||||
|
||||
super(Templ_emb, self).__init__()
|
||||
# process 2D features
|
||||
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
|
||||
|
||||
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
d_hidden=d_hidden, d_t1d=d_t1d, d_state=d_state, p_drop=p_drop,
|
||||
symmetrize_repeats=symmetrize_repeats, repeat_length=repeat_length,
|
||||
symmsub_k=symmsub_k, sym_method=sym_method)
|
||||
|
||||
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
|
||||
|
||||
# process torsion angles
|
||||
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
|
||||
self.proj_t1d = nn.Linear(d_templ, d_templ)
|
||||
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||
# d_hidden=d_hidden, p_drop=p_drop)
|
||||
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.emb = init_lecun_normal(self.emb)
|
||||
nn.init.zeros_(self.emb.bias)
|
||||
|
||||
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
|
||||
nn.init.zeros_(self.emb_t1d.bias)
|
||||
|
||||
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||
nn.init.zeros_(self.proj_t1d.bias)
|
||||
|
||||
def _get_templ_emb(self, t1d, t2d):
|
||||
B, T, L, _ = t1d.shape
|
||||
# Prepare 2D template features
|
||||
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
|
||||
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
|
||||
#
|
||||
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
|
||||
return self.emb(templ) # Template templures (B, T, L, L, d_templ)
|
||||
|
||||
def _get_templ_rbf(self, xyz_t, mask_t):
|
||||
B, T, L = xyz_t.shape[:3]
|
||||
|
||||
# process each template features
|
||||
xyz_t = xyz_t.reshape(B*T, L, 3).contiguous()
|
||||
mask_t = mask_t.reshape(B*T, L, L)
|
||||
assert(xyz_t.is_contiguous())
|
||||
rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf)
|
||||
return rbf_feat
|
||||
|
||||
def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False, p2p_crop=-1):
|
||||
# Input
|
||||
# - t1d: 1D template info (B, T, L, 30)
|
||||
# - t2d: 2D template info (B, T, L, L, 44)
|
||||
# - alpha_t: torsion angle info (B, T, L, 30) - DOUBLE-CHECK
|
||||
# - xyz_t: template CA coordinates (B, T, L, 3)
|
||||
# - mask_t: is valid residue pair? (B, T, L, L)
|
||||
# - pair: query pair features (B, L, L, d_pair)
|
||||
# - state: query state features (B, L, d_state)
|
||||
B, T, L, _ = t1d.shape
|
||||
|
||||
templ = self._get_templ_emb(t1d, t2d)
|
||||
# this looks a lot like a bug but it is not
|
||||
# mask_t has already been updated by same_chain in the train_EMA script so pairwise distances between
|
||||
# protein chains are ignored
|
||||
rbf_feat = self._get_templ_rbf(xyz_t, mask_t)
|
||||
|
||||
# process each template pair feature
|
||||
templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop) # (B, T, L,L, d_templ)
|
||||
|
||||
# DJ - repeat protein symmetrization (2D)
|
||||
if self.copy_main_block:
|
||||
assert not (self.main_block is None)
|
||||
assert self.symmetrize_repeats
|
||||
# copy the main repeat unit internally down the pair representation diagonal
|
||||
templ = copy_main_2d(templ, self.repeat_length, self.main_block)
|
||||
|
||||
# Prepare 1D template torsion angle features
|
||||
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
|
||||
# process each template features
|
||||
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
|
||||
|
||||
# DJ - repeat protein symmetrization (1D)
|
||||
if self.copy_main_block:
|
||||
# already made assertions above
|
||||
# copy main unit down single rep
|
||||
t1d = copy_main_1d(t1d, self.repeat_length, self.main_block)
|
||||
|
||||
# mixing query state features to template state features
|
||||
state = state.reshape(B*L, 1, -1)
|
||||
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
|
||||
if use_checkpoint:
|
||||
out = checkpoint.checkpoint(
|
||||
create_custom_forward(self.attn_tor), state, t1d, t1d, use_reentrant=True
|
||||
)
|
||||
out = out.reshape(B, L, -1)
|
||||
else:
|
||||
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
|
||||
state = state.reshape(B, L, -1)
|
||||
state = state + out
|
||||
|
||||
# mixing query pair features to template information (Template pointwise attention)
|
||||
pair = pair.reshape(B*L*L, 1, -1)
|
||||
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
|
||||
if use_checkpoint:
|
||||
out = checkpoint.checkpoint(
|
||||
create_custom_forward(self.attn), pair, templ, templ, use_reentrant=True
|
||||
)
|
||||
out = out.reshape(B, L, L, -1)
|
||||
else:
|
||||
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
|
||||
#
|
||||
pair = pair.reshape(B, L, L, -1)
|
||||
pair = pair + out
|
||||
|
||||
return pair, state
|
||||
|
||||
|
||||
class Recycling(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||
super(Recycling, self).__init__()
|
||||
self.proj_dist = nn.Linear(d_rbf, d_pair)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
#self.emb_rbf = init_lecun_normal(self.emb_rbf)
|
||||
#nn.init.zeros_(self.emb_rbf.bias)
|
||||
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||
nn.init.zeros_(self.proj_dist.bias)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = msa.shape[:2]
|
||||
msa = self.norm_msa(msa)
|
||||
pair = self.norm_pair(pair)
|
||||
|
||||
Ca = xyz[:,:,1]
|
||||
dist_CA = rbf(
|
||||
torch.cdist(Ca, Ca)
|
||||
).reshape(B,L,L,-1)
|
||||
|
||||
if mask_recycle != None:
|
||||
dist_CA = mask_recycle[...,None].float()*dist_CA
|
||||
|
||||
pair = pair + self.proj_dist(dist_CA)
|
||||
|
||||
return msa, pair, state # state is just zeros
|
||||
|
||||
class RecyclingAllFeatures(nn.Module):
|
||||
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||
super(RecyclingAllFeatures, self).__init__()
|
||||
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
|
||||
self.norm_pair = nn.LayerNorm(d_pair)
|
||||
self.proj_sctors = nn.Linear(2*ChemData().NTOTALDOFS, d_msa)
|
||||
self.norm_msa = nn.LayerNorm(d_msa)
|
||||
self.norm_state = nn.LayerNorm(d_state)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||
nn.init.zeros_(self.proj_dist.bias)
|
||||
self.proj_sctors = init_lecun_normal(self.proj_sctors)
|
||||
nn.init.zeros_(self.proj_sctors.bias)
|
||||
|
||||
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||
B, L = pair.shape[:2]
|
||||
state = self.norm_state(state)
|
||||
|
||||
left = state.unsqueeze(2).expand(-1,-1,L,-1)
|
||||
right = state.unsqueeze(1).expand(-1,L,-1,-1)
|
||||
|
||||
Ca_or_P = xyz[:,:,1].contiguous()
|
||||
|
||||
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
|
||||
if mask_recycle != None:
|
||||
dist = mask_recycle[...,None].float()*dist
|
||||
dist = torch.cat((dist, left, right), dim=-1)
|
||||
dist = self.proj_dist(dist)
|
||||
pair = dist + self.norm_pair(pair)
|
||||
|
||||
sctors = self.proj_sctors(sctors.reshape(B,-1,2*ChemData().NTOTALDOFS))
|
||||
msa = sctors + self.norm_msa(msa)
|
||||
|
||||
return msa, pair, state
|
||||
|
||||
recycling_factory = {
|
||||
"msa_pair": Recycling,
|
||||
"all": RecyclingAllFeatures
|
||||
}
|
||||
100
rf2aa/model/layers/SE3_network.py
Normal file
100
rf2aa/model/layers/SE3_network.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from icecream import ic
|
||||
import inspect
|
||||
|
||||
import sys, os
|
||||
#script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
#sys.path.insert(0,script_dir+'SE3Transformer')
|
||||
|
||||
from rf2aa.util import xyz_frame_from_rotation_mask
|
||||
from rf2aa.util_module import init_lecun_normal_param, \
|
||||
make_full_graph, rbf, init_lecun_normal
|
||||
from rf2aa.loss.loss import calc_chiral_grads
|
||||
from rf2aa.model.layers.Attention_module import FeedForwardLayer
|
||||
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.util_module import get_seqsep_protein_sm
|
||||
|
||||
se3_transformer_path = inspect.getfile(SE3Transformer)
|
||||
se3_fiber_path = inspect.getfile(Fiber)
|
||||
assert 'rf2aa' in se3_transformer_path
|
||||
|
||||
class SE3TransformerWrapper(nn.Module):
|
||||
"""SE(3) equivariant GCN with attention"""
|
||||
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||
l0_in_features=32, l0_out_features=32,
|
||||
l1_in_features=3, l1_out_features=2,
|
||||
num_edge_features=32):
|
||||
super().__init__()
|
||||
# Build the network
|
||||
self.l1_in = l1_in_features
|
||||
self.l1_out = l1_out_features
|
||||
#
|
||||
fiber_edge = Fiber({0: num_edge_features})
|
||||
if l1_out_features > 0:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||
else:
|
||||
if l1_in_features > 0:
|
||||
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
else:
|
||||
fiber_in = Fiber({0: l0_in_features})
|
||||
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||
fiber_out = Fiber({0: l0_out_features})
|
||||
|
||||
self.se3 = SE3Transformer(num_layers=num_layers,
|
||||
fiber_in=fiber_in,
|
||||
fiber_hidden=fiber_hidden,
|
||||
fiber_out = fiber_out,
|
||||
num_heads=n_heads,
|
||||
channels_div=div,
|
||||
fiber_edge=fiber_edge,
|
||||
populate_edge="arcsin",
|
||||
final_layer="lin",
|
||||
use_layer_norm=True)
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
def reset_parameter(self):
|
||||
|
||||
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||
for n, p in self.se3.named_parameters():
|
||||
if "bias" in n:
|
||||
nn.init.zeros_(p)
|
||||
elif len(p.shape) == 1:
|
||||
continue
|
||||
else:
|
||||
if "radial_func" not in n:
|
||||
p = init_lecun_normal_param(p)
|
||||
else:
|
||||
if "net.6" in n:
|
||||
nn.init.zeros_(p)
|
||||
else:
|
||||
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||
|
||||
# make last layers to be zero-initialized
|
||||
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||
if self.l1_out > 0:
|
||||
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||
|
||||
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||
if self.l1_in > 0:
|
||||
node_features = {'0': type_0_features, '1': type_1_features}
|
||||
else:
|
||||
node_features = {'0': type_0_features}
|
||||
edge_features = {'0': edge_features}
|
||||
return self.se3(G, node_features, edge_features)
|
||||
|
||||
Reference in New Issue
Block a user