Spaces:
Paused
Paused
File size: 2,241 Bytes
ebf11c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
"""Collection classes for managing multiple tools."""
from typing import Any, Dict, List
from app.exceptions import ToolError
from app.logger import logger
from app.tool.base import BaseTool, ToolFailure, ToolResult
class ToolCollection:
"""A collection of defined tools."""
class Config:
arbitrary_types_allowed = True
def __init__(self, *tools: BaseTool):
self.tools = tools
self.tool_map = {tool.name: tool for tool in tools}
def __iter__(self):
return iter(self.tools)
def to_params(self) -> List[Dict[str, Any]]:
return [tool.to_param() for tool in self.tools]
async def execute(
self, *, name: str, tool_input: Dict[str, Any] = None
) -> ToolResult:
tool = self.tool_map.get(name)
if not tool:
return ToolFailure(error=f"Tool {name} is invalid")
try:
result = await tool(**tool_input)
return result
except ToolError as e:
return ToolFailure(error=e.message)
async def execute_all(self) -> List[ToolResult]:
"""Execute all tools in the collection sequentially."""
results = []
for tool in self.tools:
try:
result = await tool()
results.append(result)
except ToolError as e:
results.append(ToolFailure(error=e.message))
return results
def get_tool(self, name: str) -> BaseTool:
return self.tool_map.get(name)
def add_tool(self, tool: BaseTool):
"""Add a single tool to the collection.
If a tool with the same name already exists, it will be skipped and a warning will be logged.
"""
if tool.name in self.tool_map:
logger.warning(f"Tool {tool.name} already exists in collection, skipping")
return self
self.tools += (tool,)
self.tool_map[tool.name] = tool
return self
def add_tools(self, *tools: BaseTool):
"""Add multiple tools to the collection.
If any tool has a name conflict with an existing tool, it will be skipped and a warning will be logged.
"""
for tool in tools:
self.add_tool(tool)
return self
|