Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths

This commit is contained in:
2026-03-17 17:57:24 +01:00
commit 6eef3bb748
108 changed files with 28144 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import pathlib
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)

View File

@@ -0,0 +1,160 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import time
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import torch
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.metrics import MeanAbsoluteError
class BaseCallback(ABC):
def on_fit_start(self, optimizer, args):
pass
def on_fit_end(self):
pass
def on_epoch_end(self):
pass
def on_batch_start(self):
pass
def on_validation_step(self, input, target, pred):
pass
def on_validation_end(self, epoch=None):
pass
def on_checkpoint_load(self, checkpoint):
pass
def on_checkpoint_save(self, checkpoint):
pass
class LRSchedulerCallback(BaseCallback):
def __init__(self, logger: Optional[Logger] = None):
self.logger = logger
self.scheduler = None
@abstractmethod
def get_scheduler(self, optimizer, args):
pass
def on_fit_start(self, optimizer, args):
self.scheduler = self.get_scheduler(optimizer, args)
def on_checkpoint_load(self, checkpoint):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
def on_checkpoint_save(self, checkpoint):
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
def on_epoch_end(self):
if self.logger is not None:
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
self.scheduler.step()
class QM9MetricCallback(BaseCallback):
""" Logs the rescaled mean absolute error for QM9 regression tasks """
def __init__(self, logger, targets_std, prefix=''):
self.mae = MeanAbsoluteError()
self.logger = logger
self.targets_std = targets_std
self.prefix = prefix
self.best_mae = float('inf')
def on_validation_step(self, input, target, pred):
self.mae(pred.detach(), target.detach())
def on_validation_end(self, epoch=None):
mae = self.mae.compute() * self.targets_std
logging.info(f'{self.prefix} MAE: {mae}')
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
self.best_mae = min(self.best_mae, mae)
def on_fit_end(self):
if self.best_mae != float('inf'):
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
class QM9LRSchedulerCallback(LRSchedulerCallback):
def __init__(self, logger, epochs):
super().__init__(logger)
self.epochs = epochs
def get_scheduler(self, optimizer, args):
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
class PerformanceCallback(BaseCallback):
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
self.batch_size = batch_size
self.warmup_epochs = warmup_epochs
self.epoch = 0
self.timestamps = []
self.mode = mode
self.logger = logger
def on_batch_start(self):
if self.epoch >= self.warmup_epochs:
self.timestamps.append(time.time() * 1000.0)
def _log_perf(self):
stats = self.process_performance_stats()
for k, v in stats.items():
logging.info(f'performance {k}: {v}')
self.logger.log_metrics(stats)
def on_epoch_end(self):
self.epoch += 1
def on_fit_end(self):
if self.epoch > self.warmup_epochs:
self._log_perf()
self.timestamps = []
def process_performance_stats(self):
timestamps = np.asarray(self.timestamps)
deltas = np.diff(timestamps)
throughput = (self.batch_size / deltas).mean()
stats = {
f"throughput_{self.mode}": throughput,
f"latency_{self.mode}_mean": deltas.mean(),
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
}
for level in [90, 95, 99]:
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
return stats

View File

