Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import chromadb | |
| from cashews import cache | |
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel | |
| from starlette.responses import RedirectResponse | |
| from httpx import AsyncClient | |
| from load_data import get_embedding_function, get_save_path, refresh_data | |
| from huggingface_hub import DatasetCard | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set up caching | |
| cache.setup("mem://?check_interval=10&size=10000") | |
| # Initialize Chroma client | |
| SAVE_PATH = get_save_path() | |
| client = chromadb.PersistentClient(path=SAVE_PATH) | |
| collection = None | |
| async_client = AsyncClient( | |
| follow_redirects=True, | |
| ) | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| async def lifespan(app: FastAPI): | |
| global collection | |
| # Startup: refresh data and initialize collection | |
| logger.info("Starting up the application") | |
| try: | |
| # Create or get the collection | |
| embedding_function = get_embedding_function() | |
| collection = client.get_or_create_collection( | |
| name="dataset_cards", embedding_function=embedding_function | |
| ) | |
| logger.info("Collection initialized successfully") | |
| # Refresh data | |
| refresh_data() | |
| logger.info("Data refresh completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error during startup: {str(e)}") | |
| raise | |
| yield # Here the app is running and handling requests | |
| # Shutdown: perform any cleanup | |
| logger.info("Shutting down the application") | |
| # Add any cleanup code here if needed | |
| app = FastAPI(lifespan=lifespan) | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| async def try_get_card(hub_id: str) -> Optional[str]: | |
| try: | |
| response = await async_client.get( | |
| f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" | |
| ) | |
| if response.status_code == 200: | |
| card = DatasetCard(response.text) | |
| return card.text | |
| except Exception as e: | |
| logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") | |
| return None | |
| async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): | |
| try: | |
| logger.info(f"Querying dataset: {dataset_id}") | |
| # Get the embedding for the given dataset_id | |
| result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not result.get("embeddings"): | |
| logger.info(f"Dataset not found: {dataset_id}") | |
| try: | |
| embedding_function = get_embedding_function() | |
| card = await try_get_card(dataset_id) | |
| embeddings = embedding_function(card) | |
| collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) | |
| logger.info(f"Dataset {dataset_id} added to collection") | |
| result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| except Exception as e: | |
| logger.error( | |
| f"Error adding dataset {dataset_id} to collection: {str(e)}" | |
| ) | |
| raise HTTPException(status_code=404, detail="Dataset not found") from e | |
| embedding = result["embeddings"][0] | |
| # Query the collection for similar datasets | |
| query_result = collection.query( | |
| query_embeddings=[embedding], n_results=n, include=["distances"] | |
| ) | |
| if not query_result["ids"]: | |
| logger.info(f"No similar datasets found for: {dataset_id}") | |
| return None | |
| # Prepare the response | |
| results = [ | |
| QueryResult(dataset_id=id, similarity=1 - distance) | |
| for id, distance in zip( | |
| query_result["ids"][0], query_result["distances"][0] | |
| ) | |
| ] | |
| logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") | |
| return QueryResponse(results=results) | |
| except Exception as e: | |
| logger.error(f"Error querying dataset {dataset_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |