Some checks failed
Code Quality Main / code-quality (push) Has been cancelled
Release Drafter / update_release_draft (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.10) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.8) (push) Has been cancelled
Tests / run_tests_ubuntu (ubuntu-latest, 3.9) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.10) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.8) (push) Has been cancelled
Tests / run_tests_macos (macos-latest, 3.9) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.10) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.8) (push) Has been cancelled
Tests / run_tests_windows (windows-latest, 3.9) (push) Has been cancelled
Tests / code-coverage (push) Has been cancelled
1440 lines
56 KiB
Python
1440 lines
56 KiB
Python
import random
|
|
|
|
import rootutils
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from beartype.typing import Any, Dict, Literal, Optional, Tuple, Union
|
|
from lightning import LightningModule
|
|
|
|
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
from flowdock.utils.frame_utils import cartesian_to_internal, get_frame_matrix
|
|
from flowdock.utils.metric_utils import compute_per_atom_lddt
|
|
from flowdock.utils.model_utils import (
|
|
distance_to_gaussian_contact_logits,
|
|
distogram_to_gaussian_contact_logits,
|
|
eval_true_contact_maps,
|
|
sample_res_rowmask_from_contacts,
|
|
sample_reslig_contact_matrix,
|
|
segment_mean,
|
|
)
|
|
|
|
MODEL_BATCH = Dict[str, Any]
|
|
MODEL_STAGE = Literal["train", "val", "test", "predict"]
|
|
LOSS_MODES = Literal[
|
|
"structure_prediction",
|
|
"auxiliary_estimation",
|
|
"auxiliary_estimation_without_structure_prediction",
|
|
]
|
|
|
|
|
|
def compute_contact_prediction_losses(
|
|
pred_distograms: torch.Tensor,
|
|
ref_dist_mat: torch.Tensor,
|
|
dist_bins: torch.Tensor,
|
|
contact_scale: float,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Compute the contact prediction losses for a given batch.
|
|
|
|
:param pred_distograms: The predicted distograms.
|
|
:param ref_dist_mat: The reference distance matrix.
|
|
:param dist_bins: The distance bins.
|
|
:param contact_scale: The contact scale.
|
|
:return: The distogram and forward KL losses.
|
|
"""
|
|
# True onehot distance and distogram loss
|
|
distance_bin_idx = torch.bucketize(ref_dist_mat, dist_bins[:-1], right=True)
|
|
distogram_loss = F.cross_entropy(pred_distograms.flatten(0, -2), distance_bin_idx.flatten())
|
|
# Evaluate contact logits via log(\sum_j p_j \exp(-\alpha*r_j^2))
|
|
ref_contact_logits = distance_to_gaussian_contact_logits(ref_dist_mat, contact_scale)
|
|
pred_contact_logits = distogram_to_gaussian_contact_logits(
|
|
pred_distograms,
|
|
dist_bins,
|
|
contact_scale,
|
|
)
|
|
forward_kl_loss = F.kl_div(
|
|
F.log_softmax(
|
|
pred_contact_logits.flatten(-2, -1),
|
|
dim=-1,
|
|
),
|
|
F.log_softmax(
|
|
ref_contact_logits.flatten(-2, -1),
|
|
dim=-1,
|
|
),
|
|
log_target=True,
|
|
reduction="batchmean",
|
|
)
|
|
return distogram_loss, forward_kl_loss
|
|
|
|
|
|
def compute_protein_distogram_loss(
|
|
batch: MODEL_BATCH,
|
|
target_coords: torch.Tensor,
|
|
dist_bins: torch.Tensor,
|
|
dgram_head: torch.nn.Module,
|
|
entry: str = "res_res_grid_attr_flat",
|
|
) -> torch.Tensor:
|
|
"""Compute the protein distogram loss for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param target_coords: The target coordinates.
|
|
:param dist_bins: The distance bins.
|
|
:param dgram_head: The distogram head to use for loss calculation.
|
|
:param entry: The entry to use.
|
|
:return: The distogram loss.
|
|
"""
|
|
n_protein_patches = batch["metadata"]["n_prot_patches_per_sample"]
|
|
sampled_grid_features = batch["features"][entry]
|
|
sampled_ca_coords = target_coords[batch["indexer"]["gather_idx_pid_a"]].view(
|
|
batch["metadata"]["num_structid"], n_protein_patches, 3
|
|
)
|
|
sampled_ca_dist = torch.norm(
|
|
sampled_ca_coords[:, :, None] - sampled_ca_coords[:, None, :], dim=-1
|
|
)
|
|
# Using AF2 parameters
|
|
distance_bin_idx = torch.bucketize(sampled_ca_dist, dist_bins[:-1], right=True)
|
|
distogram_loss = F.cross_entropy(dgram_head(sampled_grid_features), distance_bin_idx.flatten())
|
|
return distogram_loss
|
|
|
|
|
|
def compute_fape_from_atom37(
|
|
batch: MODEL_BATCH,
|
|
device: Union[str, torch.device],
|
|
pred_prot_coords: torch.Tensor, # [N_res, 37, 3]
|
|
target_prot_coords: torch.Tensor, # [N_res, 37, 3]
|
|
pred_lig_coords: Optional[torch.Tensor] = None, # [N_atom, 3]
|
|
target_lig_coords: Optional[torch.Tensor] = None, # [N_atom, 3]
|
|
lig_frame_atm_idx: Optional[torch.Tensor] = None, # [3, N_atom]
|
|
split_pl_views: bool = False,
|
|
cap_size: int = 8000,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Compute the Frame Aligned Point Error (FAPE) loss from `atom37` coordinates.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param device: The device to use.
|
|
:param pred_prot_coords: The predicted protein coordinates.
|
|
:param target_prot_coords: The target protein coordinates.
|
|
:param pred_lig_coords: The predicted ligand coordinates.
|
|
:param target_lig_coords: The target ligand coordinates.
|
|
:param lig_frame_atm_idx: The ligand frame atom indices.
|
|
:param split_pl_views: Whether to split the protein-ligand views.
|
|
:param cap_size: The capped size.
|
|
:return: The FAPE loss.
|
|
"""
|
|
features = batch["features"]
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
with torch.no_grad():
|
|
atom_mask = (
|
|
features["res_atom_mask"].bool().view(batch["metadata"]["num_structid"], -1, 37)
|
|
).clone()
|
|
atom_mask[:, :, [6, 7, 12, 13, 16, 17, 20, 21, 26, 27, 29, 30]] = False
|
|
pred_prot_coords = pred_prot_coords.view(batch_size, -1, 37, 3)
|
|
target_prot_coords = target_prot_coords.view(batch_size, -1, 37, 3)
|
|
pred_bb_frames = get_frame_matrix(
|
|
pred_prot_coords[:, :, 0, :],
|
|
pred_prot_coords[:, :, 1, :],
|
|
pred_prot_coords[:, :, 2, :],
|
|
)
|
|
# pred_bb_frames.R = pred_bb_frames.R.detach()
|
|
target_bb_frames = get_frame_matrix(
|
|
target_prot_coords[:, :, 0, :],
|
|
target_prot_coords[:, :, 1, :],
|
|
target_prot_coords[:, :, 2, :],
|
|
)
|
|
pred_prot_coords_flat = pred_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
target_prot_coords_flat = target_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
if pred_lig_coords is not None:
|
|
assert target_prot_coords is not None, "Target protein coordinates must be provided."
|
|
assert lig_frame_atm_idx is not None, "Ligand frame atom indices must be provided."
|
|
pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3)
|
|
target_lig_coords = target_lig_coords.view(batch_size, -1, 3)
|
|
pred_coords = torch.cat([pred_prot_coords_flat, pred_lig_coords], dim=1)
|
|
target_coords = torch.cat([target_prot_coords_flat, target_lig_coords], dim=1)
|
|
pred_lig_frames = get_frame_matrix(
|
|
pred_lig_coords[:, lig_frame_atm_idx[0]],
|
|
pred_lig_coords[:, lig_frame_atm_idx[1]],
|
|
pred_lig_coords[:, lig_frame_atm_idx[2]],
|
|
)
|
|
pred_frames = pred_bb_frames.concatenate(pred_lig_frames, dim=1)
|
|
target_lig_frames = get_frame_matrix(
|
|
target_lig_coords[:, lig_frame_atm_idx[0]],
|
|
target_lig_coords[:, lig_frame_atm_idx[1]],
|
|
target_lig_coords[:, lig_frame_atm_idx[2]],
|
|
)
|
|
target_frames = target_bb_frames.concatenate(target_lig_frames, dim=1)
|
|
else:
|
|
pred_coords = pred_prot_coords_flat
|
|
target_coords = target_prot_coords_flat
|
|
pred_frames = pred_bb_frames
|
|
target_frames = target_bb_frames
|
|
# Columns-frames, rows-points
|
|
# [B, 1, N, 3] - [B, F, 1, 3]
|
|
sampling_rate = cap_size / (batch_size * target_coords.shape[1])
|
|
sampling_mask = torch.rand(target_coords.shape[1], device=device) < sampling_rate
|
|
aligned_pred_points = cartesian_to_internal(
|
|
pred_coords[:, sampling_mask].unsqueeze(1), pred_frames.unsqueeze(2)
|
|
)
|
|
with torch.no_grad():
|
|
aligned_target_points = cartesian_to_internal(
|
|
target_coords[:, sampling_mask].unsqueeze(1), target_frames.unsqueeze(2)
|
|
)
|
|
pair_dist_aligned = (
|
|
torch.square(aligned_pred_points - aligned_target_points)
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
cropped_pair_dists = torch.clamp(pair_dist_aligned, max=10)
|
|
normalized_pair_dists = (
|
|
pair_dist_aligned / aligned_target_points.square().sum(-1).add(1e-4).sqrt()
|
|
)
|
|
if split_pl_views:
|
|
fape_protframe = cropped_pair_dists[:, : target_bb_frames.t.shape[1]].mean((1, 2)) / 10
|
|
fape_ligframe = cropped_pair_dists[:, target_bb_frames.t.shape[1] :].mean((1, 2)) / 10
|
|
return fape_protframe, fape_ligframe, normalized_pair_dists.mean((1, 2))
|
|
return cropped_pair_dists.mean((1, 2)) / 10, normalized_pair_dists.mean((1, 2))
|
|
|
|
|
|
def compute_aa_distance_geometry_loss(
|
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the amino acid distance geometry loss for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_coords: The predicted coordinates.
|
|
:param target_coords: The target coordinates.
|
|
:return: The distance geometry loss.
|
|
"""
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
features = batch["features"]
|
|
atom_mask = features["res_atom_mask"].bool()
|
|
# Add backbone atoms from previous residue
|
|
atom_mask = atom_mask.view(batch_size, -1, 37)
|
|
atom_mask = torch.cat([atom_mask[:, 1:], atom_mask[:, :-1, 0:3]], dim=2).flatten(0, 1)
|
|
pred_coords = pred_coords.view(batch_size, -1, 37, 3)
|
|
pred_coords = torch.cat([pred_coords[:, 1:], pred_coords[:, :-1, 0:3]], dim=2).flatten(0, 1)
|
|
target_coords = target_coords.view(batch_size, -1, 37, 3)
|
|
target_coords = torch.cat([target_coords[:, 1:], target_coords[:, :-1, 0:3]], dim=2).flatten(
|
|
0, 1
|
|
)
|
|
local_pair_dist_target = (
|
|
(target_coords[:, None, :] - target_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt()
|
|
)
|
|
local_pair_dist_pred = (
|
|
(pred_coords[:, None, :] - pred_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt()
|
|
)
|
|
local_pair_mask = (
|
|
atom_mask[:, None, :] & atom_mask[:, :, None] & (local_pair_dist_target < 3.0)
|
|
)
|
|
ret = (local_pair_dist_target - local_pair_dist_pred).abs()[local_pair_mask]
|
|
return ret.view(batch["metadata"]["num_structid"], -1).mean(dim=1)
|
|
|
|
|
|
def compute_sm_distance_geometry_loss(
|
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the small molecule distance geometry loss for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_coords: The predicted coordinates.
|
|
:param target_coords: The target coordinates.
|
|
:return: The distance geometry loss.
|
|
"""
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
pred_coords = pred_coords.view(batch_size, -1, 3)
|
|
target_coords = target_coords.view(batch_size, -1, 3)
|
|
pair_dist_target = (
|
|
(target_coords[:, None, :] - target_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt()
|
|
)
|
|
pair_dist_pred = (
|
|
(pred_coords[:, None, :] - pred_coords[:, :, None]).square().sum(-1).add(1e-4).sqrt()
|
|
)
|
|
local_pair_mask = pair_dist_target < 3.0
|
|
ret = (pair_dist_target - pair_dist_pred).abs()[local_pair_mask]
|
|
return ret.view(batch_size, -1).mean(dim=1)
|
|
|
|
|
|
def compute_drmsd_and_clashloss(
|
|
batch: MODEL_BATCH,
|
|
device: Union[str, torch.device],
|
|
pred_prot_coords: torch.Tensor,
|
|
target_prot_coords: torch.Tensor,
|
|
atnum2vdw_uff: torch.nn.Parameter,
|
|
cap_size: int = 4000,
|
|
pred_lig_coords: Optional[torch.Tensor] = None,
|
|
target_lig_coords: Optional[torch.Tensor] = None,
|
|
ligatm_types: Optional[torch.Tensor] = None,
|
|
binding_site: bool = False,
|
|
pl_interface: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""Compute the differentiable root-mean-square deviation (dRMSD) and optional clash loss for a
|
|
given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param device: The device to use.
|
|
:param pred_prot_coords: The predicted protein coordinates.
|
|
:param target_prot_coords: The target protein coordinates.
|
|
:param atnum2vdw_uff: The atomic number to UFF VDW parameters mapping `Parameter`.
|
|
:param cap_size: The capped size.
|
|
:param pred_lig_coords: The predicted ligand coordinates.
|
|
:param target_lig_coords: The target ligand coordinates.
|
|
:param ligatm_types: The ligand atom types.
|
|
:param binding_site: Whether to compute the binding site.
|
|
:param pl_interface: Whether to compute the protein-ligand interface.
|
|
:return: The dRMSD and optional clash loss.
|
|
"""
|
|
features = batch["features"]
|
|
with torch.no_grad():
|
|
if not binding_site:
|
|
atom_mask = features["res_atom_mask"].bool().clone()
|
|
else:
|
|
atom_mask = (
|
|
features["res_atom_mask"].bool() & features["binding_site_mask_clean"][:, None]
|
|
)
|
|
if pl_interface:
|
|
# Removing ambiguous atoms
|
|
atom_mask[:, [6, 7, 12, 13, 16, 17, 20, 21, 26, 27, 29, 30]] = False
|
|
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
pred_prot_coords = pred_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
if pred_lig_coords is not None:
|
|
assert target_prot_coords is not None, "Target protein coordinates must be provided."
|
|
assert ligatm_types is not None, "Ligand atom types must be provided."
|
|
pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3)
|
|
pred_coords = torch.cat([pred_prot_coords, pred_lig_coords], dim=1)
|
|
else:
|
|
pred_coords = pred_prot_coords
|
|
sampling_rate = cap_size / pred_coords.shape[1]
|
|
sampling_mask = torch.rand(pred_coords.shape[1], device=device) < sampling_rate
|
|
pred_coords = pred_coords[:, sampling_mask]
|
|
if pl_interface:
|
|
pred_dist = (
|
|
torch.square(pred_coords[:, :, None] - pred_lig_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
else:
|
|
pred_dist = (
|
|
torch.square(pred_coords[:, :, None] - pred_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
with torch.no_grad():
|
|
target_prot_coords = target_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
if pred_lig_coords is not None:
|
|
target_lig_coords = target_lig_coords.view(batch_size, -1, 3)
|
|
target_coords = torch.cat([target_prot_coords, target_lig_coords], dim=1)
|
|
else:
|
|
target_coords = target_prot_coords
|
|
target_coords = target_coords[:, sampling_mask]
|
|
if pl_interface:
|
|
target_dist = (
|
|
torch.square(target_coords[:, :, None] - target_lig_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
else:
|
|
target_dist = (
|
|
torch.square(target_coords[:, :, None] - target_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
# In Angstrom, using UFF params to compute clash loss
|
|
protatm_types = features["res_atom_types"].long()[atom_mask]
|
|
protatm_vdw = atnum2vdw_uff[protatm_types].view(batch_size, -1)
|
|
if pred_lig_coords is not None:
|
|
ligatm_vdw = atnum2vdw_uff[ligatm_types].view(batch_size, -1)
|
|
atm_vdw = torch.cat([protatm_vdw, ligatm_vdw], dim=1)
|
|
else:
|
|
atm_vdw = protatm_vdw
|
|
atm_vdw = atm_vdw[:, sampling_mask]
|
|
average_vdw = (atm_vdw[:, :, None] + atm_vdw[:, None, :]) / 2
|
|
# Use conservative cutoffs to avoid mis-penalization
|
|
|
|
dist_errors = (pred_dist - target_dist).square()
|
|
drmsd = dist_errors.add(1e-2).sqrt().sub(1e-1).mean(dim=(1, 2))
|
|
if pl_interface:
|
|
return drmsd, None
|
|
|
|
covalent_like = target_dist < (average_vdw * 1.2)
|
|
# Alphafold supplementary Eq. 46, modified
|
|
clash_pairwise = torch.clamp(average_vdw * 1.1 - pred_dist.add(1e-6), min=0.0)
|
|
clash_loss = clash_pairwise.mul(~covalent_like).sum(dim=2).mean(dim=1)
|
|
return drmsd, clash_loss
|
|
|
|
|
|
def compute_template_weighted_centroid_drmsd(
|
|
batch: MODEL_BATCH,
|
|
pred_prot_coords: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute the template-weighted centroid dRMSD for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_prot_coords: The predicted protein coordinates.
|
|
:return: The dRMSD.
|
|
"""
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
|
|
pred_cent_coords = (
|
|
pred_prot_coords.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
|
.sum(dim=1)
|
|
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
|
).view(batch_size, -1, 3)
|
|
pred_dist = (
|
|
torch.square(pred_cent_coords[:, :, None] - pred_cent_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
with torch.no_grad():
|
|
target_cent_coords = (
|
|
batch["features"]["res_atom_positions"]
|
|
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
|
.sum(dim=1)
|
|
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
|
).view(batch_size, -1, 3)
|
|
template_cent_coords = (
|
|
batch["features"]["apo_res_atom_positions"]
|
|
.mul(batch["features"]["apo_res_atom_mask"].bool()[:, :, None])
|
|
.sum(dim=1)
|
|
.div(batch["features"]["apo_res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
|
).view(batch_size, -1, 3)
|
|
target_dist = (
|
|
torch.square(target_cent_coords[:, :, None] - target_cent_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
template_dist = (
|
|
torch.square(template_cent_coords[:, :, None] - template_cent_coords[:, None, :])
|
|
.sum(-1)
|
|
.add(1e-4)
|
|
.sqrt()
|
|
.sub(1e-2)
|
|
)
|
|
template_alignment_mask = (
|
|
batch["features"]["apo_res_alignment_mask"].bool().view(batch_size, -1)
|
|
)
|
|
motion_mask = (
|
|
((target_dist - template_dist).abs() > 2.0)
|
|
* template_alignment_mask[:, None, :]
|
|
* template_alignment_mask[:, :, None]
|
|
)
|
|
|
|
dist_errors = (pred_dist - target_dist).square()
|
|
drmsd = (dist_errors.add(1e-4).sqrt().sub(1e-2).mul(motion_mask).sum(dim=(1, 2))) / (
|
|
motion_mask.long().sum(dim=(1, 2)) + 1
|
|
)
|
|
return drmsd
|
|
|
|
|
|
def compute_TMscore_lbound(
|
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the TM-score lower bound for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_coords: The predicted coordinates.
|
|
:param target_coords: The target coordinates.
|
|
:return: The TM-score lower bound.
|
|
"""
|
|
features = batch["features"]
|
|
atom_mask = features["res_atom_mask"].bool().view(batch["metadata"]["num_structid"], -1, 37)
|
|
pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 37, 3)
|
|
target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 37, 3)
|
|
pred_bb_frames = get_frame_matrix(
|
|
pred_coords[:, :, 0, :],
|
|
pred_coords[:, :, 1, :],
|
|
pred_coords[:, :, 2, :],
|
|
strict=True,
|
|
)
|
|
target_bb_frames = get_frame_matrix(
|
|
target_coords[:, :, 0, :],
|
|
target_coords[:, :, 1, :],
|
|
target_coords[:, :, 2, :],
|
|
strict=True,
|
|
)
|
|
pred_coords_flat = pred_coords[atom_mask].view(batch["metadata"]["num_structid"], -1, 3)
|
|
target_coords_flat = target_coords[atom_mask].view(batch["metadata"]["num_structid"], -1, 3)
|
|
# Columns-frames, rows-points
|
|
# [B, 1, N, 3] - [B, F, 1, 3]
|
|
aligned_pred_points = cartesian_to_internal(
|
|
pred_coords_flat.unsqueeze(1), pred_bb_frames.unsqueeze(2)
|
|
)
|
|
with torch.no_grad():
|
|
aligned_target_points = cartesian_to_internal(
|
|
target_coords_flat.unsqueeze(1), target_bb_frames.unsqueeze(2)
|
|
)
|
|
pair_dist_aligned = (aligned_pred_points - aligned_target_points).norm(dim=-1)
|
|
tm_normalizer = 1.24 * (max(target_coords.shape[1], 19) - 15) ** (1 / 3) - 1.8
|
|
per_frame_tm = torch.mean(1 / (1 + (pair_dist_aligned / tm_normalizer) ** 2), dim=2)
|
|
return torch.amax(per_frame_tm, dim=1)
|
|
|
|
|
|
def compute_TMscore_raw(
|
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the raw TM-score for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_coords: The predicted coordinates.
|
|
:param target_coords: The target coordinates.
|
|
:return: The raw TM-score.
|
|
"""
|
|
pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 3)
|
|
target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 3)
|
|
pair_dist_aligned = (pred_coords - target_coords).norm(dim=-1)
|
|
tm_normalizer = 1.24 * (max(target_coords.shape[1], 19) - 15) ** (1 / 3) - 1.8
|
|
per_struct_tm = torch.mean(1 / (1 + (pair_dist_aligned / tm_normalizer) ** 2), dim=1)
|
|
return per_struct_tm
|
|
|
|
|
|
def compute_lddt_ca(
|
|
batch: MODEL_BATCH, pred_coords: torch.Tensor, target_coords: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Compute the local distance difference test (lDDT) for C-alpha atoms for a given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_coords: The predicted coordinates.
|
|
:param target_coords: The target coordinates.
|
|
:return: The lDDT for C-alpha atoms.
|
|
"""
|
|
pred_coords = pred_coords.view(batch["metadata"]["num_structid"], -1, 37, 3)
|
|
target_coords = target_coords.view(batch["metadata"]["num_structid"], -1, 37, 3)
|
|
pred_ca_flat = pred_coords[:, :, 1]
|
|
target_ca_flat = target_coords[:, :, 1]
|
|
target_dist = (target_ca_flat[:, :, None] - target_ca_flat[:, None, :]).norm(dim=-1)
|
|
pred_dist = (pred_ca_flat[:, :, None] - pred_ca_flat[:, None, :]).norm(dim=-1)
|
|
conserved_mask = target_dist < 15.0
|
|
lddt = 0
|
|
for threshold in [0.5, 1, 2, 4]:
|
|
below_threshold = (pred_dist - target_dist).abs() < threshold
|
|
lddt = lddt + below_threshold.mul(conserved_mask).sum((1, 2)) / conserved_mask.sum((1, 2))
|
|
return lddt / 4
|
|
|
|
|
|
def compute_lddt_pli(
|
|
batch: MODEL_BATCH,
|
|
pred_prot_coords: torch.Tensor,
|
|
target_prot_coords: torch.Tensor,
|
|
pred_lig_coords: torch.Tensor,
|
|
target_lig_coords: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute the local distance difference test (lDDT) for protein-ligand interface atoms for a
|
|
given batch.
|
|
|
|
:param batch: A batch dictionary.
|
|
:param pred_prot_coords: The predicted protein coordinates.
|
|
:param target_prot_coords: The target protein coordinates.
|
|
:param pred_lig_coords: The predicted ligand coordinates.
|
|
:param target_lig_coords: The target ligand coordinates.
|
|
:return: The lDDT for protein-ligand interface atoms.
|
|
"""
|
|
features = batch["features"]
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
atom_mask = features["res_atom_mask"].bool()
|
|
pred_prot_coords = pred_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
target_prot_coords = target_prot_coords[atom_mask].view(batch_size, -1, 3)
|
|
pred_lig_coords = pred_lig_coords.view(batch_size, -1, 3)
|
|
target_lig_coords = target_lig_coords.view(batch_size, -1, 3)
|
|
target_dist = (target_prot_coords[:, :, None] - target_lig_coords[:, None, :]).norm(dim=-1)
|
|
pred_dist = (pred_prot_coords[:, :, None] - pred_lig_coords[:, None, :]).norm(dim=-1)
|
|
conserved_mask = target_dist < 6.0
|
|
lddt = 0
|
|
for threshold in [0.5, 1, 2, 4]:
|
|
below_threshold = (pred_dist - target_dist).abs() < threshold
|
|
lddt = lddt + below_threshold.mul(conserved_mask).sum((1, 2)) / conserved_mask.sum((1, 2))
|
|
return lddt / 4
|
|
|
|
|
|
def eval_structure_prediction_losses(
|
|
lit_module: LightningModule,
|
|
batch: MODEL_BATCH,
|
|
batch_idx: int,
|
|
device: Union[str, torch.device],
|
|
stage: MODEL_STAGE,
|
|
t_1: float = 1.0,
|
|
) -> MODEL_BATCH:
|
|
"""Evaluate the structure prediction losses for a given batch.
|
|
|
|
:param lit_module: The LightningModule object to reference.
|
|
:param batch: A batch dictionary.
|
|
:param batch_idx: The batch index.
|
|
:param device: The device to use.
|
|
:param stage: The stage of the training.
|
|
:param t_1: The final timestep in the range [0, 1].
|
|
:return: Batch dictionary with losses.
|
|
"""
|
|
assert 0 <= t_1 <= 1, "`t_1` must be in the range `[0, 1]`."
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
max(batch["metadata"]["num_a_per_sample"])
|
|
|
|
if "num_molid" in batch["metadata"].keys() and batch["metadata"]["num_molid"] > 0:
|
|
batch["misc"]["protein_only"] = False
|
|
else:
|
|
batch["misc"]["protein_only"] = True
|
|
|
|
if "augmented_coordinates" in batch["features"].keys():
|
|
batch["features"]["sdf_coordinates"] = batch["features"]["augmented_coordinates"]
|
|
is_native_sample = 0
|
|
else:
|
|
is_native_sample = 1
|
|
|
|
# Sample the timestep for each structure
|
|
t = torch.rand((batch_size, 1), device=device)
|
|
|
|
prior_training = int(random.randint(0, 10) == 1) # nosec
|
|
if prior_training == 1:
|
|
t = torch.full_like(t, t_1)
|
|
|
|
if lit_module.training and lit_module.hparams.cfg.task.use_template:
|
|
use_template = bool(random.randint(0, 1)) # nosec
|
|
else:
|
|
use_template = lit_module.hparams.cfg.task.use_template
|
|
|
|
lit_module.net.assign_timestep_encodings(batch, t)
|
|
features = batch["features"]
|
|
indexer = batch["indexer"]
|
|
metadata = batch["metadata"]
|
|
|
|
loss = 0
|
|
forward_lat_converter = lit_module.net.resolve_latent_converter(
|
|
[
|
|
("features", "res_atom_positions"),
|
|
("features", "input_protein_coords"),
|
|
],
|
|
[("features", "sdf_coordinates"), ("features", "input_ligand_coords")],
|
|
)
|
|
batch = lit_module.net.prepare_protein_patch_indexers(batch)
|
|
if not batch["misc"]["protein_only"]:
|
|
max(metadata["num_i_per_sample"])
|
|
|
|
# Evaluate the contact map
|
|
ref_dist_mat, contact_logit_matrix = eval_true_contact_maps(
|
|
batch, lit_module.net.CONTACT_SCALE
|
|
)
|
|
num_cont_to_sample = max(metadata["num_I_per_sample"])
|
|
sampled_block_contacts = [
|
|
None,
|
|
]
|
|
# Onehot contact code sampling
|
|
with torch.no_grad():
|
|
for _ in range(num_cont_to_sample):
|
|
sampled_block_contacts.append(
|
|
sample_reslig_contact_matrix(
|
|
batch, contact_logit_matrix, last=sampled_block_contacts[-1]
|
|
)
|
|
)
|
|
forward_lat_converter.lig_res_anchor_mask = sample_res_rowmask_from_contacts(
|
|
batch,
|
|
contact_logit_matrix,
|
|
lit_module.hparams.cfg.task.single_protein_batch,
|
|
)
|
|
with torch.no_grad():
|
|
batch = lit_module.net.forward_interp_plcomplex_latinp(
|
|
batch, t[:, :, None], forward_lat_converter
|
|
)
|
|
if prior_training == 1:
|
|
iter_id = random.randint(0, num_cont_to_sample) # nosec
|
|
else:
|
|
iter_id = num_cont_to_sample
|
|
batch = lit_module.forward(
|
|
batch, contact_prediction=False, score=False, use_template=use_template
|
|
)
|
|
batch = lit_module.net.run_contact_map_stack(
|
|
batch,
|
|
iter_id=iter_id,
|
|
observed_block_contacts=sampled_block_contacts[iter_id],
|
|
)
|
|
pred_distogram = batch["outputs"][f"res_lig_distogram_out_{iter_id}"]
|
|
(
|
|
pl_distogram_loss,
|
|
pl_contact_loss_forward,
|
|
) = compute_contact_prediction_losses(
|
|
pred_distogram, ref_dist_mat, lit_module.net.dist_bins, lit_module.net.CONTACT_SCALE
|
|
)
|
|
cont_loss = 0
|
|
cont_loss = (
|
|
cont_loss
|
|
+ pl_distogram_loss
|
|
* lit_module.hparams.cfg.task.contact_loss_weight
|
|
* is_native_sample
|
|
)
|
|
cont_loss = (
|
|
cont_loss
|
|
+ pl_contact_loss_forward
|
|
* lit_module.hparams.cfg.task.contact_loss_weight
|
|
* is_native_sample
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_contact/contact_loss_distogram",
|
|
pl_distogram_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_contact/contact_loss_forwardKL",
|
|
pl_contact_loss_forward.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
if lit_module.hparams.cfg.task.freeze_contact_predictor:
|
|
# Keep the contact prediction parameters in the computational graph but with zero gradients
|
|
cont_loss *= 0.0
|
|
else:
|
|
with torch.no_grad():
|
|
batch = lit_module.net.forward_interp_plcomplex_latinp(
|
|
batch, t[:, :, None], forward_lat_converter
|
|
)
|
|
iter_id = 0
|
|
batch = lit_module.forward(
|
|
batch,
|
|
iter_id=0,
|
|
contact_prediction=True,
|
|
score=False,
|
|
use_template=use_template,
|
|
)
|
|
protein_distogram_loss = compute_protein_distogram_loss(
|
|
batch,
|
|
batch["features"]["res_atom_positions"][:, 1],
|
|
lit_module.net.dist_bins,
|
|
lit_module.net.dgram_head,
|
|
entry=f"res_res_grid_attr_flat_out_{iter_id}",
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_contact/prot_distogram_loss",
|
|
protein_distogram_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
if lit_module.hparams.cfg.task.freeze_contact_predictor:
|
|
# Keep the distogram prediction parameters in the computational graph but with zero gradients
|
|
protein_distogram_loss *= 0.0
|
|
|
|
# NOTE: we keep the loss weighting time-independent since `sigma=1` for all prior distributions (where relevant)
|
|
lambda_weighting = t.new_ones(batch_size)
|
|
|
|
# Run score head and evaluate structure prediction losses
|
|
res_atom_mask = features["res_atom_mask"].bool()
|
|
|
|
scores = lit_module.net.run_score_head(batch, embedding_iter_id=iter_id)
|
|
|
|
if lit_module.training:
|
|
# # Sigmoid scaling
|
|
# violation_loss_ratio = 1 / (
|
|
# 1
|
|
# + math.exp(10 - 12 * lit_module.current_epoch / lit_module.trainer.max_epochs)
|
|
# )
|
|
# violation_loss_ratio = (lit_module.current_epoch / lit_module.trainer.max_epochs)
|
|
violation_loss_ratio = 1.0
|
|
else:
|
|
violation_loss_ratio = 1.0
|
|
|
|
if not batch["misc"]["protein_only"]:
|
|
if "binding_site_mask_clean" not in batch["features"]:
|
|
with torch.no_grad():
|
|
min_lig_res_dist_clean = (
|
|
(
|
|
batch["features"]["res_atom_positions"][:, 1].view(batch_size, -1, 3)[
|
|
:, :, None
|
|
]
|
|
- batch["features"]["sdf_coordinates"].view(batch_size, -1, 3)[:, None, :]
|
|
)
|
|
.norm(dim=-1)
|
|
.amin(dim=2)
|
|
).flatten(0, 1)
|
|
binding_site_mask_clean = (
|
|
min_lig_res_dist_clean < lit_module.net.BINDING_SITE_CUTOFF
|
|
)
|
|
batch["features"]["binding_site_mask_clean"] = binding_site_mask_clean
|
|
coords_pred_prot = scores["final_coords_prot_atom_padded"][res_atom_mask].view(
|
|
batch_size, -1, 3
|
|
)
|
|
coords_ref_prot = batch["features"]["res_atom_positions"][res_atom_mask].view(
|
|
batch_size, -1, 3
|
|
)
|
|
coords_pred_bs_prot = scores["final_coords_prot_atom_padded"][
|
|
res_atom_mask & batch["features"]["binding_site_mask_clean"][:, None]
|
|
].view(batch_size, -1, 3)
|
|
coords_ref_bs_prot = batch["features"]["res_atom_positions"][
|
|
res_atom_mask & batch["features"]["binding_site_mask_clean"][:, None]
|
|
].view(batch_size, -1, 3)
|
|
coords_pred_lig = scores["final_coords_lig_atom"].view(batch_size, -1, 3)
|
|
coords_ref_lig = batch["features"]["sdf_coordinates"].view(batch_size, -1, 3)
|
|
coords_pred = torch.cat([coords_pred_prot, coords_pred_lig], dim=1)
|
|
coords_ref = torch.cat([coords_ref_prot, coords_ref_lig], dim=1)
|
|
coords_pred_bs = torch.cat([coords_pred_bs_prot, coords_pred_lig], dim=1)
|
|
coords_ref_bs = torch.cat([coords_ref_bs_prot, coords_ref_lig], dim=1)
|
|
n_I_per_sample = max(metadata["num_I_per_sample"])
|
|
lig_frame_atm_idx = torch.stack(
|
|
[
|
|
indexer["gather_idx_ijk_i"][indexer["gather_idx_I_ijk"]][:n_I_per_sample],
|
|
indexer["gather_idx_ijk_j"][indexer["gather_idx_I_ijk"]][:n_I_per_sample],
|
|
indexer["gather_idx_ijk_k"][indexer["gather_idx_I_ijk"]][:n_I_per_sample],
|
|
],
|
|
dim=0,
|
|
)
|
|
(
|
|
global_fape_pview,
|
|
global_fape_lview,
|
|
normalized_fape,
|
|
) = compute_fape_from_atom37(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
pred_lig_coords=scores["final_coords_lig_atom"],
|
|
target_lig_coords=batch["features"]["sdf_coordinates"],
|
|
lig_frame_atm_idx=lig_frame_atm_idx,
|
|
split_pl_views=True,
|
|
)
|
|
aa_distgeom_error = compute_aa_distance_geometry_loss(
|
|
batch,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
)
|
|
lig_distgeom_error = compute_sm_distance_geometry_loss(
|
|
batch,
|
|
scores["final_coords_lig_atom"],
|
|
batch["features"]["sdf_coordinates"],
|
|
)
|
|
glob_drmsd, _ = compute_drmsd_and_clashloss(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
lit_module.net.atnum2vdw_uff,
|
|
pred_lig_coords=scores["final_coords_lig_atom"],
|
|
target_lig_coords=batch["features"]["sdf_coordinates"],
|
|
ligatm_types=batch["features"]["atomic_numbers"].long(),
|
|
)
|
|
bs_drmsd, clash_error = compute_drmsd_and_clashloss(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
lit_module.net.atnum2vdw_uff,
|
|
pred_lig_coords=scores["final_coords_lig_atom"],
|
|
target_lig_coords=batch["features"]["sdf_coordinates"],
|
|
ligatm_types=batch["features"]["atomic_numbers"].long(),
|
|
binding_site=True,
|
|
)
|
|
pli_drmsd, _ = compute_drmsd_and_clashloss(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
lit_module.net.atnum2vdw_uff,
|
|
pred_lig_coords=scores["final_coords_lig_atom"],
|
|
target_lig_coords=batch["features"]["sdf_coordinates"],
|
|
ligatm_types=batch["features"]["atomic_numbers"].long(),
|
|
pl_interface=True,
|
|
)
|
|
distgeom_loss = (
|
|
aa_distgeom_error.mul(lambda_weighting) * max(metadata["num_a_per_sample"])
|
|
+ lig_distgeom_error.mul(lambda_weighting) * max(metadata["num_i_per_sample"])
|
|
).mean() / max(metadata["num_a_per_sample"])
|
|
|
|
fape_loss = (
|
|
(
|
|
global_fape_pview
|
|
+ global_fape_lview
|
|
* (
|
|
lit_module.hparams.cfg.task.ligand_score_loss_weight
|
|
/ lit_module.hparams.cfg.task.global_score_loss_weight
|
|
)
|
|
+ normalized_fape
|
|
)
|
|
.mul(lambda_weighting)
|
|
.mean()
|
|
)
|
|
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ fape_loss
|
|
* lit_module.hparams.cfg.task.global_score_loss_weight
|
|
* is_native_sample
|
|
)
|
|
loss = (
|
|
loss
|
|
+ glob_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
if use_template:
|
|
twe_drmsd = compute_template_weighted_centroid_drmsd(
|
|
batch, scores["final_coords_prot_atom_padded"]
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ twe_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_weighted",
|
|
twe_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_weighted",
|
|
twe_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ bs_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
loss = (
|
|
loss
|
|
+ pli_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
|
|
loss = (
|
|
loss
|
|
+ distgeom_loss
|
|
* lit_module.hparams.cfg.task.local_distgeom_loss_weight
|
|
* violation_loss_ratio
|
|
)
|
|
loss = (
|
|
loss
|
|
+ clash_error.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.clash_loss_weight
|
|
* violation_loss_ratio
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_contact_predictor:
|
|
loss = (0.1 + 0.9 * prior_training) * cont_loss + (1 - prior_training * 0.99) * loss
|
|
loss = (
|
|
loss + protein_distogram_loss * lit_module.hparams.cfg.task.distogram_loss_weight
|
|
)
|
|
with torch.no_grad():
|
|
tm_lbound = compute_TMscore_lbound(
|
|
batch,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
)
|
|
lig_rmsd = segment_mean(
|
|
(
|
|
(coords_pred_lig - coords_pred_prot.mean(dim=1, keepdim=True))
|
|
- (coords_ref_lig - coords_ref_prot.mean(dim=1, keepdim=True))
|
|
)
|
|
.square()
|
|
.sum(dim=-1)
|
|
.flatten(0, 1),
|
|
indexer["gather_idx_i_molid"],
|
|
metadata["num_molid"],
|
|
).sqrt()
|
|
lit_module.log(
|
|
f"{stage}/tm_lbound",
|
|
tm_lbound.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/ligand_rmsd_ubound",
|
|
lig_rmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=lig_rmsd.shape[0],
|
|
)
|
|
# L1 score matching loss
|
|
dsm_loss_global = (
|
|
(
|
|
(coords_pred - coords_pred_prot.mean(dim=1, keepdim=True))
|
|
- (coords_ref - coords_ref_prot.mean(dim=1, keepdim=True))
|
|
)
|
|
.square()
|
|
.sum(dim=-1)
|
|
.add(1e-2)
|
|
.sqrt()
|
|
.sub(1e-1)
|
|
.mean(dim=1)
|
|
.mul(lambda_weighting)
|
|
)
|
|
dsm_loss_site = (
|
|
(
|
|
(coords_pred_bs - coords_pred_bs_prot.mean(dim=1, keepdim=True))
|
|
- (coords_ref_bs - coords_ref_bs_prot.mean(dim=1, keepdim=True))
|
|
)
|
|
.square()
|
|
.sum(dim=-1)
|
|
.add(1e-2)
|
|
.sqrt()
|
|
.sub(1e-1)
|
|
.mean(dim=1)
|
|
.mul(lambda_weighting)
|
|
)
|
|
dsm_loss_ligand = (
|
|
(
|
|
(coords_pred_lig - coords_pred.mean(dim=1, keepdim=True))
|
|
- (coords_ref_lig - coords_ref.mean(dim=1, keepdim=True))
|
|
)
|
|
.square()
|
|
.sum(dim=-1)
|
|
.add(1e-2)
|
|
.sqrt()
|
|
.sub(1e-1)
|
|
.mean(dim=1)
|
|
.mul(lambda_weighting)
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/denoising_loss_global",
|
|
dsm_loss_global.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/denoising_loss_site",
|
|
dsm_loss_site.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/denoising_loss_ligand",
|
|
dsm_loss_ligand.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_global",
|
|
glob_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_site",
|
|
bs_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_pli",
|
|
pli_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_global",
|
|
glob_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_site",
|
|
bs_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_pli",
|
|
pli_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_global_proteinview",
|
|
global_fape_pview.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_global_ligandview",
|
|
global_fape_lview.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_normalized",
|
|
normalized_fape.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_loss",
|
|
fape_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/aa_distgeom_error",
|
|
aa_distgeom_error.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/lig_distgeom_error",
|
|
lig_distgeom_error.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/clash_error",
|
|
clash_error.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/clash_loss",
|
|
clash_error.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/distgeom_loss",
|
|
distgeom_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
else:
|
|
coords_pred = scores["final_coords_prot_atom_padded"][res_atom_mask].view(
|
|
batch_size, -1, 3
|
|
)
|
|
coords_ref = batch["features"]["res_atom_positions"][res_atom_mask].view(batch_size, -1, 3)
|
|
global_fape_pview, normalized_fape = compute_fape_from_atom37(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
)
|
|
aa_distgeom_error = compute_aa_distance_geometry_loss(
|
|
batch,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
)
|
|
glob_drmsd, clash_error = compute_drmsd_and_clashloss(
|
|
batch,
|
|
device,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
lit_module.net.atnum2vdw_uff,
|
|
)
|
|
distgeom_loss = aa_distgeom_error.mul(lambda_weighting).mean()
|
|
fape_loss = (global_fape_pview + normalized_fape).mul(lambda_weighting).mean()
|
|
|
|
global_fape_pview.detach()
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ distgeom_loss
|
|
* lit_module.hparams.cfg.task.local_distgeom_loss_weight
|
|
* violation_loss_ratio
|
|
)
|
|
loss = loss + fape_loss * lit_module.hparams.cfg.task.global_score_loss_weight
|
|
loss = (
|
|
loss
|
|
+ glob_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
if use_template:
|
|
twe_drmsd = compute_template_weighted_centroid_drmsd(
|
|
batch, scores["final_coords_prot_atom_padded"]
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ twe_drmsd.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.drmsd_loss_weight
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_weighted",
|
|
twe_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_weighted",
|
|
twe_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_score_head:
|
|
loss = (
|
|
loss
|
|
+ clash_error.mul(lambda_weighting).mean()
|
|
* lit_module.hparams.cfg.task.clash_loss_weight
|
|
* violation_loss_ratio
|
|
)
|
|
if not lit_module.hparams.cfg.task.freeze_contact_predictor:
|
|
loss = (
|
|
loss + protein_distogram_loss * lit_module.hparams.cfg.task.distogram_loss_weight
|
|
)
|
|
|
|
with torch.no_grad():
|
|
dsm_loss_global = (
|
|
(
|
|
(coords_pred - coords_pred.mean(dim=1, keepdim=True))
|
|
- (coords_ref - coords_ref.mean(dim=1, keepdim=True))
|
|
)
|
|
.square()
|
|
.sum(dim=-1)
|
|
.add(1e-2)
|
|
.sqrt()
|
|
.sub(1e-1)
|
|
.mean(dim=1)
|
|
.mul(lambda_weighting)
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/denoising_loss_global",
|
|
dsm_loss_global.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
tm_lbound = compute_TMscore_lbound(
|
|
batch,
|
|
scores["final_coords_prot_atom_padded"],
|
|
batch["features"]["res_atom_positions"],
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/tm_lbound",
|
|
tm_lbound.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_loss_global",
|
|
glob_drmsd.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/drmsd_global",
|
|
glob_drmsd.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_global_proteinview",
|
|
global_fape_pview.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_normalized",
|
|
normalized_fape.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}/fape_loss",
|
|
fape_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/aa_distgeom_error",
|
|
aa_distgeom_error.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/clash_error",
|
|
clash_error.mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/clash_loss",
|
|
clash_error.mul(lambda_weighting).mean().detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_violation/distgeom_loss",
|
|
distgeom_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
if torch.is_tensor(loss) and not torch.isnan(loss):
|
|
lit_module.log(
|
|
f"{stage}/loss",
|
|
loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
sync_dist=(stage != "train"),
|
|
)
|
|
batch["outputs"]["loss"] = loss
|
|
if not torch.is_tensor(batch["outputs"]["loss"]) and batch["outputs"]["loss"] == 0:
|
|
batch["outputs"]["loss"] = None
|
|
return batch
|
|
|
|
|
|
def eval_auxiliary_estimation_losses(
|
|
lit_module: LightningModule,
|
|
batch: MODEL_BATCH,
|
|
stage: MODEL_STAGE,
|
|
loss_mode: LOSS_MODES,
|
|
**kwargs: Dict[str, Any],
|
|
) -> MODEL_BATCH:
|
|
"""Evaluate the auxiliary estimation losses for a given batch.
|
|
|
|
:param lit_module: The LightningModule object to reference.
|
|
:param batch: A batch dictionary.
|
|
:param stage: The stage of the training.
|
|
:param loss_mode: The loss mode to use.
|
|
:param kwargs: Additional keyword arguments.
|
|
:return: Batch dictionary with losses.
|
|
"""
|
|
use_template = bool(random.randint(0, 1)) # nosec
|
|
if use_template:
|
|
# Enable higher ligand diversity when using backbone template
|
|
start_time = 1.0
|
|
else:
|
|
start_time = random.randint(1, 5) / 5 # nosec
|
|
with torch.no_grad():
|
|
if loss_mode == "auxiliary_estimation_without_structure_prediction":
|
|
# Sample the structure without using the structure prediction head
|
|
# i.e., Provide the holo (ground-truth) protein and ligand structures for affinity estimation
|
|
output_struct = {
|
|
"receptor": batch["features"]["res_atom_positions"].flatten(0, 1),
|
|
"receptor_padded": batch["features"]["res_atom_positions"],
|
|
"ligands": batch["features"]["sdf_coordinates"],
|
|
}
|
|
else:
|
|
output_struct = lit_module.net.sample_pl_complex_structures(
|
|
batch,
|
|
sampler="VDODE",
|
|
sampler_eta=1.0,
|
|
num_steps=int(5 / start_time),
|
|
start_time=start_time,
|
|
exact_prior=True,
|
|
use_template=use_template,
|
|
cutoff=20.0, # Hot logits
|
|
)
|
|
batch_size = batch["metadata"]["num_structid"]
|
|
batch = lit_module.net.run_auxiliary_estimation(batch, output_struct, **kwargs)
|
|
if lit_module.hparams.cfg.confidence.enabled:
|
|
with torch.no_grad():
|
|
# Receptor centroids
|
|
ref_coords = (
|
|
(
|
|
batch["features"]["res_atom_positions"]
|
|
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
|
.sum(dim=1)
|
|
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
|
)
|
|
.contiguous()
|
|
.view(batch_size, -1, 3)
|
|
)
|
|
pred_coords = (
|
|
(
|
|
output_struct["receptor_padded"]
|
|
.mul(batch["features"]["res_atom_mask"].bool()[:, :, None])
|
|
.sum(dim=1)
|
|
.div(batch["features"]["res_atom_mask"].bool().sum(dim=1)[:, None] + 1e-9)
|
|
)
|
|
.contiguous()
|
|
.view(batch_size, -1, 3)
|
|
)
|
|
# The number of effective protein atoms used in plddt calculation
|
|
n_protatm_per_sample = pred_coords.shape[1]
|
|
if output_struct["ligands"] is not None:
|
|
ref_lig_coords = (
|
|
batch["features"]["sdf_coordinates"].contiguous().view(batch_size, -1, 3)
|
|
)
|
|
ref_coords = torch.cat([ref_coords, ref_lig_coords], dim=1)
|
|
pred_lig_coords = output_struct["ligands"].contiguous().view(batch_size, -1, 3)
|
|
pred_coords = torch.cat([pred_coords, pred_lig_coords], dim=1)
|
|
per_atom_lddt, per_atom_lddt_gram = compute_per_atom_lddt(
|
|
batch, pred_coords, ref_coords
|
|
)
|
|
|
|
plddt_dev = (per_atom_lddt - batch["outputs"]["plddt"]).abs().mean()
|
|
confidence_loss = (
|
|
F.cross_entropy(
|
|
batch["outputs"]["plddt_logits"].flatten(0, 1),
|
|
per_atom_lddt_gram.flatten(0, 1),
|
|
reduction="none",
|
|
)
|
|
.contiguous()
|
|
.view(batch_size, -1)
|
|
)
|
|
conf_loss = confidence_loss.mean()
|
|
if output_struct["ligands"] is not None:
|
|
plddt_dev_lig = (
|
|
(
|
|
per_atom_lddt.view(batch_size, -1)[:, n_protatm_per_sample:]
|
|
- batch["outputs"]["plddt"].view(batch_size, -1)[:, n_protatm_per_sample:]
|
|
)
|
|
.abs()
|
|
.mean()
|
|
)
|
|
conf_loss_lig = confidence_loss[:, n_protatm_per_sample:].mean()
|
|
conf_loss = conf_loss + conf_loss_lig # + plddt_dev_lig * 0.1
|
|
lit_module.log(
|
|
f"{stage}_confidence/plddt_dev_lig",
|
|
plddt_dev_lig.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_confidence/plddt_dev",
|
|
plddt_dev.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
)
|
|
lit_module.log(
|
|
f"{stage}_confidence/loss",
|
|
conf_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
sync_dist=(stage != "train"),
|
|
)
|
|
if lit_module.hparams.cfg.task.freeze_confidence:
|
|
# Keep the confidence prediction parameters in the computational graph but with zero gradients
|
|
conf_loss *= 0
|
|
else:
|
|
conf_loss = 0
|
|
if lit_module.hparams.cfg.affinity.enabled:
|
|
num_molid_per_sample = batch["metadata"]["num_molid"] // batch_size
|
|
gather_idx_molid_structid = torch.arange(
|
|
batch_size, device=batch["outputs"]["affinity_logits"].device
|
|
).repeat_interleave(num_molid_per_sample)
|
|
|
|
# Calculate affinity loss as the mean squared error between the predicted affinity logits and the ground-truth affinity values
|
|
affinity_logits = batch["outputs"]["affinity_logits"]
|
|
# Substitute missing ground-truth affinity values with the affinity head's (detached) predicted logits to indicate no learning signal for these examples
|
|
affinity = torch.where(
|
|
batch["features"]["affinity"].isnan(),
|
|
affinity_logits.detach(),
|
|
batch["features"]["affinity"],
|
|
)
|
|
aff_loss = segment_mean(
|
|
# Find the (batched) mean squared error over all ligand chains in the same complex, then calculate the mean of each batch
|
|
(affinity_logits - affinity).square(),
|
|
gather_idx_molid_structid,
|
|
batch_size,
|
|
).mean()
|
|
lit_module.log(
|
|
f"{stage}_affinity/loss",
|
|
aff_loss.detach(),
|
|
on_epoch=True,
|
|
batch_size=batch_size,
|
|
sync_dist=(stage != "train"),
|
|
)
|
|
if lit_module.hparams.cfg.task.freeze_affinity:
|
|
# Keep the affinity prediction parameters in the computational graph but with zero gradients
|
|
aff_loss *= 0
|
|
else:
|
|
aff_loss = 0
|
|
plddt_loss = conf_loss * lit_module.hparams.cfg.task.plddt_loss_weight
|
|
affinity_loss = aff_loss * lit_module.hparams.cfg.task.affinity_loss_weight
|
|
batch["outputs"]["loss"] = plddt_loss + affinity_loss
|
|
if not torch.is_tensor(batch["outputs"]["loss"]) and batch["outputs"]["loss"] == 0:
|
|
batch["outputs"]["loss"] = None
|
|
return batch
|