Spaces:
Paused
Paused
File size: 5,752 Bytes
ebf11c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel, Field
from app.utils.logger import logger
# class BaseTool(ABC, BaseModel):
# name: str
# description: str
# parameters: Optional[dict] = None
# class Config:
# arbitrary_types_allowed = True
# async def __call__(self, **kwargs) -> Any:
# """Execute the tool with given parameters."""
# return await self.execute(**kwargs)
# @abstractmethod
# async def execute(self, **kwargs) -> Any:
# """Execute the tool with given parameters."""
# def to_param(self) -> Dict:
# """Convert tool to function call format."""
# return {
# "type": "function",
# "function": {
# "name": self.name,
# "description": self.description,
# "parameters": self.parameters,
# },
# }
class ToolResult(BaseModel):
"""Represents the result of a tool execution."""
output: Any = Field(default=None)
error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None)
class Config:
arbitrary_types_allowed = True
def __bool__(self):
return any(getattr(self, field) for field in self.__fields__)
def __add__(self, other: "ToolResult"):
def combine_fields(
field: Optional[str], other_field: Optional[str], concatenate: bool = True
):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
def __str__(self):
return f"Error: {self.error}" if self.error else self.output
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
# return self.copy(update=kwargs)
return type(self)(**{**self.dict(), **kwargs})
class BaseTool(ABC, BaseModel):
"""Consolidated base class for all tools combining BaseModel and Tool functionality.
Provides:
- Pydantic model validation
- Schema registration
- Standardized result handling
- Abstract execution interface
Attributes:
name (str): Tool name
description (str): Tool description
parameters (dict): Tool parameters schema
_schemas (Dict[str, List[ToolSchema]]): Registered method schemas
"""
name: str
description: str
parameters: Optional[dict] = None
# _schemas: Dict[str, List[ToolSchema]] = {}
class Config:
arbitrary_types_allowed = True
underscore_attrs_are_private = False
# def __init__(self, **data):
# """Initialize tool with model validation and schema registration."""
# super().__init__(**data)
# logger.debug(f"Initializing tool class: {self.__class__.__name__}")
# self._register_schemas()
# def _register_schemas(self):
# """Register schemas from all decorated methods."""
# for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
# if hasattr(method, 'tool_schemas'):
# self._schemas[name] = method.tool_schemas
# logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
async def __call__(self, **kwargs) -> Any:
"""Execute the tool with given parameters."""
return await self.execute(**kwargs)
@abstractmethod
async def execute(self, **kwargs) -> Any:
"""Execute the tool with given parameters."""
def to_param(self) -> Dict:
"""Convert tool to function call format.
Returns:
Dictionary with tool metadata in OpenAI function calling format
"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
}
# def get_schemas(self) -> Dict[str, List[ToolSchema]]:
# """Get all registered tool schemas.
# Returns:
# Dict mapping method names to their schema definitions
# """
# return self._schemas
def success_response(self, data: Union[Dict[str, Any], str]) -> ToolResult:
"""Create a successful tool result.
Args:
data: Result data (dictionary or string)
Returns:
ToolResult with success=True and formatted output
"""
if isinstance(data, str):
text = data
else:
text = json.dumps(data, indent=2)
logger.debug(f"Created success response for {self.__class__.__name__}")
return ToolResult(output=text)
def fail_response(self, msg: str) -> ToolResult:
"""Create a failed tool result.
Args:
msg: Error message describing the failure
Returns:
ToolResult with success=False and error message
"""
logger.debug(f"Tool {self.__class__.__name__} returned failed result: {msg}")
return ToolResult(error=msg)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
|