Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A circular buffer where each element is a list of tensors. | |
| Each element of the buffer is a list of tensors. An example use case is a replay | |
| buffer in reinforcement learning, where each element is a list of tensors | |
| representing the state, action, reward etc. | |
| New elements are added sequentially, and once the buffer is full, we | |
| start overwriting them in a circular fashion. Reading does not remove any | |
| elements, only adding new elements does. | |
| """ | |
| import collections | |
| import numpy as np | |
| import tensorflow as tf | |
| import gin.tf | |
| class CircularBuffer(object): | |
| """A circular buffer where each element is a list of tensors.""" | |
| def __init__(self, buffer_size=1000, scope='replay_buffer'): | |
| """Circular buffer of list of tensors. | |
| Args: | |
| buffer_size: (integer) maximum number of tensor lists the buffer can hold. | |
| scope: (string) variable scope for creating the variables. | |
| """ | |
| self._buffer_size = np.int64(buffer_size) | |
| self._scope = scope | |
| self._tensors = collections.OrderedDict() | |
| with tf.variable_scope(self._scope): | |
| self._num_adds = tf.Variable(0, dtype=tf.int64, name='num_adds') | |
| self._num_adds_cs = tf.CriticalSection(name='num_adds') | |
| def buffer_size(self): | |
| return self._buffer_size | |
| def scope(self): | |
| return self._scope | |
| def num_adds(self): | |
| return self._num_adds | |
| def _create_variables(self, tensors): | |
| with tf.variable_scope(self._scope): | |
| for name in tensors.keys(): | |
| tensor = tensors[name] | |
| self._tensors[name] = tf.get_variable( | |
| name='BufferVariable_' + name, | |
| shape=[self._buffer_size] + tensor.get_shape().as_list(), | |
| dtype=tensor.dtype, | |
| trainable=False) | |
| def _validate(self, tensors): | |
| """Validate shapes of tensors.""" | |
| if len(tensors) != len(self._tensors): | |
| raise ValueError('Expected tensors to have %d elements. Received %d ' | |
| 'instead.' % (len(self._tensors), len(tensors))) | |
| if self._tensors.keys() != tensors.keys(): | |
| raise ValueError('The keys of tensors should be the always the same.' | |
| 'Received %s instead %s.' % | |
| (tensors.keys(), self._tensors.keys())) | |
| for name, tensor in tensors.items(): | |
| if tensor.get_shape().as_list() != self._tensors[ | |
| name].get_shape().as_list()[1:]: | |
| raise ValueError('Tensor %s has incorrect shape.' % name) | |
| if not tensor.dtype.is_compatible_with(self._tensors[name].dtype): | |
| raise ValueError( | |
| 'Tensor %s has incorrect data type. Expected %s, received %s' % | |
| (name, self._tensors[name].read_value().dtype, tensor.dtype)) | |
| def add(self, tensors): | |
| """Adds an element (list/tuple/dict of tensors) to the buffer. | |
| Args: | |
| tensors: (list/tuple/dict of tensors) to be added to the buffer. | |
| Returns: | |
| An add operation that adds the input `tensors` to the buffer. Similar to | |
| an enqueue_op. | |
| Raises: | |
| ValueError: If the shapes and data types of input `tensors' are not the | |
| same across calls to the add function. | |
| """ | |
| return self.maybe_add(tensors, True) | |
| def maybe_add(self, tensors, condition): | |
| """Adds an element (tensors) to the buffer based on the condition.. | |
| Args: | |
| tensors: (list/tuple of tensors) to be added to the buffer. | |
| condition: A boolean Tensor controlling whether the tensors would be added | |
| to the buffer or not. | |
| Returns: | |
| An add operation that adds the input `tensors` to the buffer. Similar to | |
| an maybe_enqueue_op. | |
| Raises: | |
| ValueError: If the shapes and data types of input `tensors' are not the | |
| same across calls to the add function. | |
| """ | |
| if not isinstance(tensors, dict): | |
| names = [str(i) for i in range(len(tensors))] | |
| tensors = collections.OrderedDict(zip(names, tensors)) | |
| if not isinstance(tensors, collections.OrderedDict): | |
| tensors = collections.OrderedDict( | |
| sorted(tensors.items(), key=lambda t: t[0])) | |
| if not self._tensors: | |
| self._create_variables(tensors) | |
| else: | |
| self._validate(tensors) | |
| #@tf.critical_section(self._position_mutex) | |
| def _increment_num_adds(): | |
| # Adding 0 to the num_adds variable is a trick to read the value of the | |
| # variable and return a read-only tensor. Doing this in a critical | |
| # section allows us to capture a snapshot of the variable that will | |
| # not be affected by other threads updating num_adds. | |
| return self._num_adds.assign_add(1) + 0 | |
| def _add(): | |
| num_adds_inc = self._num_adds_cs.execute(_increment_num_adds) | |
| current_pos = tf.mod(num_adds_inc - 1, self._buffer_size) | |
| update_ops = [] | |
| for name in self._tensors.keys(): | |
| update_ops.append( | |
| tf.scatter_update(self._tensors[name], current_pos, tensors[name])) | |
| return tf.group(*update_ops) | |
| return tf.contrib.framework.smart_cond(condition, _add, tf.no_op) | |
| def get_random_batch(self, batch_size, keys=None, num_steps=1): | |
| """Samples a batch of tensors from the buffer with replacement. | |
| Args: | |
| batch_size: (integer) number of elements to sample. | |
| keys: List of keys of tensors to retrieve. If None retrieve all. | |
| num_steps: (integer) length of trajectories to return. If > 1 will return | |
| a list of lists, where each internal list represents a trajectory of | |
| length num_steps. | |
| Returns: | |
| A list of tensors, where each element in the list is a batch sampled from | |
| one of the tensors in the buffer. | |
| Raises: | |
| ValueError: If get_random_batch is called before calling the add function. | |
| tf.errors.InvalidArgumentError: If this operation is executed before any | |
| items are added to the buffer. | |
| """ | |
| if not self._tensors: | |
| raise ValueError('The add function must be called before get_random_batch.') | |
| if keys is None: | |
| keys = self._tensors.keys() | |
| latest_start_index = self.get_num_adds() - num_steps + 1 | |
| empty_buffer_assert = tf.Assert( | |
| tf.greater(latest_start_index, 0), | |
| ['Not enough elements have been added to the buffer.']) | |
| with tf.control_dependencies([empty_buffer_assert]): | |
| max_index = tf.minimum(self._buffer_size, latest_start_index) | |
| indices = tf.random_uniform( | |
| [batch_size], | |
| minval=0, | |
| maxval=max_index, | |
| dtype=tf.int64) | |
| if num_steps == 1: | |
| return self.gather(indices, keys) | |
| else: | |
| return self.gather_nstep(num_steps, indices, keys) | |
| def gather(self, indices, keys=None): | |
| """Returns elements at the specified indices from the buffer. | |
| Args: | |
| indices: (list of integers or rank 1 int Tensor) indices in the buffer to | |
| retrieve elements from. | |
| keys: List of keys of tensors to retrieve. If None retrieve all. | |
| Returns: | |
| A list of tensors, where each element in the list is obtained by indexing | |
| one of the tensors in the buffer. | |
| Raises: | |
| ValueError: If gather is called before calling the add function. | |
| tf.errors.InvalidArgumentError: If indices are bigger than the number of | |
| items in the buffer. | |
| """ | |
| if not self._tensors: | |
| raise ValueError('The add function must be called before calling gather.') | |
| if keys is None: | |
| keys = self._tensors.keys() | |
| with tf.name_scope('Gather'): | |
| index_bound_assert = tf.Assert( | |
| tf.less( | |
| tf.to_int64(tf.reduce_max(indices)), | |
| tf.minimum(self.get_num_adds(), self._buffer_size)), | |
| ['Index out of bounds.']) | |
| with tf.control_dependencies([index_bound_assert]): | |
| indices = tf.convert_to_tensor(indices) | |
| batch = [] | |
| for key in keys: | |
| batch.append(tf.gather(self._tensors[key], indices, name=key)) | |
| return batch | |
| def gather_nstep(self, num_steps, indices, keys=None): | |
| """Returns elements at the specified indices from the buffer. | |
| Args: | |
| num_steps: (integer) length of trajectories to return. | |
| indices: (list of rank num_steps int Tensor) indices in the buffer to | |
| retrieve elements from for multiple trajectories. Each Tensor in the | |
| list represents the indices for a trajectory. | |
| keys: List of keys of tensors to retrieve. If None retrieve all. | |
| Returns: | |
| A list of list-of-tensors, where each element in the list is obtained by | |
| indexing one of the tensors in the buffer. | |
| Raises: | |
| ValueError: If gather is called before calling the add function. | |
| tf.errors.InvalidArgumentError: If indices are bigger than the number of | |
| items in the buffer. | |
| """ | |
| if not self._tensors: | |
| raise ValueError('The add function must be called before calling gather.') | |
| if keys is None: | |
| keys = self._tensors.keys() | |
| with tf.name_scope('Gather'): | |
| index_bound_assert = tf.Assert( | |
| tf.less_equal( | |
| tf.to_int64(tf.reduce_max(indices) + num_steps), | |
| self.get_num_adds()), | |
| ['Trajectory indices go out of bounds.']) | |
| with tf.control_dependencies([index_bound_assert]): | |
| indices = tf.map_fn( | |
| lambda x: tf.mod(tf.range(x, x + num_steps), self._buffer_size), | |
| indices, | |
| dtype=tf.int64) | |
| batch = [] | |
| for key in keys: | |
| def SampleTrajectories(trajectory_indices, key=key, | |
| num_steps=num_steps): | |
| trajectory_indices.set_shape([num_steps]) | |
| return tf.gather(self._tensors[key], trajectory_indices, name=key) | |
| batch.append(tf.map_fn(SampleTrajectories, indices, | |
| dtype=self._tensors[key].dtype)) | |
| return batch | |
| def get_position(self): | |
| """Returns the position at which the last element was added. | |
| Returns: | |
| An int tensor representing the index at which the last element was added | |
| to the buffer or -1 if no elements were added. | |
| """ | |
| return tf.cond(self.get_num_adds() < 1, | |
| lambda: self.get_num_adds() - 1, | |
| lambda: tf.mod(self.get_num_adds() - 1, self._buffer_size)) | |
| def get_num_adds(self): | |
| """Returns the number of additions to the buffer. | |
| Returns: | |
| An int tensor representing the number of elements that were added. | |
| """ | |
| def num_adds(): | |
| return self._num_adds.value() | |
| return self._num_adds_cs.execute(num_adds) | |
| def get_num_tensors(self): | |
| """Returns the number of tensors (slots) in the buffer.""" | |
| return len(self._tensors) | |