Configure ImmuneBuilder pipeline for WES execution
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
Some checks failed
CodeQL / Analyze (python) (push) Has been cancelled
- Update container image to harbor.cluster.omic.ai/omic/immunebuilder:latest - Update input/output paths to S3 (s3://omic/eureka/immunebuilder/) - Remove local mount containerOptions (not needed in k8s) - Update homepage to Gitea repo URL - Clean history to remove large model weight blobs
This commit is contained in:
313
ImmuneBuilder/rigids.py
Normal file
313
ImmuneBuilder/rigids.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import torch
|
||||
from ImmuneBuilder.constants import rigid_group_atom_positions2, chi2_centers, chi3_centers, chi4_centers, rel_pos, \
|
||||
residue_atoms_mask
|
||||
|
||||
class Vector:
|
||||
def __init__(self, x, y, z):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.z = z
|
||||
self.shape = x.shape
|
||||
assert (x.shape == y.shape) and (y.shape == z.shape), "x y and z should have the same shape"
|
||||
|
||||
def __add__(self, vec):
|
||||
return Vector(vec.x + self.x, vec.y + self.y, vec.z + self.z)
|
||||
|
||||
def __sub__(self, vec):
|
||||
return Vector(-vec.x + self.x, -vec.y + self.y, -vec.z + self.z)
|
||||
|
||||
def __mul__(self, param):
|
||||
return Vector(param * self.x, param * self.y, param * self.z)
|
||||
|
||||
def __matmul__(self, vec):
|
||||
return vec.x * self.x + vec.y * self.y + vec.z * self.z
|
||||
|
||||
def norm(self):
|
||||
return (self.x ** 2 + self.y ** 2 + self.z ** 2 + 1e-8) ** (1 / 2)
|
||||
|
||||
def cross(self, other):
|
||||
a = (self.y * other.z - self.z * other.y)
|
||||
b = (self.z * other.x - self.x * other.z)
|
||||
c = (self.x * other.y - self.y * other.x)
|
||||
return Vector(a, b, c)
|
||||
|
||||
def dist(self, other):
|
||||
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2 + 1e-8) ** (1 / 2)
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return Vector(self.x.unsqueeze(dim), self.y.unsqueeze(dim), self.z.unsqueeze(dim))
|
||||
|
||||
def squeeze(self, dim):
|
||||
return Vector(self.x.squeeze(dim), self.y.squeeze(dim), self.z.squeeze(dim))
|
||||
|
||||
def map(self, func):
|
||||
return Vector(func(self.x), func(self.y), func(self.z))
|
||||
|
||||
def to(self, device):
|
||||
return Vector(self.x.to(device), self.y.to(device), self.z.to(device))
|
||||
|
||||
def __str__(self):
|
||||
return "Vector(x={},\ny={},\nz={})\n".format(self.x, self.y, self.z)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Vector(self.x[key], self.y[key], self.z[key])
|
||||
|
||||
|
||||
class Rot:
|
||||
def __init__(self, xx, xy, xz, yx, yy, yz, zx, zy, zz):
|
||||
self.xx = xx
|
||||
self.xy = xy
|
||||
self.xz = xz
|
||||
self.yx = yx
|
||||
self.yy = yy
|
||||
self.yz = yz
|
||||
self.zx = zx
|
||||
self.zy = zy
|
||||
self.zz = zz
|
||||
self.shape = xx.shape
|
||||
|
||||
def __matmul__(self, other):
|
||||
if isinstance(other, Vector):
|
||||
return Vector(
|
||||
other.x * self.xx + other.y * self.xy + other.z * self.xz,
|
||||
other.x * self.yx + other.y * self.yy + other.z * self.yz,
|
||||
other.x * self.zx + other.y * self.zy + other.z * self.zz)
|
||||
|
||||
if isinstance(other, Rot):
|
||||
return Rot(
|
||||
xx=self.xx * other.xx + self.xy * other.yx + self.xz * other.zx,
|
||||
xy=self.xx * other.xy + self.xy * other.yy + self.xz * other.zy,
|
||||
xz=self.xx * other.xz + self.xy * other.yz + self.xz * other.zz,
|
||||
yx=self.yx * other.xx + self.yy * other.yx + self.yz * other.zx,
|
||||
yy=self.yx * other.xy + self.yy * other.yy + self.yz * other.zy,
|
||||
yz=self.yx * other.xz + self.yy * other.yz + self.yz * other.zz,
|
||||
zx=self.zx * other.xx + self.zy * other.yx + self.zz * other.zx,
|
||||
zy=self.zx * other.xy + self.zy * other.yy + self.zz * other.zy,
|
||||
zz=self.zx * other.xz + self.zy * other.yz + self.zz * other.zz,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Matmul against {}".format(type(other)))
|
||||
|
||||
def inv(self):
|
||||
return Rot(
|
||||
xx=self.xx, xy=self.yx, xz=self.zx,
|
||||
yx=self.xy, yy=self.yy, yz=self.zy,
|
||||
zx=self.xz, zy=self.yz, zz=self.zz
|
||||
)
|
||||
|
||||
def det(self):
|
||||
return self.xx * self.yy * self.zz + self.xy * self.yz * self.zx + self.yx * self.zy * self.xz - self.xz * self.yy * self.zx - self.xy * self.yx * self.zz - self.xx * self.zy * self.yz
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return Rot(
|
||||
self.xx.unsqueeze(dim=dim), self.xy.unsqueeze(dim=dim), self.xz.unsqueeze(dim=dim),
|
||||
self.yx.unsqueeze(dim=dim), self.yy.unsqueeze(dim=dim), self.yz.unsqueeze(dim=dim),
|
||||
self.zx.unsqueeze(dim=dim), self.zy.unsqueeze(dim=dim), self.zz.unsqueeze(dim=dim)
|
||||
)
|
||||
|
||||
def squeeze(self, dim):
|
||||
return Rot(
|
||||
self.xx.squeeze(dim=dim), self.xy.squeeze(dim=dim), self.xz.squeeze(dim=dim),
|
||||
self.yx.squeeze(dim=dim), self.yy.squeeze(dim=dim), self.yz.squeeze(dim=dim),
|
||||
self.zx.squeeze(dim=dim), self.zy.squeeze(dim=dim), self.zz.squeeze(dim=dim)
|
||||
)
|
||||
|
||||
def detach(self):
|
||||
return Rot(
|
||||
self.xx.detach(), self.xy.detach(), self.xz.detach(),
|
||||
self.yx.detach(), self.yy.detach(), self.yz.detach(),
|
||||
self.zx.detach(), self.zy.detach(), self.zz.detach()
|
||||
)
|
||||
|
||||
def to(self, device):
|
||||
return Rot(
|
||||
self.xx.to(device), self.xy.to(device), self.xz.to(device),
|
||||
self.yx.to(device), self.yy.to(device), self.yz.to(device),
|
||||
self.zx.to(device), self.zy.to(device), self.zz.to(device)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Rot(xx={},\nxy={},\nxz={},\nyx={},\nyy={},\nyz={},\nzx={},\nzy={},\nzz={})\n".format(self.xx, self.xy,
|
||||
self.xz, self.yx,
|
||||
self.yy, self.yz,
|
||||
self.zx, self.zy,
|
||||
self.zz)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Rot(
|
||||
self.xx[key], self.xy[key], self.xz[key],
|
||||
self.yx[key], self.yy[key], self.yz[key],
|
||||
self.zx[key], self.zy[key], self.zz[key]
|
||||
)
|
||||
|
||||
|
||||
class Rigid:
|
||||
def __init__(self, origin, rot):
|
||||
self.origin = origin
|
||||
self.rot = rot
|
||||
self.shape = self.origin.shape
|
||||
|
||||
def __matmul__(self, other):
|
||||
if isinstance(other, Vector):
|
||||
return self.rot @ other + self.origin
|
||||
elif isinstance(other, Rigid):
|
||||
return Rigid(self.rot @ other.origin + self.origin, self.rot @ other.rot)
|
||||
else:
|
||||
raise TypeError(f"can't multiply rigid by object of type {type(other)}")
|
||||
|
||||
def inv(self):
|
||||
inv_rot = self.rot.inv()
|
||||
t = inv_rot @ self.origin
|
||||
return Rigid(Vector(-t.x, -t.y, -t.z), inv_rot)
|
||||
|
||||
def unsqueeze(self, dim=None):
|
||||
return Rigid(self.origin.unsqueeze(dim=dim), self.rot.unsqueeze(dim=dim))
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
return Rigid(self.origin.squeeze(dim=dim), self.rot.squeeze(dim=dim))
|
||||
|
||||
def to(self, device):
|
||||
return Rigid(self.origin.to(device), self.rot.to(device))
|
||||
|
||||
def __str__(self):
|
||||
return "Rigid(origin={},\nrot={})".format(self.origin, self.rot)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Rigid(self.origin[key], self.rot[key])
|
||||
|
||||
|
||||
def rigid_body_identity(shape):
|
||||
return Rigid(Vector(*3 * [torch.zeros(shape)]),
|
||||
Rot(torch.ones(shape), *3 * [torch.zeros(shape)], torch.ones(shape), *3 * [torch.zeros(shape)],
|
||||
torch.ones(shape)))
|
||||
|
||||
def vec_from_tensor(tens):
|
||||
assert tens.shape[-1] == 3, "What dimension you in?"
|
||||
return Vector(tens[..., 0], tens[..., 1], tens[..., 2])
|
||||
|
||||
|
||||
def rigid_from_three_points(origin, y_x_plane, x_axis):
|
||||
v1 = x_axis - origin
|
||||
v2 = y_x_plane - origin
|
||||
|
||||
v1 *= 1 / v1.norm()
|
||||
v2 = v2 - v1 * (v1 @ v2)
|
||||
v2 *= 1 / v2.norm()
|
||||
v3 = v1.cross(v2)
|
||||
rot = Rot(v1.x, v2.x, v3.x, v1.y, v2.y, v3.y, v1.z, v2.z, v3.z)
|
||||
return Rigid(origin, rot)
|
||||
|
||||
|
||||
def rigid_from_tensor(tens):
|
||||
assert (tens.shape[-1] == 3), "I want 3D points"
|
||||
return rigid_from_three_points(vec_from_tensor(tens[..., 0, :]), vec_from_tensor(tens[..., 1, :]),
|
||||
vec_from_tensor(tens[..., 2, :]))
|
||||
|
||||
def stack_rigids(rigids, **kwargs):
|
||||
# Probably best to avoid using very much
|
||||
stacked_origin = Vector(torch.stack([rig.origin.x for rig in rigids], **kwargs),
|
||||
torch.stack([rig.origin.y for rig in rigids], **kwargs),
|
||||
torch.stack([rig.origin.z for rig in rigids], **kwargs))
|
||||
stacked_rot = Rot(
|
||||
torch.stack([rig.rot.xx for rig in rigids], **kwargs), torch.stack([rig.rot.xy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.xz for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.yx for rig in rigids], **kwargs), torch.stack([rig.rot.yy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.yz for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.zx for rig in rigids], **kwargs), torch.stack([rig.rot.zy for rig in rigids], **kwargs),
|
||||
torch.stack([rig.rot.zz for rig in rigids], **kwargs),
|
||||
)
|
||||
return Rigid(stacked_origin, stacked_rot)
|
||||
|
||||
|
||||
def rotate_x_axis_to_new_vector(new_vector):
|
||||
# Extract coordinates
|
||||
c, b, a = new_vector[..., 0], new_vector[..., 1], new_vector[..., 2]
|
||||
|
||||
# Normalize
|
||||
n = (c ** 2 + a ** 2 + b ** 2 + 1e-16) ** (1 / 2)
|
||||
a, b, c = a / n, b / n, -c / n
|
||||
|
||||
# Set new origin
|
||||
new_origin = vec_from_tensor(torch.zeros_like(new_vector))
|
||||
|
||||
# Rotate x-axis to point old origin to new one
|
||||
k = (1 - c) / (a ** 2 + b ** 2 + 1e-8)
|
||||
new_rot = Rot(-c, b, -a, b, 1 - k * b ** 2, a * b * k, a, -a * b * k, k * a ** 2 - 1)
|
||||
|
||||
return Rigid(new_origin, new_rot)
|
||||
|
||||
|
||||
def rigid_transformation_from_torsion_angles(torsion_angles, distance_to_new_origin):
|
||||
dev = torsion_angles.device
|
||||
|
||||
zero = torch.zeros(torsion_angles.shape[:-1]).to(dev)
|
||||
one = torch.ones(torsion_angles.shape[:-1]).to(dev)
|
||||
new_rot = Rot(
|
||||
-one, zero, zero,
|
||||
zero, torsion_angles[..., 0], torsion_angles[..., 1],
|
||||
zero, torsion_angles[..., 1], -torsion_angles[..., 0],
|
||||
)
|
||||
new_origin = Vector(distance_to_new_origin, zero, zero)
|
||||
|
||||
return Rigid(new_origin, new_rot)
|
||||
|
||||
|
||||
def global_frames_from_bb_frame_and_torsion_angles(bb_frame, torsion_angles, seq):
|
||||
dev = bb_frame.origin.x.device
|
||||
|
||||
# We start with psi
|
||||
psi_local_frame_origin = torch.tensor([rel_pos[x][2][1] for x in seq]).to(dev).pow(2).sum(-1).pow(1 / 2)
|
||||
psi_local_frame = rigid_transformation_from_torsion_angles(torsion_angles[:, 0], psi_local_frame_origin)
|
||||
psi_global_frame = bb_frame @ psi_local_frame
|
||||
|
||||
# Now all the chis
|
||||
chi1_local_frame_origin = torch.tensor([rel_pos[x][3][1] for x in seq]).to(dev)
|
||||
chi1_local_frame = rotate_x_axis_to_new_vector(chi1_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 1], chi1_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi1_global_frame = bb_frame @ chi1_local_frame
|
||||
|
||||
chi2_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi2_centers[x]][1] for x in seq]).to(dev)
|
||||
chi2_local_frame = rotate_x_axis_to_new_vector(chi2_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 2], chi2_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi2_global_frame = chi1_global_frame @ chi2_local_frame
|
||||
|
||||
chi3_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi3_centers[x]][1] for x in seq]).to(dev)
|
||||
chi3_local_frame = rotate_x_axis_to_new_vector(chi3_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 3], chi3_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi3_global_frame = chi2_global_frame @ chi3_local_frame
|
||||
|
||||
chi4_local_frame_origin = torch.tensor([rigid_group_atom_positions2[x][chi4_centers[x]][1] for x in seq]).to(dev)
|
||||
chi4_local_frame = rotate_x_axis_to_new_vector(chi4_local_frame_origin) @ rigid_transformation_from_torsion_angles(
|
||||
torsion_angles[:, 4], chi4_local_frame_origin.pow(2).sum(-1).pow(1 / 2))
|
||||
chi4_global_frame = chi3_global_frame @ chi4_local_frame
|
||||
|
||||
return stack_rigids(
|
||||
[bb_frame, psi_global_frame, chi1_global_frame, chi2_global_frame, chi3_global_frame, chi4_global_frame],
|
||||
dim=-1)
|
||||
|
||||
|
||||
def all_atoms_from_global_reference_frames(global_reference_frames, seq):
|
||||
dev = global_reference_frames.origin.x.device
|
||||
|
||||
all_atoms = torch.zeros((len(seq), 14, 3)).to(dev)
|
||||
for atom_pos in range(14):
|
||||
relative_positions = [rel_pos[x][atom_pos][1] for x in seq]
|
||||
local_reference_frame = [max(rel_pos[x][atom_pos][0] - 2, 0) for x in seq]
|
||||
local_reference_frame_mask = torch.tensor([[y == x for y in range(6)] for x in local_reference_frame]).to(dev)
|
||||
global_atom_vector = global_reference_frames[local_reference_frame_mask] @ vec_from_tensor(
|
||||
torch.tensor(relative_positions).to(dev))
|
||||
all_atoms[:, atom_pos] = torch.stack([global_atom_vector.x, global_atom_vector.y, global_atom_vector.z], dim=-1)
|
||||
|
||||
all_atom_mask = torch.tensor([residue_atoms_mask[x] for x in seq]).to(dev)
|
||||
all_atoms[~all_atom_mask] = float("Nan")
|
||||
return all_atoms
|
||||
Reference in New Issue
Block a user