@@ -0,0 +1,325 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import collections
import itertools
import math
import os
import pathlib
import re
import pynvml
class Device:
# assumes nvml returns list of 64 bit ints
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
def __init__(self, device_idx):
super().__init__()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
def get_name(self):
return pynvml.nvmlDeviceGetName(self.handle)
def get_uuid(self):
return pynvml.nvmlDeviceGetUUID(self.handle)
def get_cpu_affinity(self):
affinity_string = ""
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
# assume nvml returns list of 64 bit ints
affinity_string = "{:064b}".format(j) + affinity_string
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is in 0th element of list
ret = [i for i, e in enumerate(affinity_list) if e != 0]
return ret
def get_thread_siblings_list():
"""
Returns a list of 2-element integer tuples representing pairs of
hyperthreading cores.
"""
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
thread_siblings_list = []
pattern = re.compile(r"(\d+)\D(\d+)")
for fname in pathlib.Path(path[0]).glob(path[1:]):
with open(fname) as f:
content = f.read().strip()
res = pattern.findall(content)
if res:
pair = tuple(map(int, res[0]))
thread_siblings_list.append(pair)
return thread_siblings_list
def check_socket_affinities(socket_affinities):
# sets of cores should be either identical or disjoint
for i, j in itertools.product(socket_affinities, socket_affinities):
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
devices = [Device(i) for i in range(nproc_per_node)]
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
if exclude_unavailable_cores:
available_cores = os.sched_getaffinity(0)
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
check_socket_affinities(socket_affinities)
return socket_affinities
def set_socket_affinity(gpu_id):
"""
The process is assigned with all available logical CPU cores from the CPU
socket connected to the GPU with a given id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
os.sched_setaffinity(0, affinity)
def set_single_affinity(gpu_id):
"""
The process is assigned with the first available logical CPU core from the
list of all CPU cores from the CPU socket connected to the GPU with a given
id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
# exclude unavailable cores
available_cores = os.sched_getaffinity(0)
affinity = list(set(affinity) & available_cores)
os.sched_setaffinity(0, affinity[:1])
def set_single_unique_affinity(gpu_id, nproc_per_node):
"""
The process is assigned with a single unique available physical CPU core
from the list of all CPU cores from the CPU socket connected to the GPU with
a given id.
Args:
gpu_id: index of a GPU
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
affinities = []
assigned = []
for socket_affinity in socket_affinities:
for core in socket_affinity:
if core not in assigned:
affinities.append([core])
assigned.append(core)
break
os.sched_setaffinity(0, affinities[gpu_id])
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
"""
The process is assigned with an unique subset of available physical CPU
cores from the CPU socket connected to a GPU with a given id.
Assignment automatically includes hyperthreading siblings (if siblings are
available).
Args:
gpu_id: index of a GPU
nproc_per_node: total number of processes per node
mode: mode
balanced: assign an equal number of physical cores to each process
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove hyperthreading siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
socket_affinities_to_device_ids = collections.defaultdict(list)
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
# compute minimal number of physical cores per GPU across all GPUs and
# sockets, code assigns this number of cores per GPU if balanced == True
min_physical_cores_per_gpu = min(
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
)
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
devices_per_group = len(device_ids)
if balanced:
cores_per_device = min_physical_cores_per_gpu
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
else:
cores_per_device = len(socket_affinity) // devices_per_group
for group_id, device_id in enumerate(device_ids):
if device_id == gpu_id:
# In theory there should be no difference in performance between
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
# but 'continuous' should be better for DGX A100 because on AMD
# Rome 4 consecutive cores are sharing L3 cache.
# TODO: code doesn't attempt to automatically detect layout of
# L3 cache, also external environment may already exclude some
# cores, this code makes no attempt to detect it and to align
# mapping to multiples of 4.
if mode == "interleaved":
affinity = list(socket_affinity[group_id::devices_per_group])
elif mode == "continuous":
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
else:
raise RuntimeError("Unknown set_socket_unique_affinity mode")
# unconditionally reintroduce hyperthreading siblings, this step
# may result in a different numbers of logical cores assigned to
# each GPU even if balanced == True (if hyperthreading siblings
# aren't available for a subset of cores due to some external
# constraints, siblings are re-added unconditionally, in the
# worst case unavailable logical core will be ignored by
# os.sched_setaffinity().
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
os.sched_setaffinity(0, affinity)
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
"""
The process is assigned with a proper CPU affinity which matches hardware
architecture on a given platform. Usually it improves and stabilizes
performance of deep learning training workloads.
This function assumes that the workload is running in multi-process
single-device mode (there are multiple training processes and each process
is running on a single GPU), which is typical for multi-GPU training
workloads using `torch.nn.parallel.DistributedDataParallel`.
Available affinity modes:
* 'socket' - the process is assigned with all available logical CPU cores
from the CPU socket connected to the GPU with a given id.
* 'single' - the process is assigned with the first available logical CPU
core from the list of all CPU cores from the CPU socket connected to the GPU
with a given id (multiple GPUs could be assigned with the same CPU core).
* 'single_unique' - the process is assigned with a single unique available
physical CPU core from the list of all CPU cores from the CPU socket
connected to the GPU with a given id.
* 'socket_unique_interleaved' - the process is assigned with an unique
subset of available physical CPU cores from the CPU socket connected to a
GPU with a given id, hyperthreading siblings are included automatically,
cores are assigned with interleaved indexing pattern
* 'socket_unique_continuous' - (the default) the process is assigned with an
unique subset of available physical CPU cores from the CPU socket connected
to a GPU with a given id, hyperthreading siblings are included
automatically, cores are assigned with continuous indexing pattern
'socket_unique_continuous' is the recommended mode for deep learning
training workloads on NVIDIA DGX machines.
Args:
gpu_id: integer index of a GPU
nproc_per_node: number of processes per node
mode: affinity mode
balanced: assign an equal number of physical cores to each process,
affects only 'socket_unique_interleaved' and
'socket_unique_continuous' affinity modes
Returns a set of logical CPU cores on which the process is eligible to run.
Example:
import argparse
import os
import gpu_affinity
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--local_rank',
type=int,
default=os.getenv('LOCAL_RANK', 0),
)
args = parser.parse_args()
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
print(f'{args.local_rank}: core affinity: {affinity}')
if __name__ == "__main__":
main()
Launch the example with:
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
This function restricts execution only to the CPU cores directly connected
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
of CPU memory bandwidth (which may be fine for many DL models).
"""
pynvml.nvmlInit()
if mode == "socket":
set_socket_affinity(gpu_id)
elif mode == "single":
set_single_affinity(gpu_id)
elif mode == "single_unique":
set_single_unique_affinity(gpu_id, nproc_per_node)
elif mode == "socket_unique_interleaved":
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
elif mode == "socket_unique_continuous":
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
else:
raise RuntimeError("Unknown affinity mode")
affinity = os.sched_getaffinity(0)
return affinity

