Spaces:
Build error
Build error
| from torch.optim.lr_scheduler import _LRScheduler | |
| class PolyScheduler(_LRScheduler): | |
| def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): | |
| self.base_lr = base_lr | |
| self.warmup_lr_init = 0.0001 | |
| self.max_steps: int = max_steps | |
| self.warmup_steps: int = warmup_steps | |
| self.power = 2 | |
| super(PolyScheduler, self).__init__(optimizer, -1, False) | |
| self.last_epoch = last_epoch | |
| def get_warmup_lr(self): | |
| alpha = float(self.last_epoch) / float(self.warmup_steps) | |
| return [self.base_lr * alpha for _ in self.optimizer.param_groups] | |
| def get_lr(self): | |
| if self.last_epoch == -1: | |
| return [self.warmup_lr_init for _ in self.optimizer.param_groups] | |
| if self.last_epoch < self.warmup_steps: | |
| return self.get_warmup_lr() | |
| else: | |
| alpha = pow( | |
| 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps), | |
| self.power, | |
| ) | |
| return [self.base_lr * alpha for _ in self.optimizer.param_groups] | |