Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| import os | |
| from typing import List | |
| import sys | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from cashews import cache | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| import polars as pl | |
| from huggingface_hub import HfApi | |
| from transformers import AutoTokenizer | |
| # Configuration constants | |
| MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
| EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" | |
| BATCH_SIZE = 1000 | |
| CACHE_TTL = "60" | |
| hf_api = HfApi() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" | |
| ) | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| LOCAL = False | |
| if sys.platform == "darwin": | |
| LOCAL = True | |
| DATA_DIR = "data" if LOCAL else "/data" | |
| # Configure cache | |
| cache.setup("mem://", size_limit="4gb") | |
| # Initialize ChromaDB client | |
| client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") | |
| # Initialize FastAPI app | |
| async def lifespan(app: FastAPI): | |
| # Setup | |
| setup_database() | |
| yield | |
| # Cleanup | |
| await cache.close() | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://*.hf.space", # Allow all Hugging Face Spaces | |
| "https://*.huggingface.co", # Allow all Hugging Face domains | |
| "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define the embedding function at module level | |
| def get_embedding_function(): | |
| return embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="nomic-ai/modernbert-embed-base" | |
| ) | |
| def setup_database(): | |
| try: | |
| embedding_function = get_embedding_function() | |
| # Create dataset collection | |
| dataset_collection = client.get_or_create_collection( | |
| embedding_function=embedding_function, | |
| name="dataset_cards", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # Create model collection | |
| model_collection = client.get_or_create_collection( | |
| embedding_function=embedding_function, | |
| name="model_cards", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # TODO incremental updates | |
| df = pl.scan_parquet( | |
| "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
| ) | |
| row_count = df.select(pl.len()).collect().item() | |
| logger.info(f"Row count of new data: {row_count}") | |
| if dataset_collection.count() < row_count: | |
| # Load parquet files and upsert into ChromaDB | |
| df = df.select( | |
| ["datasetId", "summary", "likes", "downloads", "last_modified"] | |
| ) | |
| df = df.collect() | |
| BATCH_SIZE = 1000 | |
| total_rows = len(df) | |
| for i in range(0, total_rows, BATCH_SIZE): | |
| batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| dataset_collection.upsert( | |
| ids=batch_df.select(["datasetId"]).to_series().to_list(), | |
| documents=batch_df.select(["summary"]).to_series().to_list(), | |
| metadatas=[ | |
| { | |
| "likes": int(likes), | |
| "downloads": int(downloads), | |
| "last_modified": str(last_modified), | |
| } | |
| for likes, downloads, last_modified in zip( | |
| batch_df.select(["likes"]).to_series().to_list(), | |
| batch_df.select(["downloads"]).to_series().to_list(), | |
| batch_df.select(["last_modified"]).to_series().to_list(), | |
| ) | |
| ], | |
| ) | |
| logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") | |
| logger.info(f"Database initialized with {dataset_collection.count():,} rows") | |
| # Load model data | |
| model_df = pl.scan_parquet( | |
| "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" | |
| ) | |
| model_row_count = model_df.select(pl.len()).collect().item() | |
| logger.info(f"Row count of new model data: {model_row_count}") | |
| if model_collection.count() < model_row_count: | |
| model_df = model_df.select( | |
| ["modelId", "summary", "likes", "downloads", "last_modified"] | |
| ) | |
| model_df = model_df.collect() | |
| BATCH_SIZE = 1000 | |
| total_rows = len(model_df) | |
| for i in range(0, total_rows, BATCH_SIZE): | |
| batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| model_collection.upsert( | |
| ids=batch_df.select(["modelId"]).to_series().to_list(), | |
| documents=batch_df.select(["summary"]).to_series().to_list(), | |
| metadatas=[ | |
| { | |
| "likes": int(likes), | |
| "downloads": int(downloads), | |
| "last_modified": str(last_modified), | |
| } | |
| for likes, downloads, last_modified in zip( | |
| batch_df.select(["likes"]).to_series().to_list(), | |
| batch_df.select(["downloads"]).to_series().to_list(), | |
| batch_df.select(["last_modified"]).to_series().to_list(), | |
| ) | |
| ], | |
| ) | |
| logger.info( | |
| f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" | |
| ) | |
| logger.info( | |
| f"Model database initialized with {model_collection.count():,} rows" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Setup error: {e}") | |
| # Run setup on startup | |
| setup_database() | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| class ModelQueryResult(BaseModel): | |
| model_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| class ModelQueryResponse(BaseModel): | |
| results: List[ModelQueryResult] | |
| async def redirect_to_docs(): | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| async def search_datasets( | |
| query: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| # Get collection with proper embedding function | |
| collection = client.get_collection( | |
| name="dataset_cards", embedding_function=get_embedding_function() | |
| ) | |
| # Query ChromaDB | |
| results = collection.query( | |
| query_texts=[f"search_query: {query}"], | |
| n_results=k * 4 if sort_by != "similarity" else k, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| # Process results | |
| query_results = process_search_results(results, "dataset", k, sort_by) | |
| return QueryResponse(results=query_results) | |
| except Exception as e: | |
| logger.error(f"Search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Search failed") | |
| async def find_similar_datasets( | |
| dataset_id: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection("dataset_cards") | |
| # Get the reference document | |
| results = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not results["ids"]: | |
| raise HTTPException( | |
| status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
| ) | |
| # Query using the embedding | |
| results = collection.query( | |
| query_embeddings=[results["embeddings"][0]], | |
| n_results=k * 4 | |
| if sort_by != "similarity" | |
| else k + 1, # +1 to account for self-match | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| # Process results (excluding the query dataset itself) | |
| query_results = process_search_results( | |
| results, "dataset", k, sort_by, dataset_id | |
| ) | |
| return QueryResponse(results=query_results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Similarity search failed") | |
| async def search_models( | |
| query: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection( | |
| name="model_cards", embedding_function=get_embedding_function() | |
| ) | |
| results = collection.query( | |
| query_texts=[f"search_query: {query}"], | |
| n_results=k * 4 if sort_by != "similarity" else k, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| query_results = process_search_results(results, "model", k, sort_by) | |
| return ModelQueryResponse(results=query_results) | |
| except Exception as e: | |
| logger.error(f"Model search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Model search failed") | |
| async def find_similar_models( | |
| model_id: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection("model_cards") | |
| results = collection.get(ids=[model_id], include=["embeddings"]) | |
| if not results["ids"]: | |
| raise HTTPException( | |
| status_code=404, detail=f"Model ID '{model_id}' not found" | |
| ) | |
| results = collection.query( | |
| query_embeddings=[results["embeddings"][0]], | |
| n_results=k * 4 if sort_by != "similarity" else k + 1, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| query_results = process_search_results(results, "model", k, sort_by, model_id) | |
| return ModelQueryResponse(results=query_results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Model similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Model similarity search failed") | |
| def process_search_results(results, id_field, k, sort_by, exclude_id=None): | |
| """Process search results into a standardized format.""" | |
| query_results = [] | |
| for i in range(len(results["ids"][0])): | |
| current_id = results["ids"][0][i] | |
| if exclude_id and current_id == exclude_id: | |
| continue | |
| result = { | |
| f"{id_field}_id": current_id, | |
| "similarity": float(results["distances"][0][i]), | |
| "summary": results["documents"][0][i], | |
| "likes": results["metadatas"][0][i]["likes"], | |
| "downloads": results["metadatas"][0][i]["downloads"], | |
| } | |
| if id_field == "dataset": | |
| query_results.append(QueryResult(**result)) | |
| else: | |
| query_results.append(ModelQueryResult(**result)) | |
| if sort_by != "similarity": | |
| query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
| query_results = query_results[:k] | |
| elif exclude_id: # We fetched extra for similarity + exclude_id case | |
| query_results = query_results[:k] | |
| return query_results | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |