Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import math | |
| from datetime import timedelta | |
| import torch | |
| import wandb | |
| from wandb import AlertLevel | |
| from torch.utils.tensorboard import SummaryWriter | |
| from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \ | |
| is_master, get_rank | |
| from imaginaire.utils.distributed import master_only_print as print | |
| LOG_WRITER = None | |
| LOG_DIR = None | |
| def sn_reshape_weight_to_matrix(weight): | |
| r"""Reshape weight to obtain the matrix form. | |
| Args: | |
| weight (Parameters): pytorch layer parameter tensor. | |
| """ | |
| weight_mat = weight | |
| height = weight_mat.size(0) | |
| return weight_mat.reshape(height, -1) | |
| def get_weight_stats(mod): | |
| r"""Get weight state | |
| Args: | |
| mod: Pytorch module | |
| """ | |
| if mod.weight_orig.grad is not None: | |
| grad_norm = mod.weight_orig.grad.data.norm().item() | |
| else: | |
| grad_norm = 0. | |
| weight_norm = mod.weight_orig.data.norm().item() | |
| weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig) | |
| sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v)) | |
| return grad_norm, weight_norm, sigma | |
| def set_summary_writer(log_dir): | |
| r"""Set summary writer | |
| Args: | |
| log_dir (str): Log directory. | |
| """ | |
| global LOG_DIR, LOG_WRITER | |
| LOG_DIR = log_dir | |
| LOG_WRITER = SummaryWriter(log_dir=log_dir) | |
| def write_summary(name, summary, step, hist=False): | |
| """Utility function for write summary to log_writer. | |
| """ | |
| global LOG_WRITER | |
| lw = LOG_WRITER | |
| if lw is None: | |
| raise Exception("Log writer not set.") | |
| if hist: | |
| lw.add_histogram(name, summary, step) | |
| else: | |
| lw.add_scalar(name, summary, step) | |
| class Meter(object): | |
| """Meter is to keep track of statistics along steps. | |
| Meters write values for purpose like printing average values. | |
| Meters can be flushed to log files (i.e. TensorBoard for now) | |
| regularly. | |
| Args: | |
| name (str): the name of meter | |
| reduce (bool): If ``True``, perform a distributed reduce for the log | |
| values across all GPUs. | |
| """ | |
| def __init__(self, name, reduce=True): | |
| self.name = name | |
| self.reduce = reduce | |
| self.values = [] | |
| def reset(self): | |
| r"""Reset the meter values""" | |
| if not self.reduce and get_rank() != 0: | |
| return | |
| self.values = [] | |
| def write(self, value): | |
| r"""Record the value""" | |
| if not self.reduce and get_rank() != 0: | |
| return | |
| if value is not None: | |
| self.values.append(value) | |
| def flush(self, step): | |
| r"""Write the value in the tensorboard. | |
| Args: | |
| step (int): Epoch or iteration number. | |
| """ | |
| if not self.reduce and get_rank() != 0: | |
| return | |
| values = torch.tensor(self.values, device="cuda") | |
| if self.reduce: | |
| values = dist_all_reduce_tensor(values) | |
| if not all(math.isfinite(x) for x in values): | |
| print("meter {} contained a nan or inf.".format(self.name)) | |
| if is_master(): | |
| wandb.alert( | |
| title='NaN', | |
| text=f'Meter {self.name} contained a nan or inf.', | |
| level=AlertLevel.WARN, | |
| wait_duration=timedelta(minutes=120) | |
| ) | |
| filtered_values = list(filter(lambda x: math.isfinite(x), self.values)) | |
| if float(len(filtered_values)) != 0: | |
| value = float(sum(filtered_values)) / float(len(filtered_values)) | |
| if is_master(): | |
| write_summary(self.name, value, step) | |
| wandb.log({self.name: value}, step=step) | |
| self.reset() | |
| def write_image(self, img_grid, step): | |
| r"""Write the value in the tensorboard. | |
| Args: | |
| img_grid: | |
| step (int): Epoch or iteration number. | |
| """ | |
| if not self.reduce and get_rank() != 0: | |
| return | |
| global LOG_WRITER | |
| lw = LOG_WRITER | |
| if lw is None: | |
| raise Exception("Log writer not set.") | |
| lw.add_image("Visualizations", img_grid, step) | |