Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Protocol, TypedDict | |
| from .types import Action, Observation, State | |
| class Message(TypedDict): | |
| """A message in a conversation. | |
| Compatible with Huggingface chat template format. | |
| """ | |
| role: str | |
| content: str | |
| class ModelTokenizer(Protocol): | |
| """Protocol for tokenizers that support chat templates. | |
| This protocol defines the interface that tokenizers must implement | |
| to work with chat-based environments. It's compatible with | |
| Huggingface transformers tokenizers. | |
| """ | |
| def apply_chat_template( | |
| self, | |
| conversation: list[Message], | |
| tokenize: bool = True, | |
| return_tensors: str | None = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """Apply a chat template to format and optionally tokenize a conversation. | |
| Args: | |
| conversation: List of message dictionaries with 'role' and 'content' | |
| tokenize: Whether to tokenize the output | |
| return_tensors: Format for returned tensors ('pt' for PyTorch) | |
| **kwargs: Additional arguments | |
| Returns: | |
| Formatted and optionally tokenized conversation | |
| """ | |
| ... | |
| def decode( | |
| self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any | |
| ) -> str: | |
| """Decode token IDs back to text. | |
| Args: | |
| token_ids: Token IDs to decode | |
| skip_special_tokens: Whether to skip special tokens in output | |
| **kwargs: Additional arguments | |
| Returns: | |
| Decoded text string | |
| """ | |
| ... | |
| class Transform(ABC): | |
| """Transform observations to add rewards, metrics, or other modifications. | |
| Transforms follow the TorchRL pattern where they take an observation | |
| and return a (potentially modified) observation. This allows for | |
| flexible reward computation and observation augmentation. | |
| """ | |
| def __call__(self, observation: Observation) -> Observation: | |
| """Transform an observation. | |
| Args: | |
| observation: The input observation | |
| Returns: | |
| The transformed observation | |
| """ | |
| pass | |
| class Environment(ABC): | |
| """Base class for all environment servers following Gym/Gymnasium API. | |
| Args: | |
| transform: Optional transform to apply to observations | |
| """ | |
| def __init__(self, transform: Transform | None = None): | |
| self.transform = transform | |
| def reset(self) -> Observation: | |
| """Reset the environment and return initial observation.""" | |
| pass | |
| def step(self, action: Action) -> Observation: | |
| """Take a step in the environment.""" | |
| pass | |
| def state(self) -> State: | |
| """Get the current environment state.""" | |
| pass | |
| def _apply_transform(self, observation: Observation) -> Observation: | |
| """Apply transform if one is provided.""" | |
| if self.transform is not None: | |
| return self.transform(observation) | |
| return observation | |