Spaces:
Runtime error
Runtime error
File size: 3,314 Bytes
7d5289a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Protocol, TypedDict
from .types import Action, Observation, State
class Message(TypedDict):
"""A message in a conversation.
Compatible with Huggingface chat template format.
"""
role: str
content: str
class ModelTokenizer(Protocol):
"""Protocol for tokenizers that support chat templates.
This protocol defines the interface that tokenizers must implement
to work with chat-based environments. It's compatible with
Huggingface transformers tokenizers.
"""
def apply_chat_template(
self,
conversation: list[Message],
tokenize: bool = True,
return_tensors: str | None = None,
**kwargs: Any,
) -> Any:
"""Apply a chat template to format and optionally tokenize a conversation.
Args:
conversation: List of message dictionaries with 'role' and 'content'
tokenize: Whether to tokenize the output
return_tensors: Format for returned tensors ('pt' for PyTorch)
**kwargs: Additional arguments
Returns:
Formatted and optionally tokenized conversation
"""
...
def decode(
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
) -> str:
"""Decode token IDs back to text.
Args:
token_ids: Token IDs to decode
skip_special_tokens: Whether to skip special tokens in output
**kwargs: Additional arguments
Returns:
Decoded text string
"""
...
class Transform(ABC):
"""Transform observations to add rewards, metrics, or other modifications.
Transforms follow the TorchRL pattern where they take an observation
and return a (potentially modified) observation. This allows for
flexible reward computation and observation augmentation.
"""
@abstractmethod
def __call__(self, observation: Observation) -> Observation:
"""Transform an observation.
Args:
observation: The input observation
Returns:
The transformed observation
"""
pass
class Environment(ABC):
"""Base class for all environment servers following Gym/Gymnasium API.
Args:
transform: Optional transform to apply to observations
"""
def __init__(self, transform: Transform | None = None):
self.transform = transform
@abstractmethod
def reset(self) -> Observation:
"""Reset the environment and return initial observation."""
pass
@abstractmethod
def step(self, action: Action) -> Observation:
"""Take a step in the environment."""
pass
@property
@abstractmethod
def state(self) -> State:
"""Get the current environment state."""
pass
def _apply_transform(self, observation: Observation) -> Observation:
"""Apply transform if one is provided."""
if self.transform is not None:
return self.transform(observation)
return observation
|