Configure ImmuneBuilder pipeline for WES execution
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest - Update input/output paths to S3 (s3://omic/eureka/immunebuilder/) - Remove local mount containerOptions (not needed in k8s) - Update homepage to Gitea repo URL - Clean history to remove large model weight blobs
This commit is contained in:
230
ImmuneBuilder/models.py
Normal file
230
ImmuneBuilder/models.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from ImmuneBuilder.rigids import Rigid, Rot, rigid_body_identity, vec_from_tensor, global_frames_from_bb_frame_and_torsion_angles, all_atoms_from_global_reference_frames
|
||||
|
||||
class InvariantPointAttention(torch.nn.Module):
|
||||
def __init__(self, node_dim, edge_dim, heads=12, head_dim=16, n_query_points=4, n_value_points=8, **kwargs):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = head_dim
|
||||
self.n_query_points = n_query_points
|
||||
|
||||
node_scalar_attention_inner_dim = heads * head_dim
|
||||
node_vector_attention_inner_dim = 3 * n_query_points * heads
|
||||
node_vector_attention_value_dim = 3 * n_value_points * heads
|
||||
after_final_cat_dim = heads * edge_dim + heads * head_dim + heads * n_value_points * 4
|
||||
|
||||
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
|
||||
self.point_weight = torch.nn.Parameter(point_weight_init_value)
|
||||
|
||||
self.to_scalar_qkv = torch.nn.Linear(node_dim, 3 * node_scalar_attention_inner_dim, bias=False)
|
||||
self.to_vector_qk = torch.nn.Linear(node_dim, 2 * node_vector_attention_inner_dim, bias=False)
|
||||
self.to_vector_v = torch.nn.Linear(node_dim, node_vector_attention_value_dim, bias=False)
|
||||
self.to_scalar_edge_attention_bias = torch.nn.Linear(edge_dim, heads, bias=False)
|
||||
self.final_linear = torch.nn.Linear(after_final_cat_dim, node_dim)
|
||||
|
||||
with torch.no_grad():
|
||||
self.final_linear.weight.fill_(0.0)
|
||||
self.final_linear.bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, edge_features, rigid):
|
||||
# Classic attention on nodes
|
||||
scalar_qkv = self.to_scalar_qkv(node_features).chunk(3, dim=-1)
|
||||
scalar_q, scalar_k, scalar_v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=self.heads), scalar_qkv)
|
||||
node_scalar = torch.einsum('h i d, h j d -> h i j', scalar_q, scalar_k) * self.head_dim ** (-1 / 2)
|
||||
|
||||
# Linear bias on edges
|
||||
edge_bias = rearrange(self.to_scalar_edge_attention_bias(edge_features), 'i j h -> h i j')
|
||||
|
||||
# Reference frame attention
|
||||
wc = (2 / self.n_query_points) ** (1 / 2) / 6
|
||||
vector_qk = self.to_vector_qk(node_features).chunk(2, dim=-1)
|
||||
vector_q, vector_k = map(lambda x: vec_from_tensor(rearrange(x, 'n (h p d) -> h n p d', h=self.heads, d=3)),
|
||||
vector_qk)
|
||||
rigid_ = rigid.unsqueeze(0).unsqueeze(-1) # add head and point dimension to rigids
|
||||
|
||||
global_vector_k = rigid_ @ vector_k
|
||||
global_vector_q = rigid_ @ vector_q
|
||||
global_frame_distance = wc * global_vector_q.unsqueeze(-2).dist(global_vector_k.unsqueeze(-3)).sum(
|
||||
-1) * rearrange(self.point_weight, "h -> h () ()")
|
||||
|
||||
# Combining attentions
|
||||
attention_matrix = (3 ** (-1 / 2) * (node_scalar + edge_bias - global_frame_distance)).softmax(-1)
|
||||
|
||||
# Obtaining outputs
|
||||
edge_output = (rearrange(attention_matrix, 'h i j -> i h () j') * rearrange(edge_features,
|
||||
'i j d -> i () d j')).sum(-1)
|
||||
scalar_node_output = torch.einsum('h i j, h j d -> i h d', attention_matrix, scalar_v)
|
||||
|
||||
vector_v = vec_from_tensor(
|
||||
rearrange(self.to_vector_v(node_features), 'n (h p d) -> h n p d', h=self.heads, d=3))
|
||||
global_vector_v = rigid_ @ vector_v
|
||||
attended_global_vector_v = global_vector_v.map(
|
||||
lambda x: torch.einsum('h i j, h j p -> h i p', attention_matrix, x))
|
||||
vector_node_output = rigid_.inv() @ attended_global_vector_v
|
||||
vector_node_output = torch.stack(
|
||||
[vector_node_output.norm(), vector_node_output.x, vector_node_output.y, vector_node_output.z], dim=-1)
|
||||
|
||||
# Concatenate along heads and points
|
||||
edge_output = rearrange(edge_output, 'n h d -> n (h d)')
|
||||
scalar_node_output = rearrange(scalar_node_output, 'n h d -> n (h d)')
|
||||
vector_node_output = rearrange(vector_node_output, 'h n p d -> n (h p d)')
|
||||
|
||||
combined = torch.cat([edge_output, scalar_node_output, vector_node_output], dim=-1)
|
||||
|
||||
return node_features + self.final_linear(combined)
|
||||
|
||||
|
||||
class BackboneUpdate(torch.nn.Module):
|
||||
def __init__(self, node_dim):
|
||||
super().__init__()
|
||||
|
||||
self.to_correction = torch.nn.Linear(node_dim, 6)
|
||||
|
||||
def forward(self, node_features, update_mask=None):
|
||||
# Predict quaternions and translation vector
|
||||
rot, t = self.to_correction(node_features).chunk(2, dim=-1)
|
||||
|
||||
# I may not want to update all residues
|
||||
if update_mask is not None:
|
||||
rot = update_mask[:, None] * rot
|
||||
t = update_mask[:, None] * t
|
||||
|
||||
# Normalize quaternions
|
||||
norm = (1 + rot.pow(2).sum(-1, keepdim=True)).pow(1 / 2)
|
||||
b, c, d = (rot / norm).chunk(3, dim=-1)
|
||||
a = 1 / norm
|
||||
a, b, c, d = a.squeeze(-1), b.squeeze(-1), c.squeeze(-1), d.squeeze(-1)
|
||||
|
||||
# Make rotation matrix from quaternions
|
||||
R = Rot(
|
||||
(a ** 2 + b ** 2 - c ** 2 - d ** 2), (2 * b * c - 2 * a * d), (2 * b * d + 2 * a * c),
|
||||
(2 * b * c + 2 * a * d), (a ** 2 - b ** 2 + c ** 2 - d ** 2), (2 * c * d - 2 * a * b),
|
||||
(2 * b * d - 2 * a * c), (2 * c * d + 2 * a * b), (a ** 2 - b ** 2 - c ** 2 + d ** 2)
|
||||
)
|
||||
|
||||
return Rigid(vec_from_tensor(t), R)
|
||||
|
||||
|
||||
class TorsionAngles(torch.nn.Module):
|
||||
def __init__(self, node_dim):
|
||||
super().__init__()
|
||||
self.residual1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim)
|
||||
)
|
||||
|
||||
self.residual2 = torch.nn.Sequential(
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim)
|
||||
)
|
||||
|
||||
self.final_pred = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 10)
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.residual1[-1].weight.fill_(0.0)
|
||||
self.residual2[-1].weight.fill_(0.0)
|
||||
self.residual1[-1].bias.fill_(0.0)
|
||||
self.residual2[-1].bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, s_i):
|
||||
full_feat = torch.cat([node_features, s_i], axis=-1)
|
||||
|
||||
full_feat = full_feat + self.residual1(full_feat)
|
||||
full_feat = full_feat + self.residual2(full_feat)
|
||||
torsions = rearrange(self.final_pred(full_feat), "i (t d) -> i t d", d=2)
|
||||
norm = torch.norm(torsions, dim=-1, keepdim=True)
|
||||
|
||||
return torsions / norm, norm
|
||||
|
||||
|
||||
class StructureUpdate(torch.nn.Module):
|
||||
def __init__(self, node_dim, edge_dim, dropout=0.0, **kwargs):
|
||||
super().__init__()
|
||||
self.IPA = InvariantPointAttention(node_dim, edge_dim, **kwargs)
|
||||
self.norm1 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(dropout),
|
||||
torch.nn.LayerNorm(node_dim)
|
||||
)
|
||||
self.norm2 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(dropout),
|
||||
torch.nn.LayerNorm(node_dim)
|
||||
)
|
||||
self.residual = torch.nn.Sequential(
|
||||
torch.nn.Linear(node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, 2 * node_dim),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(2 * node_dim, node_dim)
|
||||
)
|
||||
|
||||
self.torsion_angles = TorsionAngles(node_dim)
|
||||
self.backbone_update = BackboneUpdate(node_dim)
|
||||
|
||||
with torch.no_grad():
|
||||
self.residual[-1].weight.fill_(0.0)
|
||||
self.residual[-1].bias.fill_(0.0)
|
||||
|
||||
def forward(self, node_features, edge_features, rigid_pred, update_mask=None):
|
||||
s_i = self.IPA(node_features, edge_features, rigid_pred)
|
||||
s_i = self.norm1(s_i)
|
||||
s_i = s_i + self.residual(s_i)
|
||||
s_i = self.norm2(s_i)
|
||||
rigid_new = rigid_pred @ self.backbone_update(s_i, update_mask)
|
||||
|
||||
return s_i, rigid_new
|
||||
|
||||
|
||||
class StructureModule(torch.nn.Module):
|
||||
def __init__(self, node_dim=23, n_layers=8, rel_pos_dim=64, embed_dim=128, **kwargs):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.rel_pos_dim = rel_pos_dim
|
||||
self.node_embed = torch.nn.Linear(node_dim, embed_dim)
|
||||
self.edge_embed = torch.nn.Linear(2 * rel_pos_dim + 1, embed_dim - 1)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[StructureUpdate(node_dim=embed_dim,
|
||||
edge_dim=embed_dim,
|
||||
propagate_rotation_gradient=(i == n_layers - 1),
|
||||
**kwargs)
|
||||
for i in range(n_layers)])
|
||||
|
||||
def forward(self, node_features, sequence):
|
||||
rigid_in = rigid_body_identity(len(sequence)).to(node_features.device)
|
||||
relative_positions = (torch.arange(node_features.shape[-2])[None] -
|
||||
torch.arange(node_features.shape[-2])[:, None])
|
||||
relative_positions = relative_positions.clamp(min=-self.rel_pos_dim, max=self.rel_pos_dim) + self.rel_pos_dim
|
||||
|
||||
rel_pos_embeddings = torch.nn.functional.one_hot(relative_positions, num_classes=2 * self.rel_pos_dim + 1)
|
||||
rel_pos_embeddings = rel_pos_embeddings.to(dtype=node_features.dtype, device=node_features.device)
|
||||
rel_pos_embeddings = self.edge_embed(rel_pos_embeddings)
|
||||
|
||||
new_node_features = self.node_embed(node_features)
|
||||
|
||||
for layer in self.layers:
|
||||
edge_features = torch.cat(
|
||||
[rigid_in.origin.unsqueeze(-1).dist(rigid_in.origin).unsqueeze(-1), rel_pos_embeddings], dim=-1)
|
||||
new_node_features, rigid_in = layer(new_node_features, edge_features, rigid_in)
|
||||
|
||||
torsions, _ = self.layers[-1].torsion_angles(self.node_embed(node_features), new_node_features)
|
||||
|
||||
all_reference_frames = global_frames_from_bb_frame_and_torsion_angles(rigid_in, torsions, sequence)
|
||||
all_atoms = all_atoms_from_global_reference_frames(all_reference_frames, sequence)
|
||||
|
||||
# Remove atoms of side chains with outrageous clashes
|
||||
ds = torch.linalg.norm(all_atoms[None,:,None] - all_atoms[:,None,:,None], axis = -1)
|
||||
ds[torch.isnan(ds) | (ds==0.0)] = 10
|
||||
min_ds = ds.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0]
|
||||
all_atoms[min_ds < 0.2, 5:, :] = float("Nan")
|
||||
|
||||
return all_atoms, new_node_features
|
||||
Reference in New Issue
Block a user