Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import chromadb | |
| from cashews import cache | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| from fastapi import FastAPI, HTTPException, Query | |
| from httpx import AsyncClient | |
| from huggingface_hub import DatasetCard | |
| from pydantic import BaseModel | |
| from starlette.responses import RedirectResponse | |
| from starlette.status import ( | |
| HTTP_403_FORBIDDEN, | |
| HTTP_404_NOT_FOUND, | |
| HTTP_500_INTERNAL_SERVER_ERROR, | |
| ) | |
| from load_card_data import card_embedding_function, refresh_card_data | |
| from load_viewer_data import refresh_viewer_data | |
| from utils import get_save_path, get_collection | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set up caching | |
| cache.setup("mem://?check_interval=10&size=1000") | |
| # Initialize Chroma client | |
| SAVE_PATH = get_save_path() | |
| client = chromadb.PersistentClient(path=SAVE_PATH) | |
| async_client = AsyncClient( | |
| follow_redirects=True, | |
| ) | |
| async def lifespan(app: FastAPI): | |
| # Startup: refresh data and initialize collection | |
| logger.info("Starting up the application") | |
| try: | |
| # Refresh data | |
| logger.info("Starting refresh of card data") | |
| refresh_card_data() | |
| logger.info("Card data refresh completed") | |
| logger.info("Starting refresh of viewer data") | |
| await refresh_viewer_data() | |
| logger.info("Viewer data refresh completed") | |
| logger.info("Data refresh completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error during startup: {str(e)}") | |
| logger.warning("Application starting with potential data issues") | |
| yield | |
| # Shutdown: perform any cleanup | |
| logger.info("Shutting down the application") | |
| # Add any cleanup code here if needed | |
| app = FastAPI(lifespan=lifespan) | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| async def try_get_card(hub_id: str) -> Optional[str]: | |
| try: | |
| response = await async_client.get( | |
| f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" | |
| ) | |
| if response.status_code == 200: | |
| card = DatasetCard(response.text) | |
| return card.text | |
| except Exception as e: | |
| logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") | |
| return None | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| class DatasetCardNotFoundError(HTTPException): | |
| def __init__(self, dataset_id: str): | |
| super().__init__( | |
| status_code=HTTP_404_NOT_FOUND, | |
| detail=f"No dataset card available for dataset: {dataset_id}", | |
| ) | |
| class DatasetNotForAllAudiencesError(HTTPException): | |
| def __init__(self, dataset_id: str): | |
| super().__init__( | |
| status_code=HTTP_403_FORBIDDEN, | |
| detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.", | |
| ) | |
| async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): | |
| embedding_function = card_embedding_function() | |
| collection = get_collection(client, embedding_function, "dataset_cards") | |
| try: | |
| logger.info(f"Querying dataset: {dataset_id}") | |
| # Get the embedding for the given dataset_id | |
| result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not result.get("embeddings"): | |
| logger.info(f"Dataset not found: {dataset_id}") | |
| try: | |
| card = await try_get_card(dataset_id) | |
| if card is None: | |
| raise DatasetCardNotFoundError(dataset_id) | |
| embeddings = embedding_function(card) | |
| collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) | |
| logger.info(f"Dataset {dataset_id} added to collection") | |
| result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if result.get("not-for-all-audiences"): | |
| raise DatasetNotForAllAudiencesError(dataset_id) | |
| except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError): | |
| raise | |
| except Exception as e: | |
| logger.error( | |
| f"Error adding dataset {dataset_id} to collection: {str(e)}" | |
| ) | |
| raise DatasetCardNotFoundError(dataset_id) from e | |
| embedding = result["embeddings"][0] | |
| # Query the collection for similar datasets | |
| query_result = collection.query( | |
| query_embeddings=[embedding], n_results=n, include=["distances"] | |
| ) | |
| if not query_result["ids"]: | |
| logger.info(f"No similar datasets found for: {dataset_id}") | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
| ) | |
| # Prepare the response | |
| results = [ | |
| QueryResult(dataset_id=id, similarity=1 - distance) | |
| for id, distance in zip( | |
| query_result["ids"][0], query_result["distances"][0] | |
| ) | |
| ] | |
| logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") | |
| return QueryResponse(results=results) | |
| except (HTTPException, DatasetCardNotFoundError): | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error querying dataset {dataset_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An unexpected error occurred.", | |
| ) from e | |
| async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)): | |
| try: | |
| logger.info(f"Querying datasets by text: {query}") | |
| collection = client.get_collection( | |
| name="dataset_cards", embedding_function=card_embedding_function() | |
| ) | |
| print(query) | |
| query_result = collection.query( | |
| query_texts=query, n_results=n, include=["distances"] | |
| ) | |
| print(query_result) | |
| if not query_result["ids"]: | |
| logger.info(f"No similar datasets found for query: {query}") | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
| ) | |
| # Prepare the response | |
| results = [ | |
| QueryResult(dataset_id=str(id), similarity=float(1 - distance)) | |
| for id, distance in zip( | |
| query_result["ids"][0], query_result["distances"][0] | |
| ) | |
| ] | |
| logger.info(f"Found {len(results)} similar datasets for query: {query}") | |
| return QueryResponse(results=results) | |
| except Exception as e: | |
| logger.error(f"Error querying datasets by text {query}: {str(e)}") | |
| raise HTTPException( | |
| status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An unexpected error occurred.", | |
| ) from e | |
| async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)): | |
| try: | |
| embedding_function = SentenceTransformerEmbeddingFunction( | |
| model_name="davanstrien/dataset-viewer-descriptions-processed-st", | |
| trust_remote_code=True, | |
| ) | |
| collection = client.get_collection( | |
| name="dataset-viewer-descriptions", | |
| embedding_function=embedding_function, | |
| ) | |
| query = f"USER_QUERY: {query}" | |
| query_result = collection.query( | |
| query_texts=query, n_results=n, include=["distances"] | |
| ) | |
| print(query_result) | |
| if not query_result["ids"]: | |
| logger.info(f"No similar datasets found for query: {query}") | |
| raise HTTPException( | |
| status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
| ) | |
| # Prepare the response | |
| results = [ | |
| QueryResult(dataset_id=str(id), similarity=float(1 - distance)) | |
| for id, distance in zip( | |
| query_result["ids"][0], query_result["distances"][0] | |
| ) | |
| ] | |
| logger.info(f"Found {len(results)} similar datasets for query: {query}") | |
| return QueryResponse(results=results) | |
| except Exception as e: | |
| logger.error(f"Error querying datasets by text {query}: {str(e)}") | |
| raise HTTPException( | |
| status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An unexpected error occurred.", | |
| ) from e | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |