Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
97ab261
1
Parent(s):
d574b22
add trending sorting option and fetch trending scores for datasets and models
Browse files
main.py
CHANGED
|
@@ -1,21 +1,23 @@
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
-
from typing import List
|
| 4 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import chromadb
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
from cashews import cache
|
|
|
|
| 8 |
from fastapi import FastAPI, HTTPException, Query
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 10 |
from pydantic import BaseModel
|
| 11 |
-
from contextlib import asynccontextmanager
|
| 12 |
-
import polars as pl
|
| 13 |
-
from huggingface_hub import HfApi
|
| 14 |
from transformers import AutoTokenizer
|
| 15 |
-
import torch
|
| 16 |
-
import dateutil.parser
|
| 17 |
-
import httpx
|
| 18 |
-
from datetime import datetime
|
| 19 |
|
| 20 |
# Configuration constants
|
| 21 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
|
@@ -272,18 +274,16 @@ async def search_datasets(
|
|
| 272 |
query: str,
|
| 273 |
k: int = Query(default=5, ge=1, le=100),
|
| 274 |
sort_by: str = Query(
|
| 275 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
| 276 |
),
|
| 277 |
min_likes: int = Query(default=0, ge=0),
|
| 278 |
min_downloads: int = Query(default=0, ge=0),
|
| 279 |
):
|
| 280 |
try:
|
| 281 |
-
# Get collection with proper embedding function
|
| 282 |
collection = client.get_collection(
|
| 283 |
name="dataset_cards", embedding_function=get_embedding_function()
|
| 284 |
)
|
| 285 |
|
| 286 |
-
# Query ChromaDB
|
| 287 |
results = collection.query(
|
| 288 |
query_texts=[f"search_query: {query}"],
|
| 289 |
n_results=k * 4 if sort_by != "similarity" else k,
|
|
@@ -297,8 +297,7 @@ async def search_datasets(
|
|
| 297 |
else None,
|
| 298 |
)
|
| 299 |
|
| 300 |
-
|
| 301 |
-
query_results = process_search_results(results, "dataset", k, sort_by)
|
| 302 |
|
| 303 |
return QueryResponse(results=query_results)
|
| 304 |
|
|
@@ -313,7 +312,7 @@ async def find_similar_datasets(
|
|
| 313 |
dataset_id: str,
|
| 314 |
k: int = Query(default=5, ge=1, le=100),
|
| 315 |
sort_by: str = Query(
|
| 316 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
| 317 |
),
|
| 318 |
min_likes: int = Query(default=0, ge=0),
|
| 319 |
min_downloads: int = Query(default=0, ge=0),
|
|
@@ -321,7 +320,6 @@ async def find_similar_datasets(
|
|
| 321 |
try:
|
| 322 |
collection = client.get_collection("dataset_cards")
|
| 323 |
|
| 324 |
-
# Get the reference document
|
| 325 |
results = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 326 |
|
| 327 |
if not results["ids"]:
|
|
@@ -329,12 +327,9 @@ async def find_similar_datasets(
|
|
| 329 |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
|
| 330 |
)
|
| 331 |
|
| 332 |
-
# Query using the embedding
|
| 333 |
results = collection.query(
|
| 334 |
query_embeddings=[results["embeddings"][0]],
|
| 335 |
-
n_results=k * 4
|
| 336 |
-
if sort_by != "similarity"
|
| 337 |
-
else k + 1, # +1 to account for self-match
|
| 338 |
where={
|
| 339 |
"$and": [
|
| 340 |
{"likes": {"$gte": min_likes}},
|
|
@@ -345,8 +340,7 @@ async def find_similar_datasets(
|
|
| 345 |
else None,
|
| 346 |
)
|
| 347 |
|
| 348 |
-
|
| 349 |
-
query_results = process_search_results(
|
| 350 |
results, "dataset", k, sort_by, dataset_id
|
| 351 |
)
|
| 352 |
|
|
@@ -365,7 +359,7 @@ async def search_models(
|
|
| 365 |
query: str,
|
| 366 |
k: int = Query(default=5, ge=1, le=100),
|
| 367 |
sort_by: str = Query(
|
| 368 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
| 369 |
),
|
| 370 |
min_likes: int = Query(default=0, ge=0),
|
| 371 |
min_downloads: int = Query(default=0, ge=0),
|
|
@@ -388,7 +382,7 @@ async def search_models(
|
|
| 388 |
else None,
|
| 389 |
)
|
| 390 |
|
| 391 |
-
query_results = process_search_results(results, "model", k, sort_by)
|
| 392 |
|
| 393 |
return ModelQueryResponse(results=query_results)
|
| 394 |
|
|
@@ -431,7 +425,9 @@ async def find_similar_models(
|
|
| 431 |
else None,
|
| 432 |
)
|
| 433 |
|
| 434 |
-
query_results = process_search_results(
|
|
|
|
|
|
|
| 435 |
|
| 436 |
return ModelQueryResponse(results=query_results)
|
| 437 |
|
|
@@ -442,9 +438,29 @@ async def find_similar_models(
|
|
| 442 |
raise HTTPException(status_code=500, detail="Model similarity search failed")
|
| 443 |
|
| 444 |
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
"""Process search results into a standardized format."""
|
| 447 |
query_results = []
|
|
|
|
|
|
|
| 448 |
for i in range(len(results["ids"][0])):
|
| 449 |
current_id = results["ids"][0][i]
|
| 450 |
if exclude_id and current_id == exclude_id:
|
|
@@ -463,7 +479,31 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
|
| 463 |
else:
|
| 464 |
query_results.append(ModelQueryResult(**result))
|
| 465 |
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
|
| 468 |
query_results = query_results[:k]
|
| 469 |
elif exclude_id: # We fetched extra for similarity + exclude_id case
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
| 4 |
import sys
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
import chromadb
|
| 10 |
+
import dateutil.parser
|
| 11 |
+
import httpx
|
| 12 |
+
import polars as pl
|
| 13 |
+
import torch
|
| 14 |
from cashews import cache
|
| 15 |
+
from chromadb.utils import embedding_functions
|
| 16 |
from fastapi import FastAPI, HTTPException, Query
|
| 17 |
from fastapi.middleware.cors import CORSMiddleware
|
| 18 |
+
from huggingface_hub import HfApi, model_info
|
| 19 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
| 20 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Configuration constants
|
| 23 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
|
|
|
| 274 |
query: str,
|
| 275 |
k: int = Query(default=5, ge=1, le=100),
|
| 276 |
sort_by: str = Query(
|
| 277 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
| 278 |
),
|
| 279 |
min_likes: int = Query(default=0, ge=0),
|
| 280 |
min_downloads: int = Query(default=0, ge=0),
|
| 281 |
):
|
| 282 |
try:
|
|
|
|
| 283 |
collection = client.get_collection(
|
| 284 |
name="dataset_cards", embedding_function=get_embedding_function()
|
| 285 |
)
|
| 286 |
|
|
|
|
| 287 |
results = collection.query(
|
| 288 |
query_texts=[f"search_query: {query}"],
|
| 289 |
n_results=k * 4 if sort_by != "similarity" else k,
|
|
|
|
| 297 |
else None,
|
| 298 |
)
|
| 299 |
|
| 300 |
+
query_results = await process_search_results(results, "dataset", k, sort_by)
|
|
|
|
| 301 |
|
| 302 |
return QueryResponse(results=query_results)
|
| 303 |
|
|
|
|
| 312 |
dataset_id: str,
|
| 313 |
k: int = Query(default=5, ge=1, le=100),
|
| 314 |
sort_by: str = Query(
|
| 315 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
| 316 |
),
|
| 317 |
min_likes: int = Query(default=0, ge=0),
|
| 318 |
min_downloads: int = Query(default=0, ge=0),
|
|
|
|
| 320 |
try:
|
| 321 |
collection = client.get_collection("dataset_cards")
|
| 322 |
|
|
|
|
| 323 |
results = collection.get(ids=[dataset_id], include=["embeddings"])
|
| 324 |
|
| 325 |
if not results["ids"]:
|
|
|
|
| 327 |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
|
| 328 |
)
|
| 329 |
|
|
|
|
| 330 |
results = collection.query(
|
| 331 |
query_embeddings=[results["embeddings"][0]],
|
| 332 |
+
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
|
|
|
|
|
|
| 333 |
where={
|
| 334 |
"$and": [
|
| 335 |
{"likes": {"$gte": min_likes}},
|
|
|
|
| 340 |
else None,
|
| 341 |
)
|
| 342 |
|
| 343 |
+
query_results = await process_search_results(
|
|
|
|
| 344 |
results, "dataset", k, sort_by, dataset_id
|
| 345 |
)
|
| 346 |
|
|
|
|
| 359 |
query: str,
|
| 360 |
k: int = Query(default=5, ge=1, le=100),
|
| 361 |
sort_by: str = Query(
|
| 362 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
| 363 |
),
|
| 364 |
min_likes: int = Query(default=0, ge=0),
|
| 365 |
min_downloads: int = Query(default=0, ge=0),
|
|
|
|
| 382 |
else None,
|
| 383 |
)
|
| 384 |
|
| 385 |
+
query_results = await process_search_results(results, "model", k, sort_by)
|
| 386 |
|
| 387 |
return ModelQueryResponse(results=query_results)
|
| 388 |
|
|
|
|
| 425 |
else None,
|
| 426 |
)
|
| 427 |
|
| 428 |
+
query_results = await process_search_results(
|
| 429 |
+
results, "model", k, sort_by, model_id
|
| 430 |
+
)
|
| 431 |
|
| 432 |
return ModelQueryResponse(results=query_results)
|
| 433 |
|
|
|
|
| 438 |
raise HTTPException(status_code=500, detail="Model similarity search failed")
|
| 439 |
|
| 440 |
|
| 441 |
+
@cache(ttl="1h")
|
| 442 |
+
async def get_trending_score(item_id: str, item_type: str) -> float:
|
| 443 |
+
"""Fetch trending score for a model or dataset from HuggingFace API"""
|
| 444 |
+
try:
|
| 445 |
+
async with httpx.AsyncClient() as client:
|
| 446 |
+
endpoint = "models" if item_type == "model" else "datasets"
|
| 447 |
+
response = await client.get(
|
| 448 |
+
f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore"
|
| 449 |
+
)
|
| 450 |
+
response.raise_for_status()
|
| 451 |
+
return response.json().get("trendingScore", 0)
|
| 452 |
+
except Exception as e:
|
| 453 |
+
logger.error(
|
| 454 |
+
f"Error fetching trending score for {item_type} {item_id}: {str(e)}"
|
| 455 |
+
)
|
| 456 |
+
return 0
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
async def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
| 460 |
"""Process search results into a standardized format."""
|
| 461 |
query_results = []
|
| 462 |
+
|
| 463 |
+
# Create base results
|
| 464 |
for i in range(len(results["ids"][0])):
|
| 465 |
current_id = results["ids"][0][i]
|
| 466 |
if exclude_id and current_id == exclude_id:
|
|
|
|
| 479 |
else:
|
| 480 |
query_results.append(ModelQueryResult(**result))
|
| 481 |
|
| 482 |
+
# Handle sorting
|
| 483 |
+
if sort_by == "trending":
|
| 484 |
+
# Fetch trending scores for all results
|
| 485 |
+
trending_scores = {}
|
| 486 |
+
async with httpx.AsyncClient() as client:
|
| 487 |
+
tasks = [
|
| 488 |
+
get_trending_score(
|
| 489 |
+
getattr(result, f"{id_field}_id"),
|
| 490 |
+
"model" if id_field == "model" else "dataset",
|
| 491 |
+
)
|
| 492 |
+
for result in query_results
|
| 493 |
+
]
|
| 494 |
+
scores = await asyncio.gather(*tasks)
|
| 495 |
+
trending_scores = {
|
| 496 |
+
getattr(result, f"{id_field}_id"): score
|
| 497 |
+
for result, score in zip(query_results, scores)
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
# Sort by trending score
|
| 501 |
+
query_results.sort(
|
| 502 |
+
key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0),
|
| 503 |
+
reverse=True,
|
| 504 |
+
)
|
| 505 |
+
query_results = query_results[:k]
|
| 506 |
+
elif sort_by != "similarity":
|
| 507 |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
|
| 508 |
query_results = query_results[:k]
|
| 509 |
elif exclude_id: # We fetched extra for similarity + exclude_id case
|