Spaces:
Runtime error
Runtime error
| from copy import deepcopy | |
| from collections import OrderedDict | |
| import torch | |
| class ModelEma: | |
| def __init__(self, model, decay=0.9999, device=''): | |
| self.ema = deepcopy(model) | |
| self.ema.eval() | |
| self.decay = decay | |
| self.device = device | |
| if device: | |
| self.ema.to(device=device) | |
| self.ema_is_dp = hasattr(self.ema, 'module') | |
| for p in self.ema.parameters(): | |
| p.requires_grad_(False) | |
| def load_checkpoint(self, checkpoint): | |
| if isinstance(checkpoint, str): | |
| checkpoint = torch.load(checkpoint) | |
| assert isinstance(checkpoint, dict) | |
| if 'model_ema' in checkpoint: | |
| new_state_dict = OrderedDict() | |
| for k, v in checkpoint['model_ema'].items(): | |
| if self.ema_is_dp: | |
| name = k if k.startswith('module') else 'module.' + k | |
| else: | |
| name = k.replace('module.', '') if k.startswith('module') else k | |
| new_state_dict[name] = v | |
| self.ema.load_state_dict(new_state_dict) | |
| def state_dict(self): | |
| return self.ema.state_dict() | |
| def update(self, model): | |
| pre_module = hasattr(model, 'module') and not self.ema_is_dp | |
| with torch.no_grad(): | |
| curr_msd = model.state_dict() | |
| for k, ema_v in self.ema.state_dict().items(): | |
| k = 'module.' + k if pre_module else k | |
| model_v = curr_msd[k].detach() | |
| if self.device: | |
| model_v = model_v.to(device=self.device) | |
| ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) | |