import torch import torch.nn as nn from collections import OrderedDict from copy import deepcopy import contextlib class EMA(nn.Module): def __init__(self, model, decay): super().__init__() self.decay = decay self.model = model self.shadow = deepcopy(self.model) for param in self.shadow.parameters(): param.detach_() @torch.no_grad() def update(self): if not self.training: print("EMA update should only be called during training", file=stderr, flush=True) return model_params = OrderedDict(self.model.named_parameters()) shadow_params = OrderedDict(self.shadow.named_parameters()) # check if both model contains the same set of keys assert model_params.keys() == shadow_params.keys() for name, param in model_params.items(): # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage # shadow_variable -= (1 - decay) * (shadow_variable - variable) if param.requires_grad: shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param)) model_buffers = OrderedDict(self.model.named_buffers()) shadow_buffers = OrderedDict(self.shadow.named_buffers()) # check if both model contains the same set of keys assert model_buffers.keys() == shadow_buffers.keys() for name, buffer in model_buffers.items(): # buffers are copied shadow_buffers[name].copy_(buffer) #fd A hack to allow non-DDP models to be passed into the Trainer def no_sync(self): return contextlib.nullcontext() def forward(self, *args, **kwargs): if self.training: return self.model(*args, **kwargs) else: return self.shadow(*args, **kwargs) def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)