Spaces:
Runtime error
Runtime error
| from pytorch_lightning import Callback | |
| import copy | |
| import itertools | |
| import torch | |
| import contextlib | |
| from torch.distributed.fsdp import FullyShardedDataParallel | |
| class EMACallback(Callback): | |
| def __init__( | |
| self, | |
| module_attr_name, | |
| ema_module_attr_name, | |
| decay=0.999, | |
| start_ema_step=0, | |
| init_ema_random=True, | |
| ): | |
| super().__init__() | |
| self.decay = decay | |
| self.module_attr_name = module_attr_name | |
| self.ema_module_attr_name = ema_module_attr_name | |
| self.start_ema_step = start_ema_step | |
| self.init_ema_random = init_ema_random | |
| def on_train_start(self, trainer, pl_module): | |
| if pl_module.global_step == 0: | |
| if not hasattr(pl_module, self.module_attr_name): | |
| raise ValueError( | |
| f"Module {pl_module} does not have attribute {self.module_attr_name}" | |
| ) | |
| if not hasattr(pl_module, self.ema_module_attr_name): | |
| pl_module.add_module( | |
| self.ema_module_attr_name, | |
| copy.deepcopy(getattr(pl_module, self.module_attr_name)) | |
| .eval() | |
| .requires_grad_(False), | |
| ) | |
| self.reset_ema(pl_module) | |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
| if pl_module.global_step == self.start_ema_step: | |
| self.reset_ema(pl_module) | |
| elif ( | |
| pl_module.global_step < self.start_ema_step | |
| and pl_module.global_step % 100 == 0 | |
| ): | |
| ## slow ema updates for visualisation | |
| self.update_ema(pl_module, decay=0.9) | |
| elif pl_module.global_step > self.start_ema_step: | |
| self.update_ema(pl_module, decay=self.decay) | |
| def update_ema(self, pl_module, decay=0.999): | |
| ema_module = getattr(pl_module, self.ema_module_attr_name) | |
| module = getattr(pl_module, self.module_attr_name) | |
| context_manager = self.get_model_context_manager(module) | |
| with context_manager: | |
| with torch.no_grad(): | |
| ema_params = ema_module.state_dict() | |
| for name, param in itertools.chain( | |
| module.named_parameters(), module.named_buffers() | |
| ): | |
| if name in ema_params: | |
| if param.requires_grad: | |
| ema_params[name].copy_( | |
| ema_params[name].detach().lerp(param.detach(), decay) | |
| ) | |
| def get_model_context_manager(self, module): | |
| fsdp_enabled = is_model_fsdp(module) | |
| model_context_manager = contextlib.nullcontext() | |
| if fsdp_enabled: | |
| model_context_manager = module.summon_full_params(module) | |
| return model_context_manager | |
| def reset_ema(self, pl_module): | |
| ema_module = getattr(pl_module, self.ema_module_attr_name) | |
| if self.init_ema_random: | |
| ema_module.init_weights() | |
| else: | |
| module = getattr(pl_module, self.module_attr_name) | |
| context_manager = self.get_model_context_manager(module) | |
| with context_manager: | |
| ema_params = ema_module.state_dict() | |
| for name, param in itertools.chain( | |
| module.named_parameters(), module.named_buffers() | |
| ): | |
| if name in ema_params: | |
| ema_params[name].copy_(param.detach()) | |
| def is_model_fsdp(model: torch.nn.Module) -> bool: | |
| try: | |
| if isinstance(model, FullyShardedDataParallel): | |
| return True | |
| # Check if model is wrapped with FSDP | |
| for _, obj in model.named_children(): | |
| if isinstance(obj, FullyShardedDataParallel): | |
| return True | |
| return False | |
| except ImportError: | |
| return False | |