Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| import os | |
| from typing import List | |
| import sys | |
| import duckdb | |
| from cashews import cache # Add this import | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from contextlib import asynccontextmanager | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| LOCAL = False | |
| if sys.platform == "darwin": | |
| LOCAL = True | |
| DATA_DIR = "data" if LOCAL else "/data" | |
| # Configure cache | |
| cache.setup("mem://", size_limit="4gb") | |
| # Initialize FastAPI app | |
| async def lifespan(app: FastAPI): | |
| # Startup: nothing special needed here since model and DB are initialized at module level | |
| yield | |
| # Cleanup | |
| await cache.close() | |
| con.close() | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://*.hf.space", # Allow all Hugging Face Spaces | |
| "https://*.huggingface.co", # Allow all Hugging Face domains | |
| # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize model and DuckDB | |
| model = SentenceTransformer("nomic-ai/modernbert-embed-base", backend="onnx") | |
| embedding_dim = model.get_sentence_embedding_dimension() | |
| # Database setup with fallback | |
| db_path = f"{DATA_DIR}/vector_store.db" | |
| try: | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(db_path), exist_ok=True) | |
| con = duckdb.connect(db_path) | |
| logger.info(f"Connected to persistent database at {db_path}") | |
| except (OSError, PermissionError) as e: | |
| logger.warning( | |
| f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}" | |
| ) | |
| con = duckdb.connect(":memory:") | |
| # Initialize VSS extension | |
| con.sql("INSTALL vss; LOAD vss;") | |
| con.sql("SET hnsw_enable_experimental_persistence=true;") | |
| def setup_database(): | |
| try: | |
| # Create table with properly typed embeddings | |
| con.sql(f""" | |
| CREATE TABLE IF NOT EXISTS model_cards AS | |
| SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float | |
| FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet'; | |
| """) | |
| # Check if index exists | |
| index_exists = ( | |
| con.sql(""" | |
| SELECT COUNT(*) as count | |
| FROM duckdb_indexes | |
| WHERE index_name = 'my_hnsw_index'; | |
| """).fetchone()[0] | |
| > 0 | |
| ) | |
| if index_exists: | |
| # Drop existing index | |
| con.sql("DROP INDEX my_hnsw_index;") | |
| logger.info("Dropped existing HNSW index") | |
| # Create/Recreate HNSW index | |
| con.sql(""" | |
| CREATE INDEX my_hnsw_index ON model_cards | |
| USING HNSW (embeddings_float) WITH (metric = 'cosine'); | |
| """) | |
| logger.info("Created/Recreated HNSW index") | |
| # Log the number of rows in the database | |
| row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0] | |
| logger.info(f"Database initialized with {row_count:,} rows") | |
| except Exception as e: | |
| logger.error(f"Setup error: {e}") | |
| # Run setup on startup | |
| setup_database() | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| async def redirect_to_docs(): | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)): | |
| try: | |
| query_embedding = model.encode(f"search_query: {query}").tolist() | |
| # Updated SQL query to include likes and downloads | |
| result = con.sql(f""" | |
| SELECT | |
| datasetId as dataset_id, | |
| 1 - array_cosine_distance( | |
| embeddings_float::FLOAT[{embedding_dim}], | |
| {query_embedding}::FLOAT[{embedding_dim}] | |
| ) as similarity, | |
| summary, | |
| likes, | |
| downloads | |
| FROM model_cards | |
| ORDER BY similarity DESC | |
| LIMIT {k}; | |
| """).df() | |
| # Updated result conversion | |
| results = [ | |
| QueryResult( | |
| dataset_id=row["dataset_id"], | |
| similarity=float(row["similarity"]), | |
| summary=row["summary"], | |
| likes=int(row["likes"]), | |
| downloads=int(row["downloads"]), | |
| ) | |
| for _, row in result.iterrows() | |
| ] | |
| return QueryResponse(results=results) | |
| except Exception as e: | |
| logger.error(f"Search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Search failed") | |
| async def find_similar_datasets( | |
| dataset_id: str, k: int = Query(default=5, ge=1, le=100) | |
| ): | |
| try: | |
| # First, get the embedding for the input dataset_id | |
| reference_embedding = con.sql(f""" | |
| SELECT embeddings_float | |
| FROM model_cards | |
| WHERE datasetId = '{dataset_id}' | |
| LIMIT 1; | |
| """).df() | |
| if reference_embedding.empty: | |
| raise HTTPException( | |
| status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
| ) | |
| # Updated similarity search query to include likes and downloads | |
| result = con.sql(f""" | |
| SELECT | |
| datasetId as dataset_id, | |
| 1 - array_cosine_distance( | |
| embeddings_float::FLOAT[{embedding_dim}], | |
| (SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1) | |
| ) as similarity, | |
| summary, | |
| likes, | |
| downloads | |
| FROM model_cards | |
| WHERE datasetId != '{dataset_id}' | |
| ORDER BY similarity DESC | |
| LIMIT {k}; | |
| """).df() | |
| # Updated result conversion | |
| results = [ | |
| QueryResult( | |
| dataset_id=row["dataset_id"], | |
| similarity=float(row["similarity"]), | |
| summary=row["summary"], | |
| likes=int(row["likes"]), | |
| downloads=int(row["downloads"]), | |
| ) | |
| for _, row in result.iterrows() | |
| ] | |
| return QueryResponse(results=results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Similarity search failed") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |