# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Generic Gymnasium environment server implementation.""" from __future__ import annotations import logging import uuid from typing import Any, Dict, Optional import numpy as np try: import gymnasium as gym from gymnasium import spaces except ImportError: raise ValueError("Please install gymnasium with: pip install gymnasium") from core.env_server import Environment from ..models import GymAction, GymObservation, GymState logger = logging.getLogger(__name__) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) class GymnasiumEnvironment(Environment): """ Generic Gymnasium environment wrapper for OpenEnv. Any Gymnasium environment can be served by providing its environment id. The wrapper handles common concerns such as seed management, type conversion, and JSON-friendly serialization of observations. """ def __init__( self, env_id: str, render_mode: Optional[str] = None, max_steps: Optional[int] = None, seed: Optional[int] = None, **gym_kwargs, ): super().__init__() self.env_id = env_id self.render_mode = render_mode self.max_steps = max_steps if max_steps and max_steps > 0 else None self._initial_seed = seed self._next_seed = seed logger.info( "Creating Gymnasium environment '%s' (render_mode=%s, max_steps=%s, seed=%s)", env_id, render_mode, self.max_steps, seed, ) self.env = gym.make(env_id, render_mode=render_mode, **gym_kwargs) if self.max_steps is not None: self.env = gym.wrappers.TimeLimit( self.env, max_episode_steps=self.max_steps ) self._action_space_metadata = self._describe_space(self.env.action_space) self._observation_space_metadata = self._describe_space( self.env.observation_space ) self._legal_actions = self._summarize_action_space(self.env.action_space) self._state = GymState( env_id=env_id, render_mode=render_mode, max_steps=self.max_steps, seed=seed, ) logger.info("GymnasiumEnvironment for '%s' initialized", env_id) def reset(self) -> GymObservation: """Reset the environment and return the initial observation.""" seed = self._consume_seed() obs, info = self.env.reset(seed=seed) self._state.episode_id = str(uuid.uuid4()) self._state.step_count = 0 self._state.episode_length = 0 self._state.total_reward = 0.0 self._state.seed = seed observation = self._make_observation( obs=obs, reward=None, done=False, info=info, terminated=False, truncated=False, raw_reward=0.0, ) logger.info( "Environment '%s' reset (episode_id=%s, seed=%s)", self.env_id, self._state.episode_id, seed, ) return observation def step(self, action: GymAction) -> GymObservation: """Execute an action and return the resulting observation.""" gym_action = self._convert_action(action) obs, reward, terminated, truncated, info = self.env.step(gym_action) self._state.step_count += 1 self._state.episode_length += 1 reward_value, raw_reward = self._normalize_reward(reward) if reward_value is not None: self._state.total_reward += reward_value done = bool(terminated or truncated) observation = self._make_observation( obs=obs, reward=reward_value, done=done, info=info, terminated=terminated, truncated=truncated, raw_reward=raw_reward, ) logger.debug( "Step %s -> reward=%s terminated=%s truncated=%s", self._state.step_count, reward, terminated, truncated, ) return observation @property def state(self) -> GymState: """Return the current environment state.""" return self._state def close(self) -> None: """Close the underlying Gymnasium environment.""" logger.info("Closing GymnasiumEnvironment for '%s'", self.env_id) if hasattr(self.env, "close"): self.env.close() logger.info("GymnasiumEnvironment closed") # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _consume_seed(self) -> Optional[int]: if self._next_seed is None: return None seed = self._next_seed self._next_seed += 1 return seed def _convert_action(self, action: GymAction) -> Any: if not isinstance(action, GymAction): raise ValueError(f"Expected GymAction, received {type(action)}") raw_action = action.action space = self.env.action_space converted = self._convert_action_for_space(space, raw_action) if not space.contains(converted): raise ValueError( f"Action {raw_action!r} could not be converted for space {space}" ) return converted def _convert_action_for_space(self, space: spaces.Space, value: Any) -> Any: if isinstance(space, spaces.Discrete): return int(value) if isinstance(space, spaces.MultiDiscrete): return np.asarray(value, dtype=space.dtype) if isinstance(space, spaces.MultiBinary): return np.asarray(value, dtype=space.dtype) if isinstance(space, spaces.Box): return np.asarray(value, dtype=space.dtype) if isinstance(space, spaces.Tuple): if not isinstance(value, (list, tuple)): raise TypeError( f"Tuple action space expects list/tuple, received {type(value)}" ) if len(value) != len(space.spaces): raise ValueError( f"Tuple action with length {len(value)} does not match " f"expected length {len(space.spaces)}" ) return tuple( self._convert_action_for_space(subspace, subvalue) for subspace, subvalue in zip(space.spaces, value) ) if isinstance(space, spaces.Dict): if not isinstance(value, dict): raise TypeError( f"Dict action space expects dict, received {type(value)}" ) return { key: self._convert_action_for_space(space.spaces[key], value[key]) for key in space.spaces } if isinstance(space, spaces.Text): return str(value) return value def _normalize_reward(self, reward: Any) -> tuple[Optional[float], Any]: if isinstance(reward, (int, float)): value = float(reward) return value, value if isinstance(reward, (np.integer, np.floating)): value = float(reward.item()) return value, value return None, self._to_serializable(reward) def _make_observation( self, obs: Any, reward: Optional[float], done: bool, info: Dict[str, Any], terminated: bool, truncated: bool, raw_reward: Any, ) -> GymObservation: metadata = { "env_id": self.env_id, "render_mode": self.render_mode, "max_steps": self.max_steps, "seed": self._state.seed, "info": self._to_serializable(info), "raw_reward": raw_reward, "terminated": terminated, "truncated": truncated, "action_space": self._action_space_metadata, "observation_space": self._observation_space_metadata, } # Remove keys with None values for cleaner payloads metadata = {key: value for key, value in metadata.items() if value is not None} return GymObservation( state=self._to_serializable(obs), legal_actions=self._legal_actions, episode_length=self._state.episode_length, total_reward=self._state.total_reward, done=done, reward=reward, metadata=metadata, ) def _describe_space(self, space: spaces.Space) -> Dict[str, Any]: description: Dict[str, Any] = {"type": type(space).__name__} if hasattr(space, "shape"): description["shape"] = self._to_serializable(getattr(space, "shape")) dtype = getattr(space, "dtype", None) if dtype is not None: description["dtype"] = str(dtype) if isinstance(space, spaces.Discrete): description["n"] = int(space.n) elif isinstance(space, spaces.MultiDiscrete): description["nvec"] = self._to_serializable(space.nvec) elif isinstance(space, spaces.MultiBinary): description["n"] = self._to_serializable(space.n) elif isinstance(space, spaces.Box): description["low"] = self._to_serializable(space.low) description["high"] = self._to_serializable(space.high) elif isinstance(space, spaces.Tuple): description["spaces"] = [ self._describe_space(subspace) for subspace in space.spaces ] elif isinstance(space, spaces.Dict): description["spaces"] = { key: self._describe_space(subspace) for key, subspace in space.spaces.items() } elif isinstance(space, spaces.Text): description["min_length"] = space.min_length description["max_length"] = space.max_length return description def _summarize_action_space(self, space: spaces.Space) -> Any: if isinstance(space, spaces.Discrete): return list(range(int(space.n))) if isinstance(space, spaces.MultiDiscrete): return [list(range(int(n))) for n in self._to_serializable(space.nvec)] if isinstance(space, spaces.MultiBinary): return [0, 1] if isinstance(space, spaces.Box): return { "low": self._to_serializable(space.low), "high": self._to_serializable(space.high), } if isinstance(space, spaces.Tuple): return [self._summarize_action_space(subspace) for subspace in space.spaces] if isinstance(space, spaces.Dict): return { key: self._summarize_action_space(subspace) for key, subspace in space.spaces.items() } if isinstance(space, spaces.Text): return {"charset": "unicode"} return None def _to_serializable(self, value: Any) -> Any: if isinstance(value, np.ndarray): return [self._to_serializable(v) for v in value.tolist()] if isinstance(value, (np.floating, np.integer)): return self._to_serializable(value.item()) if isinstance(value, np.bool_): return bool(value) if isinstance(value, (list, tuple, set)): return [self._to_serializable(v) for v in value] if isinstance(value, dict): return {str(k): self._to_serializable(v) for k, v in value.items()} if isinstance(value, (int, bool, float)) or value is None: return value return str(value)