Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
davanstrien
HF Staff
refactor: optimize database setup logging and limit sample sizes for clarity
756e837
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from typing import List, Optional | |
| import chromadb | |
| import dateutil.parser | |
| import httpx | |
| import polars as pl | |
| import torch | |
| from cashews import cache | |
| from chromadb.utils import embedding_functions | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| load_dotenv(override=True) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| login(token=HF_TOKEN) | |
| # Configuration constants | |
| MODEL_NAME = "davanstrien/Smol-Hub-tldr" | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" | |
| BATCH_SIZE = 2000 | |
| CACHE_TTL = "24h" | |
| TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| DEVICE = "mps" | |
| else: | |
| DEVICE = "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| LOCAL = False | |
| if sys.platform == "darwin": | |
| LOCAL = True | |
| DATA_DIR = "data" if LOCAL else "/data" | |
| # Configure cache | |
| cache.setup("mem://", size_limit="8gb") | |
| # Initialize ChromaDB client | |
| client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") | |
| # Initialize FastAPI app | |
| async def lifespan(app: FastAPI): | |
| # Setup | |
| setup_database() | |
| yield | |
| # Cleanup | |
| await cache.close() | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://*.hf.space", # Allow all Hugging Face Spaces | |
| "https://*.huggingface.co", # Allow all Hugging Face domains | |
| # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define the embedding function at module level | |
| def get_embedding_function(): | |
| logger.info(f"Using device: {DEVICE}") | |
| return embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="Qwen/Qwen3-Embedding-0.6B", device=DEVICE | |
| ) | |
| def setup_database(): | |
| try: | |
| embedding_function = get_embedding_function() | |
| dataset_collection = client.get_or_create_collection( | |
| embedding_function=embedding_function, | |
| name="dataset_cards", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| model_collection = client.get_or_create_collection( | |
| embedding_function=embedding_function, | |
| name="model_cards", | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # Load dataset data | |
| df = pl.scan_parquet( | |
| "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() | |
| ) | |
| df = df.filter( | |
| pl.col("datasetId") | |
| .str.contains_any( | |
| ["gemma-2-2B-it-thinking-function_calling-V0"] | |
| ) # course model that's not useful for retrieving | |
| .not_() | |
| ) | |
| # Get the most recent last_modified date from the collection | |
| latest_update = None | |
| if dataset_collection.count() > 0: | |
| metadata = dataset_collection.get(include=["metadatas"]).get("metadatas") | |
| logger.info(f"Found {len(metadata)} existing records in collection") | |
| last_modifieds = [ | |
| dateutil.parser.parse(m.get("last_modified")) for m in metadata | |
| ] | |
| latest_update = max(last_modifieds) | |
| logger.info(f"Most recent record in DB from: {latest_update}") | |
| logger.info(f"Oldest record in DB from: {min(last_modifieds)}") | |
| # Log sample of existing timestamps for debugging | |
| sample_timestamps = sorted(last_modifieds, reverse=True)[:5] | |
| logger.info(f"Sample of most recent DB timestamps: {sample_timestamps}") | |
| # Filter and process only newer records | |
| df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"]) | |
| # Log some stats about the incoming data BEFORE collecting | |
| total_incoming = df.select(pl.len()).collect().item() | |
| logger.info(f"Total incoming records from source: {total_incoming}") | |
| # Get sample of dates to understand the data | |
| sample_df = ( | |
| df.select(["datasetId", "last_modified"]) | |
| .sort("last_modified", descending=True) | |
| .limit(5) | |
| .collect() | |
| ) | |
| logger.info(f"Sample of most recent incoming records: {sample_df.rows()[:3]}") | |
| if latest_update: | |
| logger.info(f"Filtering records newer than {latest_update}") | |
| logger.info(f"Latest update type: {type(latest_update)}") | |
| # Get date range before filtering | |
| date_stats = df.select( | |
| [ | |
| pl.col("last_modified").min().alias("min_date"), | |
| pl.col("last_modified").max().alias("max_date"), | |
| ] | |
| ).collect() | |
| logger.info(f"Incoming data date range: {date_stats.row(0)}") | |
| # Ensure last_modified is datetime before comparison | |
| df = df.with_columns(pl.col("last_modified").str.to_datetime()) | |
| df = df.filter(pl.col("last_modified") > latest_update) | |
| filtered_count = df.select(pl.len()).collect().item() | |
| logger.info(f"Found {filtered_count} records to update after filtering") | |
| if filtered_count == 0: | |
| logger.warning( | |
| "No new records found after filtering! This might indicate a problem." | |
| ) | |
| # Log a few records that were just below the cutoff | |
| just_before = ( | |
| df.select(["datasetId", "last_modified"]) | |
| .filter(pl.col("last_modified") <= latest_update) | |
| .sort("last_modified", descending=True) | |
| .limit(3) | |
| .collect() | |
| ) | |
| if len(just_before) > 0: | |
| logger.info(f"Records just before cutoff: {just_before.rows()}") | |
| df = df.collect() | |
| total_rows = len(df) | |
| if total_rows > 0: | |
| logger.info(f"Updating dataset collection with {total_rows} new records") | |
| logger.info( | |
| f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}" | |
| ) | |
| for i in range(0, total_rows, BATCH_SIZE): | |
| batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| batch_size = len(batch_df) | |
| logger.info( | |
| f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records " | |
| f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})" | |
| ) | |
| ids_to_upsert = batch_df.select(["datasetId"]).to_series().to_list() | |
| # Log progress for every batch | |
| if i == 0 or (i // BATCH_SIZE + 1) % 5 == 0: # Log every 5th batch | |
| logger.info(f"Upserting batch {i // BATCH_SIZE + 1} (sample IDs: {ids_to_upsert[:3]})") | |
| # Check if any of these already exist (sample only) | |
| if i == 0: # Only log for first batch to reduce noise | |
| existing_check = dataset_collection.get( | |
| ids=ids_to_upsert[:3], include=["metadatas"] | |
| ) | |
| if existing_check["ids"]: | |
| logger.info( | |
| f"Sample: {len(existing_check['ids'])} existing records being updated" | |
| ) | |
| dataset_collection.upsert( | |
| ids=ids_to_upsert, | |
| documents=batch_df.select(["summary"]).to_series().to_list(), | |
| metadatas=[ | |
| { | |
| "likes": int(likes), | |
| "downloads": int(downloads), | |
| "last_modified": str(last_modified), | |
| } | |
| for likes, downloads, last_modified in zip( | |
| batch_df.select(["likes"]).to_series().to_list(), | |
| batch_df.select(["downloads"]).to_series().to_list(), | |
| batch_df.select(["last_modified"]).to_series().to_list(), | |
| ) | |
| ], | |
| ) | |
| logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records") | |
| # Final validation | |
| final_count = dataset_collection.count() | |
| logger.info(f"Database initialized with {final_count:,} total rows") | |
| # Verify the update worked by checking latest records | |
| if final_count > 0: | |
| # Get ALL metadata to find the true latest timestamp (not just 5 records) | |
| final_metadata = dataset_collection.get(include=["metadatas"]) | |
| final_timestamps = [ | |
| dateutil.parser.parse(m.get("last_modified")) | |
| for m in final_metadata.get("metadatas") | |
| ] | |
| if final_timestamps: | |
| latest_after_update = max(final_timestamps) | |
| logger.info(f"Latest record after update: {latest_after_update}") | |
| if latest_update and latest_after_update <= latest_update: | |
| logger.error( | |
| "WARNING: No new records were added! Latest timestamp hasn't changed." | |
| ) | |
| elif latest_update: | |
| logger.info( | |
| f"Successfully added records from {latest_update} to {latest_after_update}" | |
| ) | |
| else: | |
| logger.info(f"Initial database setup completed. Latest record: {latest_after_update}") | |
| # Load model data | |
| model_lazy_df = pl.scan_parquet( | |
| "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" | |
| ) | |
| model_row_count = model_lazy_df.select(pl.len()).collect().item() | |
| logger.info(f"Total model records in source: {model_row_count}") | |
| # Get the most recent last_modified date from the model collection | |
| model_latest_update = None | |
| if model_collection.count() > 0: | |
| model_metadata = model_collection.get(include=["metadatas"]).get( | |
| "metadatas" | |
| ) | |
| logger.info( | |
| f"Found {len(model_metadata)} existing model records in collection" | |
| ) | |
| model_last_modifieds = [ | |
| dateutil.parser.parse(m.get("last_modified")) for m in model_metadata | |
| ] | |
| model_latest_update = max(model_last_modifieds) | |
| logger.info(f"Most recent model record in DB from: {model_latest_update}") | |
| # Set up model schema columns | |
| schema = model_lazy_df.collect_schema() | |
| select_columns = [ | |
| "modelId", | |
| "summary", | |
| "likes", | |
| "downloads", | |
| "last_modified", | |
| ] | |
| if "param_count" in schema: | |
| logger.info("Found 'param_count' column in model data schema.") | |
| select_columns.append("param_count") | |
| else: | |
| logger.warning( | |
| "'param_count' column not found in model data schema. Will add it with null values." | |
| ) | |
| # Filter and process only newer model records | |
| model_df = model_lazy_df.select(select_columns) | |
| # Apply timestamp filtering like we do for datasets | |
| if model_latest_update: | |
| logger.info(f"Filtering model records newer than {model_latest_update}") | |
| model_df = model_df.with_columns(pl.col("last_modified").str.to_datetime()) | |
| model_df = model_df.filter(pl.col("last_modified") > model_latest_update) | |
| model_filtered_count = model_df.select(pl.len()).collect().item() | |
| logger.info(f"Found {model_filtered_count} model records to update after filtering") | |
| else: | |
| model_filtered_count = model_df.select(pl.len()).collect().item() | |
| logger.info(f"Initial model load: processing all {model_filtered_count} model records") | |
| if model_filtered_count > 0: | |
| model_df = model_df.collect() | |
| # If param_count was not in the original schema, add it now to the collected DataFrame | |
| if "param_count" not in model_df.columns: | |
| model_df = model_df.with_columns( | |
| pl.lit(None).cast(pl.Int64).alias("param_count") | |
| ) | |
| total_rows = len(model_df) | |
| logger.info(f"Updating model collection with {total_rows} new records") | |
| for i in range(0, total_rows, BATCH_SIZE): | |
| batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) | |
| model_collection.upsert( | |
| ids=batch_df.select(["modelId"]).to_series().to_list(), | |
| documents=batch_df.select(["summary"]).to_series().to_list(), | |
| metadatas=[ | |
| { | |
| "likes": int(likes), | |
| "downloads": int(downloads), | |
| "last_modified": str(last_modified), | |
| "param_count": int(param_count) | |
| if param_count is not None | |
| else 0, | |
| } | |
| for likes, downloads, last_modified, param_count in zip( | |
| batch_df.select(["likes"]).to_series().to_list(), | |
| batch_df.select(["downloads"]).to_series().to_list(), | |
| batch_df.select(["last_modified"]).to_series().to_list(), | |
| batch_df.select(["param_count"]).to_series().to_list(), | |
| ) | |
| ], | |
| ) | |
| logger.info( | |
| f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" | |
| ) | |
| logger.info( | |
| f"Model database initialized with {model_collection.count():,} rows" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Setup error: {e}") | |
| # Setup database is called in lifespan, not here | |
| class QueryResult(BaseModel): | |
| dataset_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| class QueryResponse(BaseModel): | |
| results: List[QueryResult] | |
| class ModelQueryResult(BaseModel): | |
| model_id: str | |
| similarity: float | |
| summary: str | |
| likes: int | |
| downloads: int | |
| param_count: Optional[int] = None | |
| class ModelQueryResponse(BaseModel): | |
| results: List[ModelQueryResult] | |
| async def redirect_to_docs(): | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| async def search_datasets( | |
| query: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads", "trending"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection( | |
| name="dataset_cards", embedding_function=get_embedding_function() | |
| ) | |
| task_description = "Given a search query, retrieve relevant model and dataset summaries that match the query. " | |
| query = f"Instruct: {task_description}\nQuery:{query}" | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=k * 4 if sort_by != "similarity" else k, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| query_results = await process_search_results(results, "dataset", k, sort_by) | |
| return QueryResponse(results=query_results) | |
| except Exception as e: | |
| logger.error(f"Search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Search failed") | |
| async def find_similar_datasets( | |
| dataset_id: str, | |
| k: int = Query(default=5, ge=1, le=100), | |
| sort_by: str = Query( | |
| default="similarity", enum=["similarity", "likes", "downloads", "trending"] | |
| ), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| try: | |
| collection = client.get_collection("dataset_cards") | |
| results = collection.get(ids=[dataset_id], include=["embeddings"]) | |
| if not results["ids"]: | |
| raise HTTPException( | |
| status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
| ) | |
| results = collection.query( | |
| query_embeddings=[results["embeddings"][0]], | |
| n_results=k * 4 if sort_by != "similarity" else k + 1, | |
| where={ | |
| "$and": [ | |
| {"likes": {"$gte": min_likes}}, | |
| {"downloads": {"$gte": min_downloads}}, | |
| ] | |
| } | |
| if min_likes > 0 or min_downloads > 0 | |
| else None, | |
| ) | |
| query_results = await process_search_results( | |
| results, "dataset", k, sort_by, dataset_id | |
| ) | |
| return QueryResponse(results=query_results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Similarity search failed") | |
| async def search_models( | |
| query: str, | |
| k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), | |
| sort_by: str = Query( | |
| default="similarity", | |
| enum=["similarity", "likes", "downloads", "trending"], | |
| description="Sort method for results", | |
| ), | |
| min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
| min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
| min_param_count: int = Query( | |
| default=0, | |
| ge=0, | |
| description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", | |
| ), | |
| max_param_count: Optional[int] = Query( | |
| default=None, | |
| ge=0, | |
| description="Maximum parameter count (None means no upper limit)", | |
| ), | |
| ): | |
| """ | |
| Search for models based on a text query with optional filtering. | |
| - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
| - param_count=0 indicates missing/unknown parameter count in the dataset | |
| """ | |
| try: | |
| collection = client.get_collection( | |
| name="model_cards", embedding_function=get_embedding_function() | |
| ) | |
| where_conditions = [] | |
| if min_likes > 0: | |
| where_conditions.append({"likes": {"$gte": min_likes}}) | |
| if min_downloads > 0: | |
| where_conditions.append({"downloads": {"$gte": min_downloads}}) | |
| # Add parameter count filters | |
| using_param_filters = min_param_count > 0 or max_param_count is not None | |
| if using_param_filters: | |
| # Always exclude zero param count when using any parameter filters | |
| where_conditions.append({"param_count": {"$gt": 0}}) | |
| if min_param_count > 0: | |
| where_conditions.append({"param_count": {"$gte": min_param_count}}) | |
| if max_param_count is not None: | |
| where_conditions.append({"param_count": {"$lte": max_param_count}}) | |
| # Handle where clause creation based on number of conditions | |
| where_clause = None | |
| if len(where_conditions) > 1: | |
| where_clause = {"$and": where_conditions} | |
| elif len(where_conditions) == 1: | |
| where_clause = where_conditions[0] # Single condition without $and | |
| results = collection.query( | |
| query_texts=[f"search_query: {query}"], | |
| n_results=k * 4 if sort_by != "similarity" else k, | |
| where=where_clause, | |
| ) | |
| query_results = await process_search_results(results, "model", k, sort_by) | |
| return ModelQueryResponse(results=query_results) | |
| except Exception as e: | |
| logger.error(f"Model search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Model search failed") | |
| async def find_similar_models( | |
| model_id: str, | |
| k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), | |
| sort_by: str = Query( | |
| default="similarity", | |
| enum=["similarity", "likes", "downloads", "trending"], | |
| description="Sort method for results", | |
| ), | |
| min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
| min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
| min_param_count: int = Query( | |
| default=0, | |
| ge=0, | |
| description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", | |
| ), | |
| max_param_count: Optional[int] = Query( | |
| default=None, | |
| ge=0, | |
| description="Maximum parameter count (None means no upper limit)", | |
| ), | |
| ): | |
| """ | |
| Find similar models to a specified model with optional filtering. | |
| - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
| - param_count=0 indicates missing/unknown parameter count in the dataset | |
| """ | |
| try: | |
| collection = client.get_collection("model_cards") | |
| results = collection.get(ids=[model_id], include=["embeddings"]) | |
| if not results["ids"]: | |
| raise HTTPException( | |
| status_code=404, detail=f"Model ID '{model_id}' not found" | |
| ) | |
| where_conditions = [] | |
| if min_likes > 0: | |
| where_conditions.append({"likes": {"$gte": min_likes}}) | |
| if min_downloads > 0: | |
| where_conditions.append({"downloads": {"$gte": min_downloads}}) | |
| # Add parameter count filters | |
| using_param_filters = min_param_count > 0 or max_param_count is not None | |
| if using_param_filters: | |
| # Always exclude zero param count when using any parameter filters | |
| where_conditions.append({"param_count": {"$gt": 0}}) | |
| if min_param_count > 0: | |
| where_conditions.append({"param_count": {"$gte": min_param_count}}) | |
| if max_param_count is not None: | |
| where_conditions.append({"param_count": {"$lte": max_param_count}}) | |
| # Handle where clause creation based on number of conditions | |
| where_clause = None | |
| if len(where_conditions) > 1: | |
| where_clause = {"$and": where_conditions} | |
| elif len(where_conditions) == 1: | |
| where_clause = where_conditions[0] # Single condition without $and | |
| results = collection.query( | |
| query_embeddings=[results["embeddings"][0]], | |
| n_results=k * 4 if sort_by != "similarity" else k + 1, | |
| where=where_clause, | |
| ) | |
| query_results = await process_search_results( | |
| results, "model", k, sort_by, model_id | |
| ) | |
| return ModelQueryResponse(results=query_results) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Model similarity search error: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Model similarity search failed") | |
| async def get_trending_score(item_id: str, item_type: str) -> float: | |
| """Fetch trending score for a model or dataset from HuggingFace API""" | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| endpoint = "models" if item_type == "model" else "datasets" | |
| response = await client.get( | |
| f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore" | |
| ) | |
| response.raise_for_status() | |
| return response.json().get("trendingScore", 0) | |
| except Exception as e: | |
| logger.error( | |
| f"Error fetching trending score for {item_type} {item_id}: {str(e)}" | |
| ) | |
| return 0 | |
| async def process_search_results(results, id_field, k, sort_by, exclude_id=None): | |
| """Process search results into a standardized format.""" | |
| query_results = [] | |
| # Create base results | |
| for i in range(len(results["ids"][0])): | |
| current_id = results["ids"][0][i] | |
| if exclude_id and current_id == exclude_id: | |
| continue | |
| result = { | |
| f"{id_field}_id": current_id, | |
| "similarity": float(results["distances"][0][i]), | |
| "summary": results["documents"][0][i], | |
| "likes": results["metadatas"][0][i]["likes"], | |
| "downloads": results["metadatas"][0][i]["downloads"], | |
| } | |
| # Add param_count for models if it exists in metadata | |
| if id_field == "model" and "param_count" in results["metadatas"][0][i]: | |
| result["param_count"] = results["metadatas"][0][i]["param_count"] | |
| if id_field == "dataset": | |
| query_results.append(QueryResult(**result)) | |
| else: | |
| query_results.append(ModelQueryResult(**result)) | |
| # Handle sorting | |
| if sort_by == "trending": | |
| # Fetch trending scores for all results | |
| trending_scores = {} | |
| async with httpx.AsyncClient() as client: | |
| tasks = [ | |
| get_trending_score( | |
| getattr(result, f"{id_field}_id"), | |
| "model" if id_field == "model" else "dataset", | |
| ) | |
| for result in query_results | |
| ] | |
| scores = await asyncio.gather(*tasks) | |
| trending_scores = { | |
| getattr(result, f"{id_field}_id"): score | |
| for result, score in zip(query_results, scores) | |
| } | |
| # Sort by trending score | |
| query_results.sort( | |
| key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0), | |
| reverse=True, | |
| ) | |
| query_results = query_results[:k] | |
| elif sort_by != "similarity": | |
| query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) | |
| query_results = query_results[:k] | |
| elif exclude_id: # We fetched extra for similarity + exclude_id case | |
| query_results = query_results[:k] | |
| return query_results | |
| async def fetch_trending_models(): | |
| """Fetch trending models from HuggingFace API""" | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get("https://huggingface.co/api/models") | |
| response.raise_for_status() | |
| return response.json() | |
| async def get_trending_models_with_summaries( | |
| limit: int = 10, | |
| min_likes: int = 0, | |
| min_downloads: int = 0, | |
| min_param_count: int = 0, | |
| max_param_count: Optional[int] = None, | |
| ) -> List[ModelQueryResult]: | |
| """Fetch trending models and combine with summaries from database""" | |
| try: | |
| # Fetch trending models | |
| trending_models = await fetch_trending_models() | |
| # Filter by minimum likes/downloads | |
| trending_models = [ | |
| model | |
| for model in trending_models | |
| if model.get("likes", 0) >= min_likes | |
| and model.get("downloads", 0) >= min_downloads | |
| ] | |
| # Sort by trending score | |
| trending_models = sorted( | |
| trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True | |
| ) | |
| # Fetch up to 3x the limit (buffer for filtering) or all available if fewer | |
| # This ensures we have enough models to filter from | |
| fetch_limit = min(len(trending_models), limit * 3) | |
| trending_models = trending_models[:fetch_limit] | |
| # Get model IDs | |
| model_ids = [model["modelId"] for model in trending_models] | |
| # Fetch summaries from ChromaDB | |
| collection = client.get_collection("model_cards") | |
| summaries = collection.get(ids=model_ids, include=["documents", "metadatas"]) | |
| # Create mapping of model_id to summary and metadata | |
| id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) | |
| id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"])) | |
| # Log parameters for debugging | |
| print( | |
| f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}" | |
| ) | |
| # Combine data - collect all results first | |
| all_results = [] | |
| for model in trending_models: | |
| if model["modelId"] in id_to_summary: | |
| metadata = id_to_metadata.get(model["modelId"], {}) | |
| param_count = metadata.get("param_count", 0) | |
| # Log model parameter counts | |
| print(f"Model: {model['modelId']}, param_count: {param_count}") | |
| result = ModelQueryResult( | |
| model_id=model["modelId"], | |
| similarity=1.0, # Not applicable for trending | |
| summary=id_to_summary[model["modelId"]], | |
| likes=model.get("likes", 0), | |
| downloads=model.get("downloads", 0), | |
| param_count=param_count, | |
| ) | |
| all_results.append(result) | |
| # Apply parameter filtering after collecting all results | |
| filtered_results = all_results | |
| # Check if any parameter filtering is being applied | |
| using_param_filters = min_param_count > 0 or max_param_count is not None | |
| # Only filter by params if we have specific parameter constraints | |
| if using_param_filters: | |
| filtered_results = [] | |
| for result in all_results: | |
| should_include = True | |
| # Always exclude models with param_count=0 when any parameter filtering is active | |
| if result.param_count == 0: | |
| print( | |
| f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active" | |
| ) | |
| should_include = False | |
| # Apply min param filter if specified | |
| elif min_param_count > 0 and result.param_count < min_param_count: | |
| print( | |
| f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}" | |
| ) | |
| should_include = False | |
| # Apply max param filter if specified | |
| elif ( | |
| max_param_count is not None and result.param_count > max_param_count | |
| ): | |
| print( | |
| f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}" | |
| ) | |
| should_include = False | |
| if should_include: | |
| filtered_results.append(result) | |
| print(f"After filtering: {len(filtered_results)} models remain") | |
| # Finally limit to the requested number | |
| return filtered_results[:limit] | |
| except Exception as e: | |
| logger.error(f"Error fetching trending models: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to fetch trending models") | |
| async def get_trending_models( | |
| limit: int = Query( | |
| default=10, ge=1, le=100, description="Number of results to return" | |
| ), | |
| min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), | |
| min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), | |
| min_param_count: int = Query( | |
| default=0, | |
| ge=0, | |
| description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)", | |
| ), | |
| max_param_count: Optional[int] = Query( | |
| default=None, | |
| ge=0, | |
| description="Maximum parameter count (None means no upper limit)", | |
| ), | |
| ): | |
| """ | |
| Get trending models with their summaries and optional filtering. | |
| - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded | |
| - param_count=0 indicates missing/unknown parameter count in the dataset | |
| """ | |
| print( | |
| f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}" | |
| ) | |
| results = await get_trending_models_with_summaries( | |
| limit=limit, | |
| min_likes=min_likes, | |
| min_downloads=min_downloads, | |
| min_param_count=min_param_count, | |
| max_param_count=max_param_count, | |
| ) | |
| print(f"Returning {len(results)} trending model results") | |
| return ModelQueryResponse(results=results) | |
| async def fetch_trending_datasets(): | |
| """Fetch trending datasets from HuggingFace API""" | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get("https://huggingface.co/api/datasets") | |
| response.raise_for_status() | |
| return response.json() | |
| async def get_trending_datasets_with_summaries( | |
| limit: int = 10, | |
| min_likes: int = 0, | |
| min_downloads: int = 0, | |
| ) -> List[QueryResult]: | |
| """Fetch trending datasets and combine with summaries from database""" | |
| try: | |
| # Fetch trending datasets | |
| trending_datasets = await fetch_trending_datasets() | |
| # Filter by minimum likes/downloads | |
| trending_datasets = [ | |
| dataset | |
| for dataset in trending_datasets | |
| if dataset.get("likes", 0) >= min_likes | |
| and dataset.get("downloads", 0) >= min_downloads | |
| ] | |
| # Sort by trending score and limit | |
| trending_datasets = sorted( | |
| trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True | |
| )[:limit] | |
| # Get dataset IDs | |
| dataset_ids = [dataset["id"] for dataset in trending_datasets] | |
| # Fetch summaries from ChromaDB | |
| collection = client.get_collection("dataset_cards") | |
| summaries = collection.get(ids=dataset_ids, include=["documents"]) | |
| # Create mapping of dataset_id to summary | |
| id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) | |
| # Combine data | |
| results = [] | |
| for dataset in trending_datasets: | |
| if dataset["id"] in id_to_summary: | |
| result = QueryResult( | |
| dataset_id=dataset["id"], | |
| similarity=1.0, # Not applicable for trending | |
| summary=id_to_summary[dataset["id"]], | |
| likes=dataset.get("likes", 0), | |
| downloads=dataset.get("downloads", 0), | |
| ) | |
| results.append(result) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error fetching trending datasets: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to fetch trending datasets") | |
| async def get_trending_datasets( | |
| limit: int = Query(default=10, ge=1, le=100), | |
| min_likes: int = Query(default=0, ge=0), | |
| min_downloads: int = Query(default=0, ge=0), | |
| ): | |
| """Get trending datasets with their summaries""" | |
| results = await get_trending_datasets_with_summaries( | |
| limit=limit, min_likes=min_likes, min_downloads=min_downloads | |
| ) | |
| return QueryResponse(results=results) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |