Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from .linear import LinearSE3
|
||||
from .norm import NormSE3
|
||||
from .pooling import GPooling
|
||||
from .convolution import ConvSE3
|
||||
from .attention import AttentionBlockSE3
|
||||
186
rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py
Normal file
186
rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import dgl
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dgl import DGLGraph
|
||||
from dgl.ops import edge_softmax
|
||||
from torch import Tensor
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
|
||||
class AttentionSE3(nn.Module):
|
||||
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
key_fiber: Fiber,
|
||||
value_fiber: Fiber
|
||||
):
|
||||
"""
|
||||
:param num_heads: Number of attention heads
|
||||
:param key_fiber: Fiber for the keys (and also for the queries)
|
||||
:param value_fiber: Fiber for the values
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.key_fiber = key_fiber
|
||||
self.value_fiber = value_fiber
|
||||
|
||||
def forward(
|
||||
self,
|
||||
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||
query: Dict[str, Tensor], # node features
|
||||
graph: DGLGraph
|
||||
):
|
||||
with nvtx_range('AttentionSE3'):
|
||||
with nvtx_range('reshape keys and queries'):
|
||||
if isinstance(key, Tensor):
|
||||
# case where features of all types are fused
|
||||
key = key.reshape(key.shape[0], self.num_heads, -1)
|
||||
# need to reshape queries that way to keep the same layout as keys
|
||||
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
||||
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
||||
else:
|
||||
# features are not fused, need to fuse and reshape them
|
||||
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
||||
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
||||
|
||||
with nvtx_range('attention dot product + softmax'):
|
||||
# Compute attention weights (softmax of inner product between key and query)
|
||||
with torch.cuda.amp.autocast(False):
|
||||
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
||||
edge_weights /= np.sqrt(self.key_fiber.num_features)
|
||||
edge_weights = edge_softmax(graph, edge_weights)
|
||||
edge_weights = edge_weights[..., None, None]
|
||||
|
||||
with nvtx_range('weighted sum'):
|
||||
if isinstance(value, Tensor):
|
||||
# features of all types are fused
|
||||
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
||||
weights = edge_weights * v
|
||||
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
||||
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
||||
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
||||
else:
|
||||
out = {}
|
||||
for degree, channels in self.value_fiber:
|
||||
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
||||
degree_to_dim(degree))
|
||||
weights = edge_weights * v
|
||||
res = dgl.ops.copy_e_sum(graph, weights)
|
||||
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AttentionBlockSE3(nn.Module):
|
||||
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_in: Fiber,
|
||||
fiber_out: Fiber,
|
||||
fiber_edge: Optional[Fiber] = None,
|
||||
num_heads: int = 4,
|
||||
channels_div: Optional[Dict[str,int]] = None,
|
||||
use_layer_norm: bool = False,
|
||||
max_degree: bool = 4,
|
||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
:param fiber_in: Fiber describing the input features
|
||||
:param fiber_out: Fiber describing the output features
|
||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||
:param num_heads: Number of attention heads
|
||||
:param channels_div: Divide the channels by this integer for computing values
|
||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||
:param max_degree: Maximum degree used in the bases computation
|
||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||
"""
|
||||
super().__init__()
|
||||
if fiber_edge is None:
|
||||
fiber_edge = Fiber({})
|
||||
self.fiber_in = fiber_in
|
||||
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
||||
if channels_div is not None:
|
||||
value_fiber = Fiber([(degree, channels // channels_div[str(degree)]) for degree, channels in fiber_out])
|
||||
else:
|
||||
value_fiber = Fiber([(degree, channels) for degree, channels in fiber_out])
|
||||
|
||||
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
||||
# (queries are merely projected, hence degrees have to match input)
|
||||
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
||||
|
||||
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
||||
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
||||
allow_fused_output=True)
|
||||
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
||||
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
||||
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
node_features: Dict[str, Tensor],
|
||||
edge_features: Dict[str, Tensor],
|
||||
graph: DGLGraph,
|
||||
basis: Dict[str, Tensor]
|
||||
):
|
||||
with nvtx_range('AttentionBlockSE3'):
|
||||
with nvtx_range('keys / values'):
|
||||
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
||||
key, value = self._get_key_value_from_fused(fused_key_value)
|
||||
|
||||
with nvtx_range('queries'):
|
||||
with torch.cuda.amp.autocast(False):
|
||||
query = self.to_query(node_features)
|
||||
|
||||
z = self.attention(value, key, query, graph)
|
||||
z_concat = aggregate_residual(node_features, z, 'cat')
|
||||
return self.project(z_concat)
|
||||
|
||||
def _get_key_value_from_fused(self, fused_key_value):
|
||||
# Extract keys and queries features from fused features
|
||||
if isinstance(fused_key_value, Tensor):
|
||||
# Previous layer was a fully fused convolution
|
||||
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
||||
else:
|
||||
key, value = {}, {}
|
||||
for degree, feat in fused_key_value.items():
|
||||
if int(degree) in self.fiber_in.degrees:
|
||||
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
||||
else:
|
||||
value[degree] = feat
|
||||
|
||||
return key, value
|
||||
381
rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py
Normal file
381
rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Dict
|
||||
|
||||
import dgl
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dgl import DGLGraph
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||
|
||||
|
||||
class ConvSE3FuseLevel(Enum):
|
||||
"""
|
||||
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
||||
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
||||
A higher level means faster training, but also more memory usage.
|
||||
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
||||
If you want to train fast, choose a high value.
|
||||
Recommended value is FULL with AMP.
|
||||
|
||||
Fully fused TFN convolutions requirements:
|
||||
- all input channels are the same
|
||||
- all output channels are the same
|
||||
- input degrees span the range [0, ..., max_degree]
|
||||
- output degrees span the range [0, ..., max_degree]
|
||||
|
||||
Partially fused TFN convolutions requirements:
|
||||
* For fusing by output degree:
|
||||
- all input channels are the same
|
||||
- input degrees span the range [0, ..., max_degree]
|
||||
* For fusing by input degree:
|
||||
- all output channels are the same
|
||||
- output degrees span the range [0, ..., max_degree]
|
||||
|
||||
Original TFN pairwise convolutions: no requirements
|
||||
"""
|
||||
|
||||
FULL = 2
|
||||
PARTIAL = 1
|
||||
NONE = 0
|
||||
|
||||
|
||||
class RadialProfile(nn.Module):
|
||||
"""
|
||||
Radial profile function.
|
||||
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
||||
In TFN notation: $R^{l,k}$
|
||||
In SE(3)-Transformer notation: $\phi^{l,k}$
|
||||
|
||||
Note:
|
||||
In the original papers, this function only depends on relative node distances ||x||.
|
||||
Here, we allow this function to also take as input additional invariant edge features.
|
||||
This does not break equivariance and adds expressive power to the model.
|
||||
|
||||
Diagram:
|
||||
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_freq: int,
|
||||
channels_in: int,
|
||||
channels_out: int,
|
||||
edge_dim: int = 1,
|
||||
mid_dim: int = 32,
|
||||
use_layer_norm: bool = False
|
||||
):
|
||||
"""
|
||||
:param num_freq: Number of frequencies
|
||||
:param channels_in: Number of input channels
|
||||
:param channels_out: Number of output channels
|
||||
:param edge_dim: Number of invariant edge features (input to the radial function)
|
||||
:param mid_dim: Size of the hidden MLP layers
|
||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||
"""
|
||||
super().__init__()
|
||||
modules = [
|
||||
nn.Linear(edge_dim, mid_dim),
|
||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||
nn.ReLU(),
|
||||
nn.Linear(mid_dim, mid_dim),
|
||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||
nn.ReLU(),
|
||||
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
||||
]
|
||||
|
||||
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
return self.net(features)
|
||||
|
||||
|
||||
class VersatileConvSE3(nn.Module):
|
||||
"""
|
||||
Building block for TFN convolutions.
|
||||
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
freq_sum: int,
|
||||
channels_in: int,
|
||||
channels_out: int,
|
||||
edge_dim: int,
|
||||
use_layer_norm: bool,
|
||||
fuse_level: ConvSE3FuseLevel):
|
||||
super().__init__()
|
||||
self.freq_sum = freq_sum
|
||||
self.channels_out = channels_out
|
||||
self.channels_in = channels_in
|
||||
self.fuse_level = fuse_level
|
||||
self.radial_func = RadialProfile(num_freq=freq_sum,
|
||||
channels_in=channels_in,
|
||||
channels_out=channels_out,
|
||||
edge_dim=edge_dim,
|
||||
use_layer_norm=use_layer_norm)
|
||||
|
||||
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
||||
with nvtx_range(f'VersatileConvSE3'):
|
||||
num_edges = features.shape[0]
|
||||
in_dim = features.shape[2]
|
||||
if (self.training or num_edges<=4096):
|
||||
with nvtx_range(f'RadialProfile'):
|
||||
radial_weights = self.radial_func(invariant_edge_feats) \
|
||||
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||
|
||||
if basis is not None:
|
||||
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||
out_dim = basis.shape[-1]
|
||||
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
||||
out_dim += out_dim % 2 - 1 # Account for padded basis
|
||||
basis_view = basis.view(num_edges, in_dim, -1)
|
||||
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
||||
retval = (radial_weights @ tmp)[:, :, :out_dim]
|
||||
return retval
|
||||
else:
|
||||
# k = l = 0 non-fused case
|
||||
retval = radial_weights @ features
|
||||
|
||||
else:
|
||||
#fd reduce memory in inference
|
||||
EDGESTRIDE = 65536 #16384
|
||||
if basis is not None:
|
||||
out_dim = basis.shape[-1]
|
||||
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
||||
out_dim += out_dim % 2 - 1 # Account for padded basis
|
||||
else:
|
||||
out_dim = features.shape[-1]
|
||||
|
||||
retval = torch.zeros(
|
||||
(num_edges, self.channels_out, out_dim),
|
||||
dtype=features.dtype,
|
||||
device=features.device
|
||||
)
|
||||
|
||||
for i in range((num_edges-1)//EDGESTRIDE+1):
|
||||
e_i,e_j = i*EDGESTRIDE, min((i+1)*EDGESTRIDE,num_edges)
|
||||
|
||||
radial_weights = self.radial_func(invariant_edge_feats[e_i:e_j]) \
|
||||
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||
|
||||
if basis is not None:
|
||||
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||
basis_view = basis[e_i:e_j].view(e_j-e_i, in_dim, -1)
|
||||
with torch.cuda.amp.autocast(False):
|
||||
tmp = (features[e_i:e_j] @ basis_view.float()).view(e_j-e_i, -1, basis.shape[-1])
|
||||
retslice = (radial_weights.float() @ tmp)[:, :, :out_dim]
|
||||
retval[e_i:e_j] = retslice
|
||||
|
||||
else:
|
||||
# k = l = 0 non-fused case
|
||||
retval[e_i:e_j] = radial_weights @ features[e_i:e_j]
|
||||
|
||||
return retval
|
||||
|
||||
class ConvSE3(nn.Module):
|
||||
"""
|
||||
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
||||
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
||||
Features of different degrees interact together to produce output features.
|
||||
|
||||
Note 1:
|
||||
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
||||
done, and the returned features will be edge features instead of node features.
|
||||
|
||||
Note 2:
|
||||
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
||||
Input edge features are concatenated with input source node features before the kernel is applied.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fiber_in: Fiber,
|
||||
fiber_out: Fiber,
|
||||
fiber_edge: Fiber,
|
||||
pool: bool = True,
|
||||
use_layer_norm: bool = False,
|
||||
self_interaction: bool = False,
|
||||
sum_over_edge: bool = True,
|
||||
max_degree: int = 4,
|
||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||
allow_fused_output: bool = False
|
||||
):
|
||||
"""
|
||||
:param fiber_in: Fiber describing the input features
|
||||
:param fiber_out: Fiber describing the output features
|
||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||
:param pool: If True, compute final node features by averaging incoming edge features
|
||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||
:param self_interaction: Apply self-interaction of nodes
|
||||
:param max_degree: Maximum degree used in the bases computation
|
||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||
:param allow_fused_output: Allow the module to output a fused representation of features
|
||||
"""
|
||||
super().__init__()
|
||||
self.pool = pool
|
||||
self.fiber_in = fiber_in
|
||||
self.fiber_out = fiber_out
|
||||
self.self_interaction = self_interaction
|
||||
self.sum_over_edge = sum_over_edge
|
||||
self.max_degree = max_degree
|
||||
self.allow_fused_output = allow_fused_output
|
||||
|
||||
# channels_in: account for the concatenation of edge features
|
||||
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
||||
channels_out_set = set([f.channels for f in self.fiber_out])
|
||||
unique_channels_in = (len(channels_in_set) == 1)
|
||||
unique_channels_out = (len(channels_out_set) == 1)
|
||||
degrees_up_to_max = list(range(max_degree + 1))
|
||||
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
||||
|
||||
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||
# Single fused convolution
|
||||
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
||||
|
||||
sum_freq = sum([
|
||||
degree_to_dim(min(d_in, d_out))
|
||||
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
||||
])
|
||||
|
||||
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
||||
fuse_level=self.used_fuse_level, **common_args)
|
||||
|
||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
||||
# Convolutions fused per output degree
|
||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||
self.conv_out = nn.ModuleDict()
|
||||
for d_out, c_out in fiber_out:
|
||||
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
||||
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
||||
fuse_level=self.used_fuse_level, **common_args)
|
||||
|
||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||
# Convolutions fused per input degree
|
||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||
self.conv_in = nn.ModuleDict()
|
||||
for d_in, c_in in fiber_in:
|
||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
||||
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
|
||||
else:
|
||||
# Use pairwise TFN convolutions
|
||||
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
||||
self.conv = nn.ModuleDict()
|
||||
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
||||
dict_key = f'{degree_in},{degree_out}'
|
||||
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
||||
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
||||
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
||||
fuse_level=self.used_fuse_level, **common_args)
|
||||
|
||||
if self_interaction:
|
||||
self.to_kernel_self = nn.ParameterDict()
|
||||
for degree_out, channels_out in fiber_out:
|
||||
if fiber_in[degree_out]:
|
||||
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
node_feats: Dict[str, Tensor],
|
||||
edge_feats: Dict[str, Tensor],
|
||||
graph: DGLGraph,
|
||||
basis: Dict[str, Tensor]
|
||||
):
|
||||
with nvtx_range(f'ConvSE3'):
|
||||
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
||||
src, dst = graph.edges()
|
||||
out = {}
|
||||
in_features = []
|
||||
|
||||
# Fetch all input features from edge and node features
|
||||
for degree_in in self.fiber_in.degrees:
|
||||
src_node_features = node_feats[str(degree_in)][src]
|
||||
if degree_in > 0 and str(degree_in) in edge_feats:
|
||||
# Handle edge features of any type by concatenating them to node features
|
||||
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
||||
in_features.append(src_node_features)
|
||||
|
||||
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
||||
in_features_fused = torch.cat(in_features, dim=-1)
|
||||
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
||||
|
||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||
out = unfuse_features(out, self.fiber_out.degrees)
|
||||
|
||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
||||
in_features_fused = torch.cat(in_features, dim=-1)
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
|
||||
basis[f'out{degree_out}_fused'])
|
||||
|
||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||
out = 0
|
||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
|
||||
basis[f'in{degree_in}_fused'])
|
||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||
out = unfuse_features(out, self.fiber_out.degrees)
|
||||
else:
|
||||
# Fallback to pairwise TFN convolutions
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
out_feature = 0
|
||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||
dict_key = f'{degree_in},{degree_out}'
|
||||
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
|
||||
basis.get(dict_key, None))
|
||||
out[str(degree_out)] = out_feature
|
||||
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
||||
with nvtx_range(f'self interaction'):
|
||||
dst_features = node_feats[str(degree_out)][dst]
|
||||
kernel_self = self.to_kernel_self[str(degree_out)]
|
||||
out[str(degree_out)] += kernel_self @ dst_features
|
||||
|
||||
if self.pool:
|
||||
if self.sum_over_edge:
|
||||
with nvtx_range(f'pooling'):
|
||||
if isinstance(out, dict):
|
||||
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
||||
else:
|
||||
out = dgl.ops.copy_e_sum(graph, out)
|
||||
else:
|
||||
with nvtx_range(f'pooling'):
|
||||
if isinstance(out, dict):
|
||||
out[str(degree_out)] = dgl.ops.copy_e_mean(graph, out[str(degree_out)])
|
||||
else:
|
||||
out = dgl.ops.copy_e_mean(graph, out)
|
||||
return out
|
||||
59
rf2aa/SE3Transformer/se3_transformer/model/layers/linear.py
Normal file
59
rf2aa/SE3Transformer/se3_transformer/model/layers/linear.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class LinearSE3(nn.Module):
|
||||
"""
|
||||
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
||||
Maps a fiber to a fiber with the same degrees (channels may be different).
|
||||
No interaction between degrees, but interaction between channels.
|
||||
|
||||
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
||||
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
||||
:
|
||||
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
||||
"""
|
||||
|
||||
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
||||
super().__init__()
|
||||
self.weights = nn.ParameterDict({
|
||||
str(degree_out): nn.Parameter(
|
||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||
for degree_out, channels_out in fiber_out
|
||||
})
|
||||
|
||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||
return {
|
||||
degree: self.weights[degree] @ features[degree]
|
||||
for degree, weight in self.weights.items()
|
||||
}
|
||||
83
rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py
Normal file
83
rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class NormSE3(nn.Module):
|
||||
"""
|
||||
Norm-based SE(3)-equivariant nonlinearity.
|
||||
|
||||
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
||||
feature_in ──┤ * ──> feature_out
|
||||
└──> feature_phase ────────────────────────────┘
|
||||
"""
|
||||
|
||||
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||
|
||||
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
||||
super().__init__()
|
||||
self.fiber = fiber
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
if len(set(fiber.channels)) == 1:
|
||||
# Fuse all the layer normalizations into a group normalization
|
||||
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
||||
else:
|
||||
# Use multiple layer normalizations
|
||||
self.layer_norms = nn.ModuleDict({
|
||||
str(degree): nn.LayerNorm(channels)
|
||||
for degree, channels in fiber
|
||||
})
|
||||
|
||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||
with nvtx_range('NormSE3'):
|
||||
output = {}
|
||||
if hasattr(self, 'group_norm'):
|
||||
# Compute per-degree norms of features
|
||||
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||
for d in self.fiber.degrees]
|
||||
fused_norms = torch.cat(norms, dim=-2)
|
||||
|
||||
# Transform the norms only
|
||||
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
||||
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
||||
|
||||
# Scale features to the new norms
|
||||
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
||||
output[str(d)] = features[str(d)] / norm * new_norm
|
||||
else:
|
||||
for degree, feat in features.items():
|
||||
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
||||
output[degree] = new_norm * feat / norm
|
||||
|
||||
return output
|
||||
53
rf2aa/SE3Transformer/se3_transformer/model/layers/pooling.py
Normal file
53
rf2aa/SE3Transformer/se3_transformer/model/layers/pooling.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Dict, Literal
|
||||
|
||||
import torch.nn as nn
|
||||
from dgl import DGLGraph
|
||||
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class GPooling(nn.Module):
|
||||
"""
|
||||
Graph max/average pooling on a given feature type.
|
||||
The average can be taken for any feature type, and equivariance will be maintained.
|
||||
The maximum can only be taken for invariant features (type 0).
|
||||
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
||||
"""
|
||||
|
||||
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
||||
"""
|
||||
:param feat_type: Feature type to pool
|
||||
:param pool: Type of pooling: max or avg
|
||||
"""
|
||||
super().__init__()
|
||||
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
||||
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
||||
self.feat_type = feat_type
|
||||
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
||||
|
||||
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
||||
pooled = self.pool(graph, features[str(self.feat_type)])
|
||||
return pooled.squeeze(dim=-1)
|
||||
Reference in New Issue
Block a user