Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any, Optional | |
| from gradio_client import Client | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from .utils import get_auth, process_content | |
| load_dotenv() | |
| class VectorStoreInterface(ABC): | |
| """Abstract interface for different vector store implementations.""" | |
| def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
| """Search for similar documents.""" | |
| pass | |
| class HuggingFaceSpacesVectorStore(VectorStoreInterface): | |
| """Vector store implementation for Hugging Face Spaces with MCP endpoints.""" | |
| def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None): | |
| repo_id = url | |
| logging.info(f"Connecting to Hugging Face Space: {repo_id}") | |
| if api_key: | |
| self.client = Client(repo_id, hf_token=api_key) | |
| else: | |
| self.client = Client(repo_id) | |
| self.collection_name = collection_name | |
| def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
| """Search using Hugging Face Spaces MCP API.""" | |
| try: | |
| # Use the /search_text endpoint as documented in the API | |
| result = self.client.predict( | |
| query=query, | |
| collection_name=self.collection_name, | |
| model_name=kwargs.get('model_name'), | |
| top_k=top_k, | |
| api_name="/search_text" | |
| ) | |
| logging.info(f"Successfully retrieved {len(result) if result else 0} documents") | |
| return result | |
| except Exception as e: | |
| logging.error(f"Error searching Hugging Face Spaces: {str(e)}") | |
| raise e | |
| class QdrantVectorStore(VectorStoreInterface): | |
| """Vector store implementation for direct Qdrant connection.""" | |
| def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None): | |
| from qdrant_client import QdrantClient | |
| from sentence_transformers import SentenceTransformer | |
| self.client = QdrantClient(host = url, | |
| # very important that port to be used for python client | |
| port=443, | |
| https=True, | |
| # api_key = QDRANT_API_KEY_READ, | |
| ## this is for write access | |
| api_key = api_key, | |
| timeout=120,) | |
| #self.client = QdrantClient( | |
| # url=url, # Use url parameter which handles full URLs with protocol | |
| # api_key=api_key | |
| #) | |
| self.collection_name = collection_name | |
| # Initialize embedding model as None - will be loaded on first use | |
| self._embedding_model = None | |
| self._current_model_name = None | |
| def _get_embedding_model(self, model_name: str = None): | |
| """Lazy load embedding model to avoid loading if not needed.""" | |
| if model_name is None: | |
| model_name = "BAAI/bge-m3" # Default from config | |
| # Only reload if model name changed | |
| if self._embedding_model is None or self._current_model_name != model_name: | |
| logging.info(f"Loading embedding model: {model_name}") | |
| from sentence_transformers import SentenceTransformer | |
| cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache")) | |
| cache_folder.mkdir(parents=True, exist_ok=True) | |
| self._embedding_model = SentenceTransformer( | |
| model_name, | |
| cache_folder=str(cache_folder) | |
| ) | |
| # self._embedding_model = SentenceTransformer(model_name) | |
| self._current_model_name = model_name | |
| logging.info(f"Successfully loaded embedding model: {model_name}") | |
| return self._embedding_model | |
| def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]: | |
| """Search using direct Qdrant connection.""" | |
| try: | |
| # Get embedding model | |
| model_name = kwargs.get('model_name') | |
| embedding_model = self._get_embedding_model(model_name) | |
| # Convert query to embedding | |
| logging.info(f"Converting query to embedding using model: {self._current_model_name}") | |
| query_embedding = embedding_model.encode(query).tolist() | |
| # Get filter from kwargs if provided | |
| filter_obj = kwargs.get('filter', None) | |
| # Perform vector search | |
| logging.info(f"Searching Qdrant collection '{self.collection_name}' for top {top_k} results") | |
| search_result = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| query_filter=filter_obj, # Add filter support | |
| limit=top_k, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| logging.info(search_result) | |
| # Format results to match expected output format | |
| results = [] | |
| for hit in search_result: | |
| raw_content = hit.payload.get('text', '') | |
| # Process content to handle malformed nested list structures | |
| processed_content = process_content(raw_content) | |
| result_dict = { | |
| 'answer': processed_content, | |
| 'answer_metadata': hit.payload.get('metadata', {}), | |
| 'score': hit.score | |
| } | |
| results.append(result_dict) | |
| logging.info(f"Successfully retrieved {len(results)} documents from Qdrant") | |
| return results | |
| except Exception as e: | |
| logging.error(f"Error searching Qdrant: {str(e)}") | |
| raise e | |
| def create_vectorstore(config: Any) -> VectorStoreInterface: | |
| """Factory function to create appropriate vector store based on configuration.""" | |
| vectorstore_type = config.get("vectorstore", "PROVIDER") | |
| # Get authentication config based on provider | |
| auth_config = get_auth(vectorstore_type.lower()) | |
| if vectorstore_type.lower() == "huggingface": | |
| url = config.get("vectorstore", "URL") | |
| collection_name = config.get("vectorstore", "COLLECTION_NAME") | |
| api_key = auth_config["api_key"] | |
| return HuggingFaceSpacesVectorStore(url, collection_name, api_key) | |
| elif vectorstore_type.lower() == "qdrant": | |
| url = config.get("vectorstore", "URL") # Use the full URL | |
| collection_name = config.get("vectorstore", "COLLECTION_NAME") | |
| api_key = auth_config["api_key"] | |
| # Remove port parameter since it's included in the URL | |
| return QdrantVectorStore(url, collection_name, api_key) | |
| else: | |
| raise ValueError(f"Unsupported vector store type: {vectorstore_type}") |