Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
60
rf2aa/training/EMA.py
Normal file
60
rf2aa/training/EMA.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import contextlib
|
||||
|
||||
class EMA(nn.Module):
|
||||
|
||||
def __init__(self, model, decay):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
|
||||
self.model = model
|
||||
self.shadow = deepcopy(self.model)
|
||||
|
||||
for param in self.shadow.parameters():
|
||||
param.detach_()
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
if not self.training:
|
||||
print("EMA update should only be called during training", file=stderr, flush=True)
|
||||
return
|
||||
|
||||
model_params = OrderedDict(self.model.named_parameters())
|
||||
shadow_params = OrderedDict(self.shadow.named_parameters())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_params.keys() == shadow_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
||||
if param.requires_grad:
|
||||
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
|
||||
|
||||
model_buffers = OrderedDict(self.model.named_buffers())
|
||||
shadow_buffers = OrderedDict(self.shadow.named_buffers())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_buffers.keys() == shadow_buffers.keys()
|
||||
|
||||
for name, buffer in model_buffers.items():
|
||||
# buffers are copied
|
||||
shadow_buffers[name].copy_(buffer)
|
||||
|
||||
#fd A hack to allow non-DDP models to be passed into the Trainer
|
||||
def no_sync(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.training:
|
||||
return self.model(*args, **kwargs)
|
||||
else:
|
||||
return self.shadow(*args, **kwargs)
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
5
rf2aa/training/checkpoint.py
Normal file
5
rf2aa/training/checkpoint.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# for gradient checkpointing
|
||||
def create_custom_forward(module, **kwargs):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
71
rf2aa/training/recycling.py
Normal file
71
rf2aa/training/recycling.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from contextlib import ExitStack
|
||||
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
||||
def recycle_step_legacy(ddp_model, input, n_cycle, use_amp, nograds=False, force_device=None):
|
||||
if force_device is not None:
|
||||
gpu = force_device
|
||||
else:
|
||||
gpu = ddp_model.device
|
||||
|
||||
xyz_prev, alpha_prev, mask_recycle = \
|
||||
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
for i_cycle in range(n_cycle):
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||
if i_cycle < n_cycle -1 or nograds is True:
|
||||
stack.enter_context(torch.no_grad())
|
||||
if force_device is None:
|
||||
stack.enter_context(ddp_model.no_sync())
|
||||
return_raw = (i_cycle < n_cycle -1)
|
||||
use_checkpoint = not nograds and (i_cycle == n_cycle -1)
|
||||
|
||||
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||
output_i = ddp_model(**input_i)
|
||||
return output_i
|
||||
|
||||
|
||||
def run_model_forward_legacy(model, network_input, device="cpu"):
|
||||
""" run model forward pass, no recycling or ddp with legacy model (for tests)"""
|
||||
gpu = device
|
||||
xyz_prev, alpha_prev, mask_recycle = \
|
||||
network_input["xyz_prev"], network_input["alpha_prev"], network_input["mask_recycle"]
|
||||
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||
input_i = add_recycle_inputs(network_input, output_i, 0, gpu, return_raw=False, use_checkpoint=False)
|
||||
input_i["seq_unmasked"] = input_i["seq_unmasked"].to(gpu)
|
||||
input_i["sctors"] = input_i["sctors"].to(gpu)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
output_i = model(**input_i)
|
||||
|
||||
return output_i
|
||||
|
||||
def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
|
||||
input_i = {}
|
||||
for key in network_input:
|
||||
if key in ['msa_latent', 'msa_full', 'seq']:
|
||||
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
|
||||
else:
|
||||
input_i[key] = network_input[key]
|
||||
|
||||
L = input_i["msa_latent"].shape[2]
|
||||
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
|
||||
xyz_prev = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
|
||||
|
||||
input_i['msa_prev'] = msa_prev
|
||||
input_i['pair_prev'] = pair_prev
|
||||
input_i['xyz'] = xyz_prev
|
||||
input_i['mask_recycle'] = mask_recycle
|
||||
input_i['sctors'] = alpha
|
||||
input_i['return_raw'] = return_raw
|
||||
input_i['use_checkpoint'] = use_checkpoint
|
||||
|
||||
input_i.pop('xyz_prev')
|
||||
input_i.pop('alpha_prev')
|
||||
return input_i
|
||||
Reference in New Issue
Block a user