Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
178
rf2aa/SE3Transformer/se3_transformer/model/basis.py
Normal file
178
rf2aa/SE3Transformer/se3_transformer/model/basis.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List
|
||||
|
||||
import e3nn.o3 as o3
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
|
||||
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
|
||||
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
||||
all_cb = []
|
||||
for d_in in range(max_degree + 1):
|
||||
for d_out in range(max_degree + 1):
|
||||
K_Js = []
|
||||
for J in range(abs(d_in - d_out), d_in + d_out + 1):
|
||||
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
|
||||
all_cb.append(K_Js)
|
||||
return all_cb
|
||||
|
||||
|
||||
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
||||
all_degrees = list(range(2 * max_degree + 1))
|
||||
with nvtx_range('spherical harmonics'):
|
||||
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
||||
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_basis_script(max_degree: int,
|
||||
use_pad_trick: bool,
|
||||
spherical_harmonics: List[Tensor],
|
||||
clebsch_gordon: List[List[Tensor]],
|
||||
amp: bool) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Compute pairwise bases matrices for degrees up to max_degree
|
||||
:param max_degree: Maximum input or output degree
|
||||
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
|
||||
:param spherical_harmonics: List of computed spherical harmonics
|
||||
:param clebsch_gordon: List of computed CB-coefficients
|
||||
:param amp: When true, return bases in FP16 precision
|
||||
"""
|
||||
basis = {}
|
||||
idx = 0
|
||||
# Double for loop instead of product() because of JIT script
|
||||
for d_in in range(max_degree + 1):
|
||||
for d_out in range(max_degree + 1):
|
||||
key = f'{d_in},{d_out}'
|
||||
K_Js = []
|
||||
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
|
||||
Q_J = clebsch_gordon[idx][freq_idx]
|
||||
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
|
||||
|
||||
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
|
||||
if amp:
|
||||
basis[key] = basis[key].half()
|
||||
if use_pad_trick:
|
||||
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
|
||||
|
||||
idx += 1
|
||||
|
||||
return basis
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def update_basis_with_fused(basis: Dict[str, Tensor],
|
||||
max_degree: int,
|
||||
use_pad_trick: bool,
|
||||
fully_fused: bool) -> Dict[str, Tensor]:
|
||||
""" Update the basis dict with partially and optionally fully fused bases """
|
||||
num_edges = basis['0,0'].shape[0]
|
||||
device = basis['0,0'].device
|
||||
dtype = basis['0,0'].dtype
|
||||
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
|
||||
|
||||
# Fused per output degree
|
||||
for d_out in range(max_degree + 1):
|
||||
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
|
||||
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
|
||||
device=device, dtype=dtype)
|
||||
acc_d, acc_f = 0, 0
|
||||
for d_in in range(max_degree + 1):
|
||||
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
|
||||
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
||||
|
||||
acc_d += degree_to_dim(d_in)
|
||||
acc_f += degree_to_dim(min(d_out, d_in))
|
||||
|
||||
basis[f'out{d_out}_fused'] = basis_fused
|
||||
|
||||
# Fused per input degree
|
||||
for d_in in range(max_degree + 1):
|
||||
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
|
||||
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
|
||||
device=device, dtype=dtype)
|
||||
acc_d, acc_f = 0, 0
|
||||
for d_out in range(max_degree + 1):
|
||||
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
|
||||
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
||||
|
||||
acc_d += degree_to_dim(d_out)
|
||||
acc_f += degree_to_dim(min(d_out, d_in))
|
||||
|
||||
basis[f'in{d_in}_fused'] = basis_fused
|
||||
|
||||
if fully_fused:
|
||||
# Fully fused
|
||||
# Double sum this way because of JIT script
|
||||
sum_freq = sum([
|
||||
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
|
||||
])
|
||||
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
|
||||
|
||||
acc_d, acc_f = 0, 0
|
||||
for d_out in range(max_degree + 1):
|
||||
b = basis[f'out{d_out}_fused']
|
||||
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
|
||||
:degree_to_dim(d_out)]
|
||||
acc_f += b.shape[2]
|
||||
acc_d += degree_to_dim(d_out)
|
||||
|
||||
basis['fully_fused'] = basis_fused
|
||||
|
||||
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
|
||||
return basis
|
||||
|
||||
|
||||
def get_basis(relative_pos: Tensor,
|
||||
max_degree: int = 4,
|
||||
compute_gradients: bool = False,
|
||||
use_pad_trick: bool = False,
|
||||
amp: bool = False) -> Dict[str, Tensor]:
|
||||
with nvtx_range('spherical harmonics'):
|
||||
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
|
||||
with nvtx_range('CB coefficients'):
|
||||
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
|
||||
|
||||
with torch.autograd.set_grad_enabled(compute_gradients):
|
||||
with nvtx_range('bases'):
|
||||
basis = get_basis_script(max_degree=max_degree,
|
||||
use_pad_trick=use_pad_trick,
|
||||
spherical_harmonics=spherical_harmonics,
|
||||
clebsch_gordon=clebsch_gordon,
|
||||
amp=amp)
|
||||
return basis
|
||||
Reference in New Issue
Block a user