Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| # import torch | |
| import math | |
| from torch.optim.optimizer import Optimizer, required | |
| class Fromage(Optimizer): | |
| r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)""" | |
| def __init__(self, params, lr=required, momentum=0): | |
| if lr is not required and lr < 0.0: | |
| raise ValueError("Invalid learning rate: {}".format(lr)) | |
| defaults = dict(lr=lr, momentum=momentum) | |
| super(Fromage, self).__init__(params, defaults) | |
| def step(self, closure=None): | |
| r"""Performs a single optimization step. | |
| Args: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| for group in self.param_groups: | |
| for p in group['params']: | |
| if p.grad is None: | |
| continue | |
| d_p = p.grad.data | |
| d_p_norm = p.grad.norm() | |
| p_norm = p.norm() | |
| if p_norm > 0.0 and d_p_norm > 0.0: | |
| p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm)) | |
| else: | |
| p.data.add_(-group['lr'], d_p) | |
| p.data /= math.sqrt(1 + group['lr'] ** 2) | |
| return loss | |