Spaces:
Runtime error
Runtime error
| # Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| from environments.ant_maze_env import AntMazeEnv | |
| from environments.point_maze_env import PointMazeEnv | |
| import tensorflow as tf | |
| import gin.tf | |
| from tf_agents.environments import gym_wrapper | |
| from tf_agents.environments import tf_py_environment | |
| def create_maze_env(env_name=None, top_down_view=False): | |
| n_bins = 0 | |
| manual_collision = False | |
| if env_name.startswith('Ego'): | |
| n_bins = 8 | |
| env_name = env_name[3:] | |
| if env_name.startswith('Ant'): | |
| cls = AntMazeEnv | |
| env_name = env_name[3:] | |
| maze_size_scaling = 8 | |
| elif env_name.startswith('Point'): | |
| cls = PointMazeEnv | |
| manual_collision = True | |
| env_name = env_name[5:] | |
| maze_size_scaling = 4 | |
| else: | |
| assert False, 'unknown env %s' % env_name | |
| maze_id = None | |
| observe_blocks = False | |
| put_spin_near_agent = False | |
| if env_name == 'Maze': | |
| maze_id = 'Maze' | |
| elif env_name == 'Push': | |
| maze_id = 'Push' | |
| elif env_name == 'Fall': | |
| maze_id = 'Fall' | |
| elif env_name == 'Block': | |
| maze_id = 'Block' | |
| put_spin_near_agent = True | |
| observe_blocks = True | |
| elif env_name == 'BlockMaze': | |
| maze_id = 'BlockMaze' | |
| put_spin_near_agent = True | |
| observe_blocks = True | |
| else: | |
| raise ValueError('Unknown maze environment %s' % env_name) | |
| gym_mujoco_kwargs = { | |
| 'maze_id': maze_id, | |
| 'n_bins': n_bins, | |
| 'observe_blocks': observe_blocks, | |
| 'put_spin_near_agent': put_spin_near_agent, | |
| 'top_down_view': top_down_view, | |
| 'manual_collision': manual_collision, | |
| 'maze_size_scaling': maze_size_scaling | |
| } | |
| gym_env = cls(**gym_mujoco_kwargs) | |
| gym_env.reset() | |
| wrapped_env = gym_wrapper.GymWrapper(gym_env) | |
| return wrapped_env | |
| class TFPyEnvironment(tf_py_environment.TFPyEnvironment): | |
| def __init__(self, *args, **kwargs): | |
| super(TFPyEnvironment, self).__init__(*args, **kwargs) | |
| def start_collect(self): | |
| pass | |
| def current_obs(self): | |
| time_step = self.current_time_step() | |
| return time_step.observation[0] # For some reason, there is an extra dim. | |
| def step(self, actions): | |
| actions = tf.expand_dims(actions, 0) | |
| next_step = super(TFPyEnvironment, self).step(actions) | |
| return next_step.is_last()[0], next_step.reward[0], next_step.discount[0] | |
| def reset(self): | |
| return super(TFPyEnvironment, self).reset() | |