Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| """Schedule functions for controlling hparams over time.""" | |
| from abc import ABCMeta | |
| from abc import abstractmethod | |
| import math | |
| from common import config_lib # brain coder | |
| class Schedule(object): | |
| """Schedule is a function which sets a hyperparameter's value over time. | |
| For example, a schedule can be used to decay an hparams, or oscillate it over | |
| time. | |
| This object is constructed with an instance of config_lib.Config (will be | |
| specific to each class implementation). For example if this is a decay | |
| schedule, the config may specify the rate of decay and decay start time. Then | |
| the object instance is called like a function, mapping global step (an integer | |
| counting how many calls to the train op have been made) to the hparam value. | |
| Properties of a schedule function f(t): | |
| 0) Domain of t is the non-negative integers (t may be 0). | |
| 1) Range of f is the reals. | |
| 2) Schedule functions can assume that they will be called in time order. This | |
| allows schedules to be stateful. | |
| 3) Schedule functions should be deterministic. Two schedule instances with the | |
| same config must always give the same value for each t, and regardless of | |
| what t's it was previously called on. Users may call f(t) on arbitrary | |
| (positive) time jumps. Essentially, multiple schedule instances used in | |
| replica training will behave the same. | |
| 4) Duplicate successive calls on the same time are allowed. | |
| """ | |
| __metaclass__ = ABCMeta | |
| def __init__(self, config): | |
| """Construct this schedule with a config specific to each class impl. | |
| Args: | |
| config: An instance of config_lib.Config. | |
| """ | |
| pass | |
| def __call__(self, global_step): | |
| """Map `global_step` to a value. | |
| `global_step` is an integer counting how many calls to the train op have | |
| been made across all replicas (hence why it is global). Implementations | |
| may assume calls to be made in time order, i.e. `global_step` now >= | |
| previous `global_step` values. | |
| Args: | |
| global_step: Non-negative integer. | |
| Returns: | |
| Hparam value at this step. A number. | |
| """ | |
| pass | |
| class ConstSchedule(Schedule): | |
| """Constant function. | |
| config: | |
| const: Constant value at every step. | |
| f(t) = const. | |
| """ | |
| def __init__(self, config): | |
| super(ConstSchedule, self).__init__(config) | |
| self.const = config.const | |
| def __call__(self, global_step): | |
| return self.const | |
| class LinearDecaySchedule(Schedule): | |
| """Linear decay function. | |
| config: | |
| initial: Decay starts from this value. | |
| final: Decay ends at this value. | |
| start_time: Step when decay starts. Constant before it. | |
| end_time: When decay ends. Constant after it. | |
| f(t) is a linear function when start_time <= t <= end_time, with slope of | |
| (final - initial) / (end_time - start_time). f(t) = initial | |
| when t <= start_time. f(t) = final when t >= end_time. | |
| If start_time == end_time, this becomes a step function. | |
| """ | |
| def __init__(self, config): | |
| super(LinearDecaySchedule, self).__init__(config) | |
| self.initial = config.initial | |
| self.final = config.final | |
| self.start_time = config.start_time | |
| self.end_time = config.end_time | |
| if self.end_time < self.start_time: | |
| raise ValueError('start_time must be before end_time.') | |
| # Linear interpolation. | |
| self._time_diff = float(self.end_time - self.start_time) | |
| self._diff = float(self.final - self.initial) | |
| self._slope = ( | |
| self._diff / self._time_diff if self._time_diff > 0 else float('inf')) | |
| def __call__(self, global_step): | |
| if global_step <= self.start_time: | |
| return self.initial | |
| if global_step > self.end_time: | |
| return self.final | |
| return self.initial + (global_step - self.start_time) * self._slope | |
| class ExponentialDecaySchedule(Schedule): | |
| """Exponential decay function. | |
| See https://en.wikipedia.org/wiki/Exponential_decay. | |
| Use this decay function to decay over orders of magnitude. For example, to | |
| decay learning rate from 1e-2 to 1e-6. Exponential decay will decay the | |
| exponent linearly. | |
| config: | |
| initial: Decay starts from this value. | |
| final: Decay ends at this value. | |
| start_time: Step when decay starts. Constant before it. | |
| end_time: When decay ends. Constant after it. | |
| f(t) is an exponential decay function when start_time <= t <= end_time. The | |
| decay rate and amplitude are chosen so that f(t) = initial when | |
| t = start_time, and f(t) = final when t = end_time. f(t) is constant for | |
| t < start_time or t > end_time. initial and final must be positive values. | |
| If start_time == end_time, this becomes a step function. | |
| """ | |
| def __init__(self, config): | |
| super(ExponentialDecaySchedule, self).__init__(config) | |
| self.initial = config.initial | |
| self.final = config.final | |
| self.start_time = config.start_time | |
| self.end_time = config.end_time | |
| if self.initial <= 0 or self.final <= 0: | |
| raise ValueError('initial and final must be positive numbers.') | |
| # Linear interpolation in log space. | |
| self._linear_fn = LinearDecaySchedule( | |
| config_lib.Config( | |
| initial=math.log(self.initial), | |
| final=math.log(self.final), | |
| start_time=self.start_time, | |
| end_time=self.end_time)) | |
| def __call__(self, global_step): | |
| return math.exp(self._linear_fn(global_step)) | |
| class SmootherstepDecaySchedule(Schedule): | |
| """Smootherstep decay function. | |
| A sigmoidal like transition from initial to final values. A smoother | |
| transition than linear and exponential decays, hence the name. | |
| See https://en.wikipedia.org/wiki/Smoothstep. | |
| config: | |
| initial: Decay starts from this value. | |
| final: Decay ends at this value. | |
| start_time: Step when decay starts. Constant before it. | |
| end_time: When decay ends. Constant after it. | |
| f(t) is fully defined here: | |
| https://en.wikipedia.org/wiki/Smoothstep#Variations. | |
| f(t) is smooth, as in its first-derivative exists everywhere. | |
| """ | |
| def __init__(self, config): | |
| super(SmootherstepDecaySchedule, self).__init__(config) | |
| self.initial = config.initial | |
| self.final = config.final | |
| self.start_time = config.start_time | |
| self.end_time = config.end_time | |
| if self.end_time < self.start_time: | |
| raise ValueError('start_time must be before end_time.') | |
| self._time_diff = float(self.end_time - self.start_time) | |
| self._diff = float(self.final - self.initial) | |
| def __call__(self, global_step): | |
| if global_step <= self.start_time: | |
| return self.initial | |
| if global_step > self.end_time: | |
| return self.final | |
| x = (global_step - self.start_time) / self._time_diff | |
| # Smootherstep | |
| return self.initial + x * x * x * (x * (x * 6 - 15) + 10) * self._diff | |
| class HardOscillatorSchedule(Schedule): | |
| """Hard oscillator function. | |
| config: | |
| high: Max value of the oscillator. Value at constant plateaus. | |
| low: Min value of the oscillator. Value at constant valleys. | |
| start_time: Global step when oscillation starts. Constant before this. | |
| period: Width of one oscillation, i.e. number of steps over which the | |
| oscillation takes place. | |
| transition_fraction: Fraction of the period spent transitioning between high | |
| and low values. 50% of this time is spent rising, and 50% of this time | |
| is spent falling. 50% of the remaining time is spent constant at the | |
| high value, and 50% of the remaining time is spent constant at the low | |
| value. transition_fraction = 1.0 means the entire period is spent | |
| rising and falling. transition_fraction = 0.0 means no time is spent | |
| rising and falling, i.e. the function jumps instantaneously between | |
| high and low. | |
| f(t) = high when t < start_time. | |
| f(t) is periodic when t >= start_time, with f(t + period) = f(t). | |
| f(t) is linear with positive slope when rising, and negative slope when | |
| falling. At the start of the period t0, f(t0) = high and begins to descend. | |
| At the middle of the period f is low and is constant until the ascension | |
| begins. f then rises from low to high and is constant again until the period | |
| repeats. | |
| Note: when transition_fraction is 0, f starts the period low and ends high. | |
| """ | |
| def __init__(self, config): | |
| super(HardOscillatorSchedule, self).__init__(config) | |
| self.high = config.high | |
| self.low = config.low | |
| self.start_time = config.start_time | |
| self.period = float(config.period) | |
| self.transition_fraction = config.transition_fraction | |
| self.half_transition_fraction = config.transition_fraction / 2.0 | |
| if self.transition_fraction < 0 or self.transition_fraction > 1.0: | |
| raise ValueError('transition_fraction must be between 0 and 1.0') | |
| if self.period <= 0: | |
| raise ValueError('period must be positive') | |
| self._slope = ( | |
| float(self.high - self.low) / self.half_transition_fraction | |
| if self.half_transition_fraction > 0 else float('inf')) | |
| def __call__(self, global_step): | |
| if global_step < self.start_time: | |
| return self.high | |
| period_pos = ((global_step - self.start_time) / self.period) % 1.0 | |
| if period_pos >= 0.5: | |
| # ascending | |
| period_pos -= 0.5 | |
| if period_pos < self.half_transition_fraction: | |
| return self.low + period_pos * self._slope | |
| else: | |
| return self.high | |
| else: | |
| # descending | |
| if period_pos < self.half_transition_fraction: | |
| return self.high - period_pos * self._slope | |
| else: | |
| return self.low | |
| _NAME_TO_CONFIG = { | |
| 'const': ConstSchedule, | |
| 'linear_decay': LinearDecaySchedule, | |
| 'exp_decay': ExponentialDecaySchedule, | |
| 'smooth_decay': SmootherstepDecaySchedule, | |
| 'hard_osc': HardOscillatorSchedule, | |
| } | |
| def make_schedule(config): | |
| """Schedule factory. | |
| Given `config` containing a `fn` property, a Schedule implementation is | |
| instantiated with `config`. See `_NAME_TO_CONFIG` for `fn` options. | |
| Args: | |
| config: Config with a `fn` option that specifies which Schedule | |
| implementation to use. `config` is passed into the constructor. | |
| Returns: | |
| A Schedule impl instance. | |
| """ | |
| schedule_class = _NAME_TO_CONFIG[config.fn] | |
| return schedule_class(config) | |