Spaces:
Runtime error
Runtime error
| # Copyright 2017 Google, Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A trainable ADAM optimizer that learns its internal variables.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import numpy as np | |
| import tensorflow as tf | |
| from learned_optimizer.optimizer import trainable_optimizer as opt | |
| from learned_optimizer.optimizer import utils | |
| class TrainableAdam(opt.TrainableOptimizer): | |
| """Adam optimizer with learnable scalar parameters. | |
| See Kingma et. al., 2014 for algorithm (http://arxiv.org/abs/1412.6980). | |
| """ | |
| def __init__(self, | |
| learning_rate=1e-3, | |
| beta1=0.9, | |
| beta2=0.999, | |
| epsilon=1e-8, | |
| **kwargs): | |
| """Initializes the TrainableAdam optimizer with the given initial values. | |
| Args: | |
| learning_rate: The learning rate (default: 1e-3). | |
| beta1: The exponential decay rate for the 1st moment estimates. | |
| beta2: The exponential decay rate for the 2nd moment estimates. | |
| epsilon: A small constant for numerical stability. | |
| **kwargs: Any additional keyword arguments for TrainableOptimizer. | |
| Raises: | |
| ValueError: if the learning rate or epsilon is not positive | |
| ValueError: if beta1 or beta2 is not in (0, 1). | |
| """ | |
| if learning_rate <= 0: | |
| raise ValueError("Learning rate must be positive.") | |
| if epsilon <= 0: | |
| raise ValueError("Epsilon must be positive.") | |
| if not 0 < beta1 < 1 or not 0 < beta2 < 1: | |
| raise ValueError("Beta values must be between 0 and 1, exclusive.") | |
| self._reuse_vars = False | |
| with tf.variable_scope(opt.OPTIMIZER_SCOPE): | |
| def inv_sigmoid(x): | |
| return np.log(x / (1.0 - x)) | |
| self.log_learning_rate = tf.get_variable( | |
| "log_learning_rate", | |
| shape=[], | |
| initializer=tf.constant_initializer(np.log(learning_rate))) | |
| self.beta1_logit = tf.get_variable( | |
| "beta1_logit", | |
| shape=[], | |
| initializer=tf.constant_initializer(inv_sigmoid(beta1))) | |
| self.beta2_logit = tf.get_variable( | |
| "beta2_logit", | |
| shape=[], | |
| initializer=tf.constant_initializer(inv_sigmoid(beta2))) | |
| self.log_epsilon = tf.get_variable( | |
| "log_epsilon", | |
| shape=[], | |
| initializer=tf.constant_initializer(np.log(epsilon))) | |
| # Key names are derived from Algorithm 1 described in | |
| # https://arxiv.org/pdf/1412.6980.pdf | |
| state_keys = ["m", "v", "t"] | |
| super(TrainableAdam, self).__init__("Adam", state_keys, **kwargs) | |
| def _initialize_state(self, var): | |
| """Returns a dictionary mapping names of state variables to their values.""" | |
| vectorized_shape = var.get_shape().num_elements(), 1 | |
| return {key: tf.zeros(vectorized_shape) for key in self.state_keys} | |
| def _compute_update(self, param, grad, state): | |
| """Calculates the new internal state and parameters. | |
| If the gradient is sparse, updates the appropriate slices in the internal | |
| state and stacks the update tensor. | |
| Args: | |
| param: A tensor of parameters. | |
| grad: A tensor of gradients with the same shape as param. | |
| state: A dictionary containing any state for the optimizer. | |
| Returns: | |
| updated_param: The updated parameters. | |
| updated_state: The updated state variables in a dictionary. | |
| """ | |
| with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope: | |
| if self._reuse_vars: | |
| scope.reuse_variables() | |
| else: | |
| self._reuse_vars = True | |
| (grad_values, first_moment, second_moment, timestep, grad_indices | |
| ) = self._extract_gradients_and_internal_state( | |
| grad, state, tf.shape(param)) | |
| beta1 = tf.nn.sigmoid(self.beta1_logit) | |
| beta2 = tf.nn.sigmoid(self.beta2_logit) | |
| epsilon = tf.exp(self.log_epsilon) + 1e-10 | |
| learning_rate = tf.exp(self.log_learning_rate) | |
| old_grad_shape = tf.shape(grad_values) | |
| grad_values = tf.reshape(grad_values, [-1, 1]) | |
| new_timestep = timestep + 1 | |
| new_first_moment = self._update_adam_estimate( | |
| first_moment, grad_values, beta1) | |
| new_second_moment = self._debias_adam_estimate( | |
| second_moment, tf.square(grad_values), beta2) | |
| debiased_first_moment = self._debias_adam_estimate( | |
| new_first_moment, beta1, new_timestep) | |
| debiased_second_moment = self._debias_adam_estimate( | |
| new_second_moment, beta2, new_timestep) | |
| # Propagating through the square root of 0 is very bad for stability. | |
| update = (learning_rate * debiased_first_moment / | |
| (tf.sqrt(debiased_second_moment + 1e-10) + epsilon)) | |
| update = tf.reshape(update, old_grad_shape) | |
| if grad_indices is not None: | |
| param_shape = tf.shape(param) | |
| update = utils.stack_tensor( | |
| update, grad_indices, param, param_shape[:1]) | |
| new_first_moment = utils.update_slices( | |
| new_first_moment, grad_indices, state["m"], param_shape) | |
| new_second_moment = utils.update_slices( | |
| new_second_moment, grad_indices, state["v"], param_shape) | |
| new_timestep = utils.update_slices( | |
| new_timestep, grad_indices, state["t"], param_shape) | |
| new_param = param - update | |
| # collect the update and new state | |
| new_state = { | |
| "m": new_first_moment, | |
| "v": new_second_moment, | |
| "t": new_timestep | |
| } | |
| return new_param, new_state | |
| def _update_adam_estimate(self, estimate, value, beta): | |
| """Returns a beta-weighted average of estimate and value.""" | |
| return (beta * estimate) + ((1 - beta) * value) | |
| def _debias_adam_estimate(self, estimate, beta, t_step): | |
| """Returns a debiased estimate based on beta and the timestep.""" | |
| return estimate / (1 - tf.pow(beta, t_step)) | |
| def _extract_gradients_and_internal_state(self, grad, state, param_shape): | |
| """Extracts the gradients and relevant internal state. | |
| If the gradient is sparse, extracts the appropriate slices from the state. | |
| Args: | |
| grad: The current gradient. | |
| state: The current state. | |
| param_shape: The shape of the parameter (used if gradient is sparse). | |
| Returns: | |
| grad_values: The gradient value tensor. | |
| first_moment: The first moment tensor (internal state). | |
| second_moment: The second moment tensor (internal state). | |
| timestep: The current timestep (internal state). | |
| grad_indices: The indices for the gradient tensor, if sparse. | |
| None otherwise. | |
| """ | |
| grad_values = grad | |
| grad_indices = None | |
| first_moment = state["m"] | |
| second_moment = state["v"] | |
| timestep = state["t"] | |
| if isinstance(grad, tf.IndexedSlices): | |
| grad_indices, grad_values = utils.accumulate_sparse_gradients(grad) | |
| first_moment = utils.slice_tensor( | |
| first_moment, grad_indices, param_shape) | |
| second_moment = utils.slice_tensor( | |
| second_moment, grad_indices, param_shape) | |
| timestep = utils.slice_tensor(timestep, grad_indices, param_shape) | |
| return grad_values, first_moment, second_moment, timestep, grad_indices | |