Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
257
rf2aa/SE3Transformer/se3_transformer/model/transformer.py
Normal file
257
rf2aa/SE3Transformer/se3_transformer/model/transformer.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
from typing import Optional, Literal, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dgl import DGLGraph
|
||||
from torch import Tensor
|
||||
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.norm import NormSE3
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.layers.pooling import GPooling
|
||||
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
|
||||
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||
|
||||
|
||||
class Sequential(nn.Sequential):
|
||||
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
for module in self:
|
||||
input = module(input, *args, **kwargs)
|
||||
return input
|
||||
|
||||
|
||||
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
||||
""" Add relative positions to existing edge features """
|
||||
edge_features = edge_features.copy() if edge_features else {}
|
||||
r = relative_pos.norm(dim=-1, keepdim=True)
|
||||
if '0' in edge_features:
|
||||
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
||||
else:
|
||||
edge_features['0'] = r[..., None]
|
||||
|
||||
return edge_features
|
||||
|
||||
|
||||
class SE3Transformer(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers: int,
|
||||
fiber_in: Fiber,
|
||||
fiber_hidden: Fiber,
|
||||
fiber_out: Fiber,
|
||||
num_heads: int,
|
||||
channels_div: int,
|
||||
fiber_edge: Fiber = Fiber({}),
|
||||
return_type: Optional[int] = None,
|
||||
pooling: Optional[Literal['avg', 'max']] = None,
|
||||
final_layer: Optional[Literal['conv', 'lin', 'att']] = 'conv',
|
||||
norm: bool = True,
|
||||
use_layer_norm: bool = True,
|
||||
tensor_cores: bool = False,
|
||||
low_memory: bool = False,
|
||||
populate_edge: Optional[Literal['lin', 'arcsin', 'log', 'zero']] = 'lin',
|
||||
sum_over_edge: bool = True,
|
||||
**kwargs):
|
||||
"""
|
||||
:param num_layers: Number of attention layers
|
||||
:param fiber_in: Input fiber description
|
||||
:param fiber_hidden: Hidden fiber description
|
||||
:param fiber_out: Output fiber description
|
||||
:param fiber_edge: Input edge fiber description
|
||||
:param num_heads: Number of attention heads
|
||||
:param channels_div: Channels division before feeding to attention layer
|
||||
:param return_type: Return only features of this type
|
||||
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
||||
:param norm: Apply a normalization layer after each attention block
|
||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
||||
:param low_memory: If True, will use slower ops that use less memory
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.fiber_edge = fiber_edge
|
||||
self.num_heads = num_heads
|
||||
self.channels_div = channels_div
|
||||
self.return_type = return_type
|
||||
self.pooling = pooling
|
||||
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
||||
self.tensor_cores = tensor_cores
|
||||
self.low_memory = low_memory
|
||||
self.populate_edge = populate_edge
|
||||
|
||||
if low_memory and not tensor_cores:
|
||||
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
||||
|
||||
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
||||
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
||||
|
||||
div = dict((str(degree), channels_div) for degree in range(self.max_degree+1))
|
||||
div_fin = dict((str(degree), 1) for degree in range(self.max_degree+1))
|
||||
div_fin['0'] = channels_div
|
||||
|
||||
graph_modules = []
|
||||
for i in range(num_layers):
|
||||
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||
fiber_out=fiber_hidden,
|
||||
fiber_edge=fiber_edge,
|
||||
num_heads=num_heads,
|
||||
channels_div=div,
|
||||
use_layer_norm=use_layer_norm,
|
||||
max_degree=self.max_degree,
|
||||
fuse_level=fuse_level))
|
||||
if norm:
|
||||
graph_modules.append(NormSE3(fiber_hidden))
|
||||
fiber_in = fiber_hidden
|
||||
|
||||
if final_layer == 'conv':
|
||||
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
||||
fiber_out=fiber_out,
|
||||
fiber_edge=fiber_edge,
|
||||
self_interaction=True,
|
||||
sum_over_edge=sum_over_edge,
|
||||
use_layer_norm=use_layer_norm,
|
||||
max_degree=self.max_degree))
|
||||
elif final_layer == "lin":
|
||||
graph_modules.append(LinearSE3(fiber_in=fiber_in,
|
||||
fiber_out=fiber_out))
|
||||
else:
|
||||
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||
fiber_out=fiber_out,
|
||||
fiber_edge=fiber_edge,
|
||||
num_heads=1,
|
||||
channels_div=div_fin,
|
||||
use_layer_norm=use_layer_norm,
|
||||
max_degree=self.max_degree,
|
||||
fuse_level=fuse_level))
|
||||
self.graph_modules = Sequential(*graph_modules)
|
||||
|
||||
if pooling is not None:
|
||||
assert return_type is not None, 'return_type must be specified when pooling'
|
||||
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
||||
|
||||
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
||||
edge_feats: Optional[Dict[str, Tensor]] = None,
|
||||
basis: Optional[Dict[str, Tensor]] = None):
|
||||
# Compute bases in case they weren't precomputed as part of the data loading
|
||||
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
||||
use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||
amp=torch.is_autocast_enabled())
|
||||
|
||||
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
||||
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||
fully_fused=self.tensor_cores and not self.low_memory)
|
||||
|
||||
if self.populate_edge=='lin':
|
||||
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
||||
elif self.populate_edge=='arcsin':
|
||||
r = graph.edata['rel_pos'].norm(dim=-1, keepdim=True)
|
||||
r = torch.maximum(r, torch.zeros_like(r) + 4.0) - 4.0
|
||||
r = torch.arcsinh(r)/3.0
|
||||
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
|
||||
elif self.populate_edge=='log':
|
||||
# fd - replace with log(1+x)
|
||||
r = torch.log( 1 + graph.edata['rel_pos'].norm(dim=-1, keepdim=True) )
|
||||
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
|
||||
else:
|
||||
edge_feats['0'] = torch.cat((edge_feats['0'], torch.zeros_like(edge_feats['0'][:,:1,:])), dim=1)
|
||||
|
||||
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
||||
|
||||
if self.pooling is not None:
|
||||
return self.pooling_module(node_feats, graph=graph)
|
||||
|
||||
if self.return_type is not None:
|
||||
return node_feats[str(self.return_type)]
|
||||
|
||||
return node_feats
|
||||
|
||||
@staticmethod
|
||||
def add_argparse_args(parser):
|
||||
parser.add_argument('--num_layers', type=int, default=7,
|
||||
help='Number of stacked Transformer layers')
|
||||
parser.add_argument('--num_heads', type=int, default=8,
|
||||
help='Number of heads in self-attention')
|
||||
parser.add_argument('--channels_div', type=int, default=2,
|
||||
help='Channels division before feeding to attention layer')
|
||||
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
||||
help='Type of graph pooling')
|
||||
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Apply a normalization layer after each attention block')
|
||||
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Apply layer normalization between MLP layers')
|
||||
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='If true, will use fused ops that are slower but that use less memory '
|
||||
'(expect 25 percent less memory). '
|
||||
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class SE3TransformerPooled(nn.Module):
|
||||
def __init__(self,
|
||||
fiber_in: Fiber,
|
||||
fiber_out: Fiber,
|
||||
fiber_edge: Fiber,
|
||||
num_degrees: int,
|
||||
num_channels: int,
|
||||
output_dim: int,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
||||
self.transformer = SE3Transformer(
|
||||
fiber_in=fiber_in,
|
||||
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
||||
fiber_out=fiber_out,
|
||||
fiber_edge=fiber_edge,
|
||||
return_type=0,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
n_out_features = fiber_out.num_features
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(n_out_features, n_out_features),
|
||||
nn.ReLU(),
|
||||
nn.Linear(n_out_features, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, graph, node_feats, edge_feats, basis=None):
|
||||
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
||||
y = self.mlp(feats).squeeze(-1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def add_argparse_args(parent_parser):
|
||||
parser = parent_parser.add_argument_group("Model architecture")
|
||||
SE3Transformer.add_argparse_args(parser)
|
||||
parser.add_argument('--num_degrees',
|
||||
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
||||
type=int, default=4)
|
||||
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
||||
return parent_parser
|
||||
Reference in New Issue
Block a user