burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
7d5289a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Protocol, TypedDict
from .types import Action, Observation, State
class Message(TypedDict):
"""A message in a conversation.
Compatible with Huggingface chat template format.
"""
role: str
content: str
class ModelTokenizer(Protocol):
"""Protocol for tokenizers that support chat templates.
This protocol defines the interface that tokenizers must implement
to work with chat-based environments. It's compatible with
Huggingface transformers tokenizers.
"""
def apply_chat_template(
self,
conversation: list[Message],
tokenize: bool = True,
return_tensors: str | None = None,
**kwargs: Any,
) -> Any:
"""Apply a chat template to format and optionally tokenize a conversation.
Args:
conversation: List of message dictionaries with 'role' and 'content'
tokenize: Whether to tokenize the output
return_tensors: Format for returned tensors ('pt' for PyTorch)
**kwargs: Additional arguments
Returns:
Formatted and optionally tokenized conversation
"""
...
def decode(
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
) -> str:
"""Decode token IDs back to text.
Args:
token_ids: Token IDs to decode
skip_special_tokens: Whether to skip special tokens in output
**kwargs: Additional arguments
Returns:
Decoded text string
"""
...
class Transform(ABC):
"""Transform observations to add rewards, metrics, or other modifications.
Transforms follow the TorchRL pattern where they take an observation
and return a (potentially modified) observation. This allows for
flexible reward computation and observation augmentation.
"""
@abstractmethod
def __call__(self, observation: Observation) -> Observation:
"""Transform an observation.
Args:
observation: The input observation
Returns:
The transformed observation
"""
pass
class Environment(ABC):
"""Base class for all environment servers following Gym/Gymnasium API.
Args:
transform: Optional transform to apply to observations
"""
def __init__(self, transform: Transform | None = None):
self.transform = transform
@abstractmethod
def reset(self) -> Observation:
"""Reset the environment and return initial observation."""
pass
@abstractmethod
def step(self, action: Action) -> Observation:
"""Take a step in the environment."""
pass
@property
@abstractmethod
def state(self) -> State:
"""Get the current environment state."""
pass
def _apply_transform(self, observation: Observation) -> Observation:
"""Apply transform if one is provided."""
if self.transform is not None:
return self.transform(observation)
return observation