FlexSED / src /desed_task /utils /schedulers.py
OpenSound's picture
Upload 544 files
3b6a091 verified
import numpy as np
import torch
# Copied from https://github.com/asteroid-team/asteroid/blob/master/asteroid/engine/schedulers.py
# Copied since it is the last function we still use from asteroid (and avoid other dependencies)
class BaseScheduler(object):
'''Base class for the step-wise scheduler logic.
Args:
optimizer (Optimize): Optimizer instance to apply lr schedule on.
Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler.
'''
def __init__(self, optimizer):
self.optimizer = optimizer
self.step_num = 0
def zero_grad(self):
self.optimizer.zero_grad()
def _get_lr(self):
raise NotImplementedError
def _set_lr(self, lr):
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def step(self, metrics=None, epoch=None):
'''Update step-wise learning rate before optimizer.step.'''
self.step_num += 1
lr = self._get_lr()
self._set_lr(lr)
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
def as_tensor(self, start=0, stop=100_000):
'''Returns the scheduler values from start to stop.'''
lr_list = []
for _ in range(start, stop):
self.step_num += 1
lr_list.append(self._get_lr())
self.step_num = 0
return torch.tensor(lr_list)
def plot(self, start=0, stop=100_000): # noqa
'''Plot the scheduler values from start to stop.'''
import matplotlib.pyplot as plt
all_lr = self.as_tensor(start=start, stop=stop)
plt.plot(all_lr.numpy())
plt.show()
class ExponentialWarmup(BaseScheduler):
""" Scheduler to apply ramp-up during training to the learning rate.
Args:
optimizer: torch.optimizer.Optimizer, the optimizer from which to rampup the value from
max_lr: float, the maximum learning to use at the end of ramp-up.
rampup_length: int, the length of the rampup (number of steps).
exponent: float, the exponent to be used.
"""
def __init__(self, optimizer, max_lr, rampup_length, exponent=-5.0):
super().__init__(optimizer)
self.rampup_len = rampup_length
self.max_lr = max_lr
self.step_num = 1
self.exponent = exponent
def _get_scaling_factor(self):
if self.rampup_len == 0:
return 1.0
else:
current = np.clip(self.step_num, 0.0, self.rampup_len)
phase = 1.0 - current / self.rampup_len
return float(np.exp(self.exponent * phase * phase))
def _get_lr(self):
return self.max_lr * self._get_scaling_factor()