Files
immunebuilder/ImmuneBuilder/models.py
Olamide Isreal 8887cbe592
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
Configure ImmuneBuilder pipeline for WES execution
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest
- Update input/output paths to S3 (s3://omic/eureka/immunebuilder/)
- Remove local mount containerOptions (not needed in k8s)
- Update homepage to Gitea repo URL
- Clean history to remove large model weight blobs
2026-03-16 15:31:53 +01:00

231 lines
10 KiB
Python

import torch
from einops import rearrange
from ImmuneBuilder.rigids import Rigid, Rot, rigid_body_identity, vec_from_tensor, global_frames_from_bb_frame_and_torsion_angles, all_atoms_from_global_reference_frames
class InvariantPointAttention(torch.nn.Module):
def __init__(self, node_dim, edge_dim, heads=12, head_dim=16, n_query_points=4, n_value_points=8, **kwargs):
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.n_query_points = n_query_points
node_scalar_attention_inner_dim = heads * head_dim
node_vector_attention_inner_dim = 3 * n_query_points * heads
node_vector_attention_value_dim = 3 * n_value_points * heads
after_final_cat_dim = heads * edge_dim + heads * head_dim + heads * n_value_points * 4
point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
self.point_weight = torch.nn.Parameter(point_weight_init_value)
self.to_scalar_qkv = torch.nn.Linear(node_dim, 3 * node_scalar_attention_inner_dim, bias=False)
self.to_vector_qk = torch.nn.Linear(node_dim, 2 * node_vector_attention_inner_dim, bias=False)
self.to_vector_v = torch.nn.Linear(node_dim, node_vector_attention_value_dim, bias=False)
self.to_scalar_edge_attention_bias = torch.nn.Linear(edge_dim, heads, bias=False)
self.final_linear = torch.nn.Linear(after_final_cat_dim, node_dim)
with torch.no_grad():
self.final_linear.weight.fill_(0.0)
self.final_linear.bias.fill_(0.0)
def forward(self, node_features, edge_features, rigid):
# Classic attention on nodes
scalar_qkv = self.to_scalar_qkv(node_features).chunk(3, dim=-1)
scalar_q, scalar_k, scalar_v = map(lambda t: rearrange(t, 'n (h d) -> h n d', h=self.heads), scalar_qkv)
node_scalar = torch.einsum('h i d, h j d -> h i j', scalar_q, scalar_k) * self.head_dim ** (-1 / 2)
# Linear bias on edges
edge_bias = rearrange(self.to_scalar_edge_attention_bias(edge_features), 'i j h -> h i j')
# Reference frame attention
wc = (2 / self.n_query_points) ** (1 / 2) / 6
vector_qk = self.to_vector_qk(node_features).chunk(2, dim=-1)
vector_q, vector_k = map(lambda x: vec_from_tensor(rearrange(x, 'n (h p d) -> h n p d', h=self.heads, d=3)),
vector_qk)
rigid_ = rigid.unsqueeze(0).unsqueeze(-1) # add head and point dimension to rigids
global_vector_k = rigid_ @ vector_k
global_vector_q = rigid_ @ vector_q
global_frame_distance = wc * global_vector_q.unsqueeze(-2).dist(global_vector_k.unsqueeze(-3)).sum(
-1) * rearrange(self.point_weight, "h -> h () ()")
# Combining attentions
attention_matrix = (3 ** (-1 / 2) * (node_scalar + edge_bias - global_frame_distance)).softmax(-1)
# Obtaining outputs
edge_output = (rearrange(attention_matrix, 'h i j -> i h () j') * rearrange(edge_features,
'i j d -> i () d j')).sum(-1)
scalar_node_output = torch.einsum('h i j, h j d -> i h d', attention_matrix, scalar_v)
vector_v = vec_from_tensor(
rearrange(self.to_vector_v(node_features), 'n (h p d) -> h n p d', h=self.heads, d=3))
global_vector_v = rigid_ @ vector_v
attended_global_vector_v = global_vector_v.map(
lambda x: torch.einsum('h i j, h j p -> h i p', attention_matrix, x))
vector_node_output = rigid_.inv() @ attended_global_vector_v
vector_node_output = torch.stack(
[vector_node_output.norm(), vector_node_output.x, vector_node_output.y, vector_node_output.z], dim=-1)
# Concatenate along heads and points
edge_output = rearrange(edge_output, 'n h d -> n (h d)')
scalar_node_output = rearrange(scalar_node_output, 'n h d -> n (h d)')
vector_node_output = rearrange(vector_node_output, 'h n p d -> n (h p d)')
combined = torch.cat([edge_output, scalar_node_output, vector_node_output], dim=-1)
return node_features + self.final_linear(combined)
class BackboneUpdate(torch.nn.Module):
def __init__(self, node_dim):
super().__init__()
self.to_correction = torch.nn.Linear(node_dim, 6)
def forward(self, node_features, update_mask=None):
# Predict quaternions and translation vector
rot, t = self.to_correction(node_features).chunk(2, dim=-1)
# I may not want to update all residues
if update_mask is not None:
rot = update_mask[:, None] * rot
t = update_mask[:, None] * t
# Normalize quaternions
norm = (1 + rot.pow(2).sum(-1, keepdim=True)).pow(1 / 2)
b, c, d = (rot / norm).chunk(3, dim=-1)
a = 1 / norm
a, b, c, d = a.squeeze(-1), b.squeeze(-1), c.squeeze(-1), d.squeeze(-1)
# Make rotation matrix from quaternions
R = Rot(
(a ** 2 + b ** 2 - c ** 2 - d ** 2), (2 * b * c - 2 * a * d), (2 * b * d + 2 * a * c),
(2 * b * c + 2 * a * d), (a ** 2 - b ** 2 + c ** 2 - d ** 2), (2 * c * d - 2 * a * b),
(2 * b * d - 2 * a * c), (2 * c * d + 2 * a * b), (a ** 2 - b ** 2 - c ** 2 + d ** 2)
)
return Rigid(vec_from_tensor(t), R)
class TorsionAngles(torch.nn.Module):
def __init__(self, node_dim):
super().__init__()
self.residual1 = torch.nn.Sequential(
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim)
)
self.residual2 = torch.nn.Sequential(
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim)
)
self.final_pred = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 10)
)
with torch.no_grad():
self.residual1[-1].weight.fill_(0.0)
self.residual2[-1].weight.fill_(0.0)
self.residual1[-1].bias.fill_(0.0)
self.residual2[-1].bias.fill_(0.0)
def forward(self, node_features, s_i):
full_feat = torch.cat([node_features, s_i], axis=-1)
full_feat = full_feat + self.residual1(full_feat)
full_feat = full_feat + self.residual2(full_feat)
torsions = rearrange(self.final_pred(full_feat), "i (t d) -> i t d", d=2)
norm = torch.norm(torsions, dim=-1, keepdim=True)
return torsions / norm, norm
class StructureUpdate(torch.nn.Module):
def __init__(self, node_dim, edge_dim, dropout=0.0, **kwargs):
super().__init__()
self.IPA = InvariantPointAttention(node_dim, edge_dim, **kwargs)
self.norm1 = torch.nn.Sequential(
torch.nn.Dropout(dropout),
torch.nn.LayerNorm(node_dim)
)
self.norm2 = torch.nn.Sequential(
torch.nn.Dropout(dropout),
torch.nn.LayerNorm(node_dim)
)
self.residual = torch.nn.Sequential(
torch.nn.Linear(node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, 2 * node_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * node_dim, node_dim)
)
self.torsion_angles = TorsionAngles(node_dim)
self.backbone_update = BackboneUpdate(node_dim)
with torch.no_grad():
self.residual[-1].weight.fill_(0.0)
self.residual[-1].bias.fill_(0.0)
def forward(self, node_features, edge_features, rigid_pred, update_mask=None):
s_i = self.IPA(node_features, edge_features, rigid_pred)
s_i = self.norm1(s_i)
s_i = s_i + self.residual(s_i)
s_i = self.norm2(s_i)
rigid_new = rigid_pred @ self.backbone_update(s_i, update_mask)
return s_i, rigid_new
class StructureModule(torch.nn.Module):
def __init__(self, node_dim=23, n_layers=8, rel_pos_dim=64, embed_dim=128, **kwargs):
super().__init__()
self.n_layers = n_layers
self.rel_pos_dim = rel_pos_dim
self.node_embed = torch.nn.Linear(node_dim, embed_dim)
self.edge_embed = torch.nn.Linear(2 * rel_pos_dim + 1, embed_dim - 1)
self.layers = torch.nn.ModuleList(
[StructureUpdate(node_dim=embed_dim,
edge_dim=embed_dim,
propagate_rotation_gradient=(i == n_layers - 1),
**kwargs)
for i in range(n_layers)])
def forward(self, node_features, sequence):
rigid_in = rigid_body_identity(len(sequence)).to(node_features.device)
relative_positions = (torch.arange(node_features.shape[-2])[None] -
torch.arange(node_features.shape[-2])[:, None])
relative_positions = relative_positions.clamp(min=-self.rel_pos_dim, max=self.rel_pos_dim) + self.rel_pos_dim
rel_pos_embeddings = torch.nn.functional.one_hot(relative_positions, num_classes=2 * self.rel_pos_dim + 1)
rel_pos_embeddings = rel_pos_embeddings.to(dtype=node_features.dtype, device=node_features.device)
rel_pos_embeddings = self.edge_embed(rel_pos_embeddings)
new_node_features = self.node_embed(node_features)
for layer in self.layers:
edge_features = torch.cat(
[rigid_in.origin.unsqueeze(-1).dist(rigid_in.origin).unsqueeze(-1), rel_pos_embeddings], dim=-1)
new_node_features, rigid_in = layer(new_node_features, edge_features, rigid_in)
torsions, _ = self.layers[-1].torsion_angles(self.node_embed(node_features), new_node_features)
all_reference_frames = global_frames_from_bb_frame_and_torsion_angles(rigid_in, torsions, sequence)
all_atoms = all_atoms_from_global_reference_frames(all_reference_frames, sequence)
# Remove atoms of side chains with outrageous clashes
ds = torch.linalg.norm(all_atoms[None,:,None] - all_atoms[:,None,:,None], axis = -1)
ds[torch.isnan(ds) | (ds==0.0)] = 10
min_ds = ds.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0]
all_atoms[min_ds < 0.2, 5:, :] = float("Nan")
return all_atoms, new_node_features