View File

@@ -0,0 +1,131 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import List
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import BaseCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import DLLogger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank
@torch.inference_mode()
def evaluate(model: nn.Module,
dataloader: DataLoader,
callbacks: List[BaseCallback],
args):
model.eval()
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
leave=False, disable=(args.silent or get_local_rank() != 0)):
*input, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*input)
for callback in callbacks:
callback.on_validation_step(input, target, pred)
if __name__ == '__main__':
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import init_distributed, seed_everything
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled, Fiber
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
import torch.distributed as dist
import logging
import sys
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Inference on the test set |')
logging.info('===============================')
if not args.benchmark and args.load_ckpt_path is None:
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
sys.exit(1)
if args.benchmark:
logging.info('Running benchmark mode with one warmup pass')
if args.seed is not None:
seed_everything(args.seed)
major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
**vars(args)
)
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
model.to(device=torch.cuda.current_device())
if args.load_ckpt_path is not None:
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
model.load_state_dict(checkpoint['state_dict'])
if is_distributed:
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model,
test_dataloader,
callbacks,
args)
for callback in callbacks:
callback.on_validation_end()
if args.benchmark:
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
for _ in range(6):
evaluate(model,
test_dataloader,
callbacks,
args)
callbacks[0].on_epoch_end()
callbacks[0].on_fit_end()

View File

