Spaces:
Runtime error
Runtime error
| import logging | |
| from pytorch_lightning.callbacks import Callback | |
| import torch | |
| log = logging.getLogger(__name__) | |
| class FixNANinGrad(Callback): | |
| def __init__(self, monitor): | |
| super().__init__() | |
| self.monitor = monitor | |
| self.continuous_nan_batchs = 0 | |
| def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: | |
| has_nan = [] | |
| is_inf = [] | |
| for name, param in pl_module.named_parameters(): | |
| if param.grad is not None: | |
| if torch.isnan(param.grad).any(): | |
| has_nan.append(name) | |
| if torch.isinf(param.grad).any(): | |
| is_inf.append(name) | |
| torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) | |
| if len(has_nan) > 0: | |
| print(f"Found NaN in {has_nan}") | |
| if len(is_inf) > 0: | |
| print(f"Found Inf in {is_inf}") | |
| def on_train_batch_end( | |
| self, | |
| trainer, | |
| pl_module, | |
| outputs, | |
| batch, | |
| batch_idx, | |
| ) -> None: | |
| logs = trainer.callback_metrics | |
| i = 0 | |
| found_metric = False | |
| while i < len(self.monitor) and not found_metric: | |
| if self.monitor[i] in logs.keys(): | |
| current = logs[self.monitor[i]].squeeze() | |
| found_metric = True | |
| else: | |
| i += 1 | |
| if not found_metric: | |
| raise ValueError("Asked metric not in logs") | |
| if not torch.isfinite(current): | |
| self.continuous_nan_batchs += 1 | |
| if self.continuous_nan_batchs >= 5: | |
| trainer.should_stop = True | |
| log.info("Training interrupted because of NaN in {self.monitor}") | |
| else: | |
| self.continuous_nan_batchs = 0 | |