Files
immunebuilder/ImmuneBuilder/rigids.py
Olamide Isreal 8887cbe592
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
Configure ImmuneBuilder pipeline for WES execution
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest
- Update input/output paths to S3 (s3://omic/eureka/immunebuilder/)
- Remove local mount containerOptions (not needed in k8s)
- Update homepage to Gitea repo URL
- Clean history to remove large model weight blobs
2026-03-16 15:31:53 +01:00

314 lines
13 KiB
Python

import torch
from ImmuneBuilder.constants import rigid_group_atom_positions2, chi2_centers, chi3_centers, chi4_centers, rel_pos, \
residue_atoms_mask
class Vector:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
self.shape = x.shape
assert (x.shape == y.shape) and (y.shape == z.shape), "x y and z should have the same shape"
def __add__(self, vec):
return Vector(vec.x + self.x, vec.y + self.y, vec.z + self.z)
def __sub__(self, vec):
return Vector(-vec.x + self.x, -vec.y + self.y, -vec.z + self.z)
def __mul__(self, param):
return Vector(param * self.x, param * self.y, param * self.z)
def __matmul__(self, vec):
return vec.x * self.x + vec.y * self.y + vec.z * self.z
def norm(self):
return (self.x ** 2 + self.y ** 2 + self.z ** 2 + 1e-8) ** (1 / 2)
def cross(self, other):
a = (self.y * other.z - self.z * other.y)
b = (self.z * other.x - self.x * other.z)
c = (self.x * other.y - self.y * other.x)
return Vector(a, b, c)
def dist(self, other):
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2 + 1e-8) ** (1 / 2)
def unsqueeze(self, dim):
return Vector(self.x.unsqueeze(dim), self.y.unsqueeze(dim), self.z.unsqueeze(dim))
def squeeze(self, dim):
return Vector(self.x.squeeze(dim), self.y.squeeze(dim), self.z.squeeze(dim))
def map(self, func):
return Vector(func(self.x), func(self.y), func(self.z))
def to(self, device):
return Vector(self.x.to(device), self.y.to(device), self.z.to(device))
def __str__(self):
return "Vector(x={},\ny={},\nz={})\n".format(self.x, self.y, self.z)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Vector(self.x[key], self.y[key], self.z[key])
class Rot:
def __init__(self, xx, xy, xz, yx, yy, yz, zx, zy, zz):
self.xx = xx
self.xy = xy
self.xz = xz
self.yx = yx
self.yy = yy
self.yz = yz
self.zx = zx
self.zy = zy
self.zz = zz
self.shape = xx.shape
def __matmul__(self, other):
if isinstance(other, Vector):
return Vector(
other.x * self.xx + other.y * self.xy + other.z * self.xz,
other.x * self.yx + other.y * self.yy + other.z * self.yz,
other.x * self.zx + other.y * self.zy + other.z * self.zz)
if isinstance(other, Rot):
return Rot(
xx=self.xx * other.xx + self.xy * other.yx + self.xz * other.zx,
xy=self.xx * other.xy + self.xy * other.yy + self.xz * other.zy,
xz=self.xx * other.xz + self.xy * other.yz + self.xz * other.zz,
yx=self.yx * other.xx + self.yy * other.yx + self.yz * other.zx,
yy=self.yx * other.xy + self.yy * other.yy + self.yz * other.zy,
yz=self.yx * other.xz + self.yy * other.yz + self.yz * other.zz,
zx=self.zx * other.xx + self.zy * other.yx + self.zz * other.zx,
zy=self.zx * other.xy + self.zy * other.yy + self.zz * other.zy,
zz=self.zx * other.xz + self.zy * other.yz + self.zz * other.zz,
)
else:
raise ValueError("Matmul against {}".format(type(other)))
def inv(self):
return Rot(
xx=self.xx, xy=self.yx, xz=self.zx,
yx=self.xy, yy=self.yy, yz=self.zy,
zx=self.xz, zy=self.yz, zz=self.zz
)
def det(self):
return self.xx * self.yy * self.zz + self.xy * self.yz * self.zx + self.yx * self.zy * self.xz - self.xz * self.yy * self.zx - self.xy * self.yx * self.zz - self.xx * self.zy * self.yz
def unsqueeze(self, dim):
return Rot(
self.xx.unsqueeze(dim=dim), self.xy.unsqueeze(dim=dim), self.xz.unsqueeze(dim=dim),
self.yx.unsqueeze(dim=dim), self.yy.unsqueeze(dim=dim), self.yz.unsqueeze(dim=dim),
self.zx.unsqueeze(dim=dim), self.zy.unsqueeze(dim=dim), self.zz.unsqueeze(dim=dim)
)
def squeeze(self, dim):
return Rot(
self.xx.squeeze(dim=dim), self.xy.squeeze(dim=dim), self.xz.squeeze(dim=dim),
self.yx.squeeze(dim=dim), self.yy.squeeze(dim=dim), self.yz.squeeze(dim=dim),
self.zx.squeeze(dim=dim), self.zy.squeeze(dim=dim), self.zz.squeeze(dim=dim)
)
def detach(self):
return Rot(
self.xx.detach(), self.xy.detach(), self.xz.detach(),
self.yx.detach(), self.yy.detach(), self.yz.detach(),
self.zx.detach(), self.zy.detach(), self.zz.detach()
)
def to(self, device):
return Rot(
self.xx.to(device), self.xy.to(device), self.xz.to(device),
self.yx.to(device), self.yy.to(device), self.yz.to(device),
self.zx.to(device), self.zy.to(device), self.zz.to(device)
)
def __str__(self):
return "Rot(xx={},\nxy={},\nxz={},\nyx={},\nyy={},\nyz={},\nzx={},\nzy={},\nzz={})\n".format(self.xx, self.xy,
self.xz, self.yx,
self.yy, self.yz,
self.zx, self.zy,
self.zz)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Rot(
self.xx[key], self.xy[key], self.xz[key],
self.yx[key], self.yy[key], self.yz[key],
self.zx[key], self.zy[key], self.zz[key]
)
class Rigid:
def __init__(self, origin, rot):
self.origin = origin
self.rot = rot
self.shape = self.origin.shape
def __matmul__(self, other):
if isinstance(other, Vector):
return self.rot @ other + self.origin
elif isinstance(other, Rigid):
return Rigid(self.rot @ other.origin + self.origin, self.rot @ other.rot)
else:
raise TypeError(f"can't multiply rigid by object of type {type(other)}")
def inv(self):
inv_rot = self.rot.inv()
t = inv_rot @ self.origin
return Rigid(Vector(-t.x, -t.y, -t.z), inv_rot)
def unsqueeze(self, dim=None):
return Rigid(self.origin.unsqueeze(dim=dim), self.rot.unsqueeze(dim=dim))
def squeeze(self, dim=None):
return Rigid(self.origin.squeeze(dim=dim), self.rot.squeeze(dim=dim))
def to(self, device):
return Rigid(self.origin.to(device), self.rot.to(device))
def __str__(self):
return "Rigid(origin={},\nrot={})".format(self.origin, self.rot)
def __repr__(self):
return str(self)
def __getitem__(self, key):
return Rigid(self.origin[key], self.rot[key])
def rigid_body_identity(shape):
return Rigid(Vector(*3 * [torch.zeros(shape)]),
Rot(torch.ones(shape), *3 * [torch.zeros(shape)], torch.ones(shape), *3 * [torch.zeros(shape)],
torch.ones(shape)))
def vec_from_tensor(tens):
assert tens.shape[-1] == 3, "What dimension you in?"
return Vector(tens[..., 0], tens[..., 1], tens[..., 2])
def rigid_from_three_points(origin, y_x_plane, x_axis):
v1 = x_axis - origin
v2 = y_x_plane - origin
v1 *= 1 / v1.norm()
v2 = v2 - v1 * (v1 @ v2)
v2 *= 1 / v2.norm()
v3 = v1.cross(v2)
rot = Rot(v1.x, v2.x, v3.x, v1.y, v2.y, v3.y, v1.z, v2.z, v3.z)
return Rigid(origin, rot)
def rigid_from_tensor(tens):
assert (tens.shape[-1] == 3), "I want 3D points"
return rigid_from_three_points(vec_from_tensor(tens[..., 0, :]), vec_from_tensor(tens[..., 1, :]),
vec_from_tensor(tens[..., 2, :]))
def stack_rigids(rigids, **kwargs):
# Probably best to avoid using very much
stacked_origin = Vector(torch.stack([rig.origin.x for rig in rigids], **kwargs),
torch.stack([rig.origin.y for rig in rigids], **kwargs),
torch.stack([rig.origin.z for rig in rigids], **kwargs))
stacked_rot = Rot(
torch.stack([rig.rot.xx for rig in rigids], **kwargs), torch.stack([rig.rot.xy for rig in rigids], **kwargs),
torch.stack([rig.rot.xz for rig in rigids], **kwargs),
torch.stack([rig.rot.yx for rig in rigids], **kwargs), torch.stack([rig.rot.yy for rig in rigids], **kwargs),
torch.stack([rig.rot.yz for rig in rigids], **kwargs),
torch.stack([rig.rot.zx for rig in rigids], **kwargs), torch.stack([rig.rot.zy for rig in rigids], **kwargs),
torch.stack([rig.rot.zz for rig in rigids], **kwargs),
)
return Rigid(stacked_origin, stacked_rot)
def rotate_x_axis_to_new_vector(new_vector):
# Extract coordinates
c, b, a = new_vector[..., 0], new_vector[..., 1], new_vector[..., 2]
# Normalize
n = (c ** 2 + a ** 2 + b ** 2 + 1e-16) ** (1 / 2)
a, b, c = a / n, b / n, -c / n
# Set new origin
new_origin = vec_from_tensor(torch.zeros_like(new_vector))
# Rotate x-axis to point old origin to new one
k = (1 - c) / (a ** 2 + b ** 2 + 1e-8)
new_rot = Rot(-c, b, -a, b, 1 - k * b ** 2, a * b * k, a, -a * b * k, k * a ** 2 - 1)
return Rigid(new_origin, new_rot)
def rigid_transformation_from_torsion_angles(torsion_angles, distance_to_new_origin):
dev = torsion_angles.device
zero = torch.zeros(torsion_angles.shape[:-1]).to(dev)
one = torch.ones(torsion_angles.shape[:-1]).to(dev)
new_rot = Rot(
-one, zero, zero,
zero, torsion_angles[..., 0], torsion_angles[..., 1],
zero, torsion_angles[..., 1], -torsion_angles[..., 0],
)
new_origin = Vector(distance_to_new_origin, zero, zero)
return Rigid(new_origin, new_rot)
def global_frames_from_bb_frame_and_torsion_angles(bb_frame, torsion_angles, seq):
dev = bb_frame.origin.x.device
# We start with psi
psi_local_frame_origin = torch.tensor([rel_pos[x][2][1] for x in seq]).to(dev).pow(2).sum(-1).pow(1 / 2)
psi_local_frame = rigid_transformation_from_torsion_angles(torsion_angles[:, 0], psi_local_frame_origin)
psi_global_frame = bb_frame @ psi_local_frame
# Now all the chis
chi1_local_frame_origin = torch.tensor([rel_pos[x][3][1] for x in seq]).to(dev)
chi1_local_frame = rotate_x_axis_to_new_vector(chi1_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 1], chi1_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi1_global_frame = bb_frame @ chi1_local_frame
chi2_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi2_centers[x]][1] for x in seq]).to(dev)
chi2_local_frame = rotate_x_axis_to_new_vector(chi2_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 2], chi2_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi2_global_frame = chi1_global_frame @ chi2_local_frame
chi3_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi3_centers[x]][1] for x in seq]).to(dev)
chi3_local_frame = rotate_x_axis_to_new_vector(chi3_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 3], chi3_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi3_global_frame = chi2_global_frame @ chi3_local_frame
chi4_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi4_centers[x]][1] for x in seq]).to(dev)
chi4_local_frame = rotate_x_axis_to_new_vector(chi4_local_frame_origin) @ rigid_transformation_from_torsion_angles(
torsion_angles[:, 4], chi4_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
chi4_global_frame = chi3_global_frame @ chi4_local_frame
return stack_rigids(
[bb_frame, psi_global_frame, chi1_global_frame, chi2_global_frame, chi3_global_frame, chi4_global_frame],
dim=-1)
def all_atoms_from_global_reference_frames(global_reference_frames, seq):
dev = global_reference_frames.origin.x.device
all_atoms = torch.zeros((len(seq), 14, 3)).to(dev)
for atom_pos in range(14):
relative_positions = [rel_pos[x][atom_pos][1] for x in seq]
local_reference_frame = [max(rel_pos[x][atom_pos][0] - 2, 0) for x in seq]
local_reference_frame_mask = torch.tensor([[y == x for y in range(6)] for x in local_reference_frame]).to(dev)
global_atom_vector = global_reference_frames[local_reference_frame_mask] @ vec_from_tensor(
torch.tensor(relative_positions).to(dev))
all_atoms[:, atom_pos] = torch.stack([global_atom_vector.x, global_atom_vector.y, global_atom_vector.z], dim=-1)
all_atom_mask = torch.tensor([residue_atoms_mask[x] for x in seq]).to(dev)
all_atoms[~all_atom_mask] = float("Nan")
return all_atoms