@@ -0,0 +1,134 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import pathlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any, Callable, Optional
import dllogger
import torch.distributed as dist
import wandb
from dllogger import Verbosity
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import rank_zero_only
class Logger(ABC):
@rank_zero_only
@abstractmethod
def log_hyperparams(self, params):
pass
@rank_zero_only
@abstractmethod
def log_metrics(self, metrics, step=None):
pass
@staticmethod
def _sanitize_params(params):
def _sanitize(val):
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return getattr(val, "__name__", None)
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
return str(val)
return val
return {key: _sanitize(val) for key, val in params.items()}
class LoggerCollection(Logger):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def __getitem__(self, index):
return [logger for logger in self.loggers][index]
@rank_zero_only
def log_metrics(self, metrics, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, step)
@rank_zero_only
def log_hyperparams(self, params):
for logger in self.loggers:
logger.log_hyperparams(params)
class DLLogger(Logger):
def __init__(self, save_dir: pathlib.Path, filename: str):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dllogger.init(
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
@rank_zero_only
def log_hyperparams(self, params):
params = self._sanitize_params(params)
dllogger.log(step="PARAMETER", data=params)
@rank_zero_only
def log_metrics(self, metrics, step=None):
if step is None:
step = tuple()
dllogger.log(step=step, data=metrics)
class WandbLogger(Logger):
def __init__(
self,
name: str,
save_dir: pathlib.Path,
id: Optional[str] = None,
project: Optional[str] = None
):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
self.experiment = wandb.init(name=name,
project=project,
id=id,
dir=str(save_dir),
resume='allow',
anonymous='must')
@rank_zero_only
def log_hyperparams(self, params: Dict[str, Any]) -> None:
params = self._sanitize_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
if step is not None:
self.experiment.log({**metrics, 'epoch': step})
else:
self.experiment.log(metrics)

View File

@@ -0,0 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
from torch import Tensor
class Metric(ABC):
""" Metric class with synchronization capabilities similar to TorchMetrics """
def __init__(self):
self.states = {}
def add_state(self, name: str, default: Tensor):
assert name not in self.states
self.states[name] = default.clone()
setattr(self, name, default)
def synchronize(self):
if dist.is_initialized():
for state in self.states:
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
def reset(self):
for name, default in self.states.items():
setattr(self, name, default.clone())
def compute(self):
self.synchronize()
value = self._compute().item()
self.reset()
return value
@abstractmethod
def _compute(self):
pass
@abstractmethod
def update(self, preds: Tensor, targets: Tensor):
pass
class MeanAbsoluteError(Metric):
def __init__(self):
super().__init__()
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
def update(self, preds: Tensor, targets: Tensor):
preds = preds.detach()
n = preds.shape[0]
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
self.total += n
self.error += error
def _compute(self):
return self.error / self.total

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import pathlib
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.inference import evaluate
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = {
'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
for callback in callbacks:
callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks:
callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch']
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward()
# gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer)
grad_scaler.update()
optimizer.zero_grad()
losses.append(loss.item())
return np.mean(losses)
def train(model: nn.Module,
loss_fn: _Loss,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
callbacks: List[BaseCallback],
logger: Logger,
args):
device = torch.cuda.current_device()
model.to(device=device)
local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
if args.optimizer == 'adam':
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
elif args.optimizer == 'lamb':
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
loss = (loss / world_size).item()
logging.info(f'Train loss: {loss}')
logger.log_metrics({'train loss': loss}, epoch_idx)
for callback in callbacks:
callback.on_epoch_end()
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
evaluate(model, val_dataloader, callbacks, args)
model.train()
for callback in callbacks:
callback.on_validation_end(epoch_idx)
if args.save_ckpt_path is not None and not args.benchmark:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
callback.on_fit_end()
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
if __name__ == '__main__':
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |')
logging.info('===============================')
if args.seed is not None:
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
logger = LoggerCollection([
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
])
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
)
loss_fn = nn.L1Loss()
if args.benchmark:
logging.info('Running benchmark mode')
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
if is_distributed:
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
print_parameters_count(model)
logger.log_hyperparams(vars(args))
increase_l2_fetch_granularity()
train(model,
loss_fn,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
callbacks,
logger,
args)
logging.info('Training finished successfully')

View File

@@ -0,0 +1,130 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import ctypes
import logging
import os
import random
from functools import wraps
from typing import Union, List, Dict
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
def aggregate_residual(feats1, feats2, method: str):
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
if method in ['add', 'sum']:
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
elif method in ['cat', 'concat']:
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
else:
raise ValueError('Method must be add/sum or cat/concat')
def degree_to_dim(degree: int) -> int:
return 2 * degree + 1
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
def str2bool(v: Union[bool, str]) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def to_cuda(x):
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
if isinstance(x, Tensor):
return x.cuda(non_blocking=True)
elif isinstance(x, tuple):
return (to_cuda(v) for v in x)
elif isinstance(x, list):
return [to_cuda(v) for v in x]
elif isinstance(x, dict):
return {k: to_cuda(v) for k, v in x.items()}
else:
# DGLGraph or other objects
return x.to(device=torch.cuda.current_device())
def get_local_rank() -> int:
return int(os.environ.get('LOCAL_RANK', 0))
def init_distributed() -> bool:
world_size = int(os.environ.get('WORLD_SIZE', 1))
distributed = world_size > 1
if distributed:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend, init_method='env://')
if backend == 'nccl':
torch.cuda.set_device(get_local_rank())
else:
logging.warning('Running on CPU only!')
assert torch.distributed.is_initialized()
return distributed
def increase_l2_fetch_granularity():
# maximum fetch granularity of L2: 128 bytes
_libcudart = ctypes.CDLL('libcudart.so')
# set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128
def seed_everything(seed):
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def rank_zero_only(fn):
@wraps(fn)
def wrapped_fn(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
return fn(*args, **kwargs)
return wrapped_fn
def using_tensor_cores(amp: bool) -> bool:
major_cc, minor_cc = torch.cuda.get_device_capability()
return (amp and major_cc >= 7) or major_cc >= 8