Spaces:
Running
on
L4
Running
on
L4
| from collections import OrderedDict | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| from dockformerpp.utils.tensor_utils import tensor_tree_map | |
| class ExponentialMovingAverage: | |
| """ | |
| Maintains moving averages of parameters with exponential decay | |
| At each step, the stored copy `copy` of each parameter `param` is | |
| updated as follows: | |
| `copy = decay * copy + (1 - decay) * param` | |
| where `decay` is an attribute of the ExponentialMovingAverage object. | |
| """ | |
| def __init__(self, model: nn.Module, decay: float): | |
| """ | |
| Args: | |
| model: | |
| A torch.nn.Module whose parameters are to be tracked | |
| decay: | |
| A value (usually close to 1.) by which updates are | |
| weighted as part of the above formula | |
| """ | |
| super(ExponentialMovingAverage, self).__init__() | |
| clone_param = lambda t: t.clone().detach() | |
| self.params = tensor_tree_map(clone_param, model.state_dict()) | |
| self.decay = decay | |
| self.device = next(model.parameters()).device | |
| def to(self, device): | |
| self.params = tensor_tree_map(lambda t: t.to(device), self.params) | |
| self.device = device | |
| def _update_state_dict_(self, update, state_dict): | |
| with torch.no_grad(): | |
| for k, v in update.items(): | |
| stored = state_dict[k] | |
| if not isinstance(v, torch.Tensor): | |
| self._update_state_dict_(v, stored) | |
| else: | |
| diff = stored - v | |
| diff *= 1 - self.decay | |
| stored -= diff | |
| def update(self, model: torch.nn.Module) -> None: | |
| """ | |
| Updates the stored parameters using the state dict of the provided | |
| module. The module should have the same structure as that used to | |
| initialize the ExponentialMovingAverage object. | |
| """ | |
| self._update_state_dict_(model.state_dict(), self.params) | |
| def load_state_dict(self, state_dict: OrderedDict) -> None: | |
| for k in state_dict["params"].keys(): | |
| self.params[k] = state_dict["params"][k].clone() | |
| self.decay = state_dict["decay"] | |
| def state_dict(self) -> OrderedDict: | |
| return OrderedDict( | |
| { | |
| "params": self.params, | |
| "decay": self.decay, | |
| } | |
| ) | |