Initial commit: RoseTTAFold-All-Atom configured for Wes with Harbor images and s3:// paths
This commit is contained in:
60
rf2aa/training/EMA.py
Normal file
60
rf2aa/training/EMA.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import contextlib
|
||||
|
||||
class EMA(nn.Module):
|
||||
|
||||
def __init__(self, model, decay):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
|
||||
self.model = model
|
||||
self.shadow = deepcopy(self.model)
|
||||
|
||||
for param in self.shadow.parameters():
|
||||
param.detach_()
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
if not self.training:
|
||||
print("EMA update should only be called during training", file=stderr, flush=True)
|
||||
return
|
||||
|
||||
model_params = OrderedDict(self.model.named_parameters())
|
||||
shadow_params = OrderedDict(self.shadow.named_parameters())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_params.keys() == shadow_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
||||
if param.requires_grad:
|
||||
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
|
||||
|
||||
model_buffers = OrderedDict(self.model.named_buffers())
|
||||
shadow_buffers = OrderedDict(self.shadow.named_buffers())
|
||||
|
||||
# check if both model contains the same set of keys
|
||||
assert model_buffers.keys() == shadow_buffers.keys()
|
||||
|
||||
for name, buffer in model_buffers.items():
|
||||
# buffers are copied
|
||||
shadow_buffers[name].copy_(buffer)
|
||||
|
||||
#fd A hack to allow non-DDP models to be passed into the Trainer
|
||||
def no_sync(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.training:
|
||||
return self.model(*args, **kwargs)
|
||||
else:
|
||||
return self.shadow(*args, **kwargs)
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
Reference in New Issue
Block a user