Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import Dict, Union, get_origin, get_args | |
| from pydantic import BaseModel, Field | |
| from types import UnionType | |
| import os | |
| import logging | |
| from src.vectorstore import VectorStore | |
| # from langchain.tools import tool | |
| class ToolBase(BaseModel, ABC): | |
| def invoke(cls, input: Dict): | |
| pass | |
| def to_openai_tool(cls): | |
| """ | |
| Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions. | |
| Formats it into a structure similar to OpenAI's function metadata. | |
| """ | |
| function_metadata = { | |
| "type": "function", | |
| "function": { | |
| "name": cls.__name__, # Function name is same as the class name, in lowercase | |
| "description": cls.__doc__.strip(), | |
| "parameters": { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [], | |
| }, | |
| }, | |
| } | |
| # Iterate over the fields to add them to the parameters | |
| for field_name, field_info in cls.model_fields.items(): | |
| # Field properties | |
| field_type = "string" # Default to string, will adjust if it's a different type | |
| annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation | |
| has_none = False | |
| if get_origin(annotation) is UnionType: # Check if it's a Union type | |
| args = get_args(annotation) | |
| if type(None) in args: | |
| has_none = True | |
| args = [arg for arg in args if type(None) != arg] | |
| if len(args) > 1: | |
| raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None") | |
| elif len(args) == 0: | |
| raise TypeError("There must be a valid type (str, int, bool, etc) not only None") | |
| else: | |
| annotation = args[0] | |
| if annotation == int: | |
| field_type = "integer" | |
| elif annotation == bool: | |
| field_type = "boolean" | |
| # Add the field's description and type to the properties | |
| function_metadata["function"]["parameters"]["properties"][field_name] = { | |
| "type": field_type, | |
| "description": field_info.description, | |
| } | |
| # Determine if the field is required (not Optional or None) | |
| if field_info.is_required(): | |
| function_metadata["function"]["parameters"]["required"].append(field_name) | |
| has_none = True | |
| # If there's an enum (like for `unit`), add it to the properties | |
| if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list): | |
| function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default | |
| if not has_none: | |
| function_metadata["function"]["parameters"]["required"].append(field_name) | |
| return function_metadata | |
| tools: Dict[str, ToolBase] = {} | |
| oitools = [] | |
| vector_store = VectorStore( | |
| # embeddings_model="BAAI/bge-m3", | |
| embeddings_model=os.environ.get("EMBEDDINGS_MODEL"), | |
| vs_local_path=os.environ.get("VS_LOCAL_PATH"), | |
| vs_hf_path=os.environ.get("VS_HF_PATH")) | |
| def tool_register(cls: BaseModel): | |
| oaitool = cls.to_openai_tool() | |
| oitools.append(oaitool) | |
| tools[oaitool["function"]["name"]] = cls | |
| class get_documents(ToolBase): | |
| """ | |
| Retrieves general information about a region, its cities, activities, tourism, or surrounding areas based on query. | |
| """ | |
| logging.info("@tool_register: get_documents()") | |
| query: str = Field(description="An enhanced user query optimized for retrieving information") | |
| logging.info(f"query: {query}") | |
| def invoke(cls, input: Dict) -> str: | |
| logging.info(f"get_documents.invoke() input: {input}") | |
| # Check if the input is a dictionary | |
| query = input.get("query", None) | |
| if not query: | |
| return "Missing required argument: query." | |
| # return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!" | |
| return vector_store.get_context(query) | |