|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class BaseScheduler(object): |
|
|
'''Base class for the step-wise scheduler logic. |
|
|
Args: |
|
|
optimizer (Optimize): Optimizer instance to apply lr schedule on. |
|
|
Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler. |
|
|
''' |
|
|
def __init__(self, optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.step_num = 0 |
|
|
def zero_grad(self): |
|
|
self.optimizer.zero_grad() |
|
|
def _get_lr(self): |
|
|
raise NotImplementedError |
|
|
def _set_lr(self, lr): |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] = lr |
|
|
def step(self, metrics=None, epoch=None): |
|
|
'''Update step-wise learning rate before optimizer.step.''' |
|
|
self.step_num += 1 |
|
|
lr = self._get_lr() |
|
|
self._set_lr(lr) |
|
|
def load_state_dict(self, state_dict): |
|
|
self.__dict__.update(state_dict) |
|
|
def state_dict(self): |
|
|
return {key: value for key, value in self.__dict__.items() if key != "optimizer"} |
|
|
def as_tensor(self, start=0, stop=100_000): |
|
|
'''Returns the scheduler values from start to stop.''' |
|
|
lr_list = [] |
|
|
for _ in range(start, stop): |
|
|
self.step_num += 1 |
|
|
lr_list.append(self._get_lr()) |
|
|
self.step_num = 0 |
|
|
return torch.tensor(lr_list) |
|
|
def plot(self, start=0, stop=100_000): |
|
|
'''Plot the scheduler values from start to stop.''' |
|
|
import matplotlib.pyplot as plt |
|
|
all_lr = self.as_tensor(start=start, stop=stop) |
|
|
plt.plot(all_lr.numpy()) |
|
|
plt.show() |
|
|
|
|
|
class ExponentialWarmup(BaseScheduler): |
|
|
""" Scheduler to apply ramp-up during training to the learning rate. |
|
|
Args: |
|
|
optimizer: torch.optimizer.Optimizer, the optimizer from which to rampup the value from |
|
|
max_lr: float, the maximum learning to use at the end of ramp-up. |
|
|
rampup_length: int, the length of the rampup (number of steps). |
|
|
exponent: float, the exponent to be used. |
|
|
""" |
|
|
|
|
|
def __init__(self, optimizer, max_lr, rampup_length, exponent=-5.0): |
|
|
super().__init__(optimizer) |
|
|
self.rampup_len = rampup_length |
|
|
self.max_lr = max_lr |
|
|
self.step_num = 1 |
|
|
self.exponent = exponent |
|
|
|
|
|
def _get_scaling_factor(self): |
|
|
|
|
|
if self.rampup_len == 0: |
|
|
return 1.0 |
|
|
else: |
|
|
|
|
|
current = np.clip(self.step_num, 0.0, self.rampup_len) |
|
|
phase = 1.0 - current / self.rampup_len |
|
|
return float(np.exp(self.exponent * phase * phase)) |
|
|
|
|
|
def _get_lr(self): |
|
|
return self.max_lr * self._get_scaling_factor() |
|
|
|