Spaces:
Sleeping
Sleeping
| import torch | |
| class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |
| """ Implements the learning rate schedule defined in the AlphaFold 2 | |
| supplement. A linear warmup is followed by a plateau at the maximum | |
| learning rate and then exponential decay. | |
| Note that the initial learning rate of the optimizer in question is | |
| ignored; use this class' base_lr parameter to specify the starting | |
| point of the warmup. | |
| """ | |
| def __init__(self, | |
| optimizer, | |
| last_epoch: int = -1, | |
| verbose: bool = False, | |
| base_lr: float = 0., | |
| max_lr: float = 0.001, | |
| warmup_no_steps: int = 1000, | |
| start_decay_after_n_steps: int = 50000, | |
| decay_every_n_steps: int = 50000, | |
| decay_factor: float = 0.95, | |
| ): | |
| step_counts = { | |
| "warmup_no_steps": warmup_no_steps, | |
| "start_decay_after_n_steps": start_decay_after_n_steps, | |
| } | |
| for k,v in step_counts.items(): | |
| if(v < 0): | |
| raise ValueError(f"{k} must be nonnegative") | |
| if(warmup_no_steps > start_decay_after_n_steps): | |
| raise ValueError( | |
| "warmup_no_steps must not exceed start_decay_after_n_steps" | |
| ) | |
| self.optimizer = optimizer | |
| self.last_epoch = last_epoch | |
| self.verbose = verbose | |
| self.base_lr = base_lr | |
| self.max_lr = max_lr | |
| self.warmup_no_steps = warmup_no_steps | |
| self.start_decay_after_n_steps = start_decay_after_n_steps | |
| self.decay_every_n_steps = decay_every_n_steps | |
| self.decay_factor = decay_factor | |
| super(AlphaFoldLRScheduler, self).__init__( | |
| optimizer, | |
| last_epoch=last_epoch, | |
| verbose=verbose, | |
| ) | |
| def state_dict(self): | |
| state_dict = { | |
| k:v for k,v in self.__dict__.items() if k not in ["optimizer"] | |
| } | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| self.__dict__.update(state_dict) | |
| def get_lr(self): | |
| if(not self._get_lr_called_within_step): | |
| raise RuntimeError( | |
| "To get the last learning rate computed by the scheduler, use " | |
| "get_last_lr()" | |
| ) | |
| step_no = self.last_epoch | |
| if(step_no <= self.warmup_no_steps): | |
| lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr | |
| elif(step_no > self.start_decay_after_n_steps): | |
| steps_since_decay = step_no - self.start_decay_after_n_steps | |
| exp = (steps_since_decay // self.decay_every_n_steps) + 1 | |
| lr = self.max_lr * (self.decay_factor ** exp) | |
| else: # plateau | |
| lr = self.max_lr | |
| return [lr for group in self.optimizer.param_groups] | |