Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
d849643
1
Parent(s):
a0c28a9
add trending models and datasets fetching endpoints with summaries
Browse files
main.py
CHANGED
|
@@ -14,12 +14,15 @@ from huggingface_hub import HfApi
|
|
| 14 |
from transformers import AutoTokenizer
|
| 15 |
import torch
|
| 16 |
import dateutil.parser
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Configuration constants
|
| 19 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
| 20 |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
|
| 21 |
BATCH_SIZE = 2000
|
| 22 |
CACHE_TTL = "60"
|
|
|
|
| 23 |
|
| 24 |
if torch.cuda.is_available():
|
| 25 |
DEVICE = "cuda"
|
|
@@ -463,6 +466,156 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
|
| 463 |
return query_results
|
| 464 |
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
if __name__ == "__main__":
|
| 467 |
import uvicorn
|
| 468 |
|
|
|
|
| 14 |
from transformers import AutoTokenizer
|
| 15 |
import torch
|
| 16 |
import dateutil.parser
|
| 17 |
+
import httpx
|
| 18 |
+
from datetime import datetime
|
| 19 |
|
| 20 |
# Configuration constants
|
| 21 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
| 22 |
EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
|
| 23 |
BATCH_SIZE = 2000
|
| 24 |
CACHE_TTL = "60"
|
| 25 |
+
TRENDING_CACHE_TTL = "900" # 15 minutes cache for trending data
|
| 26 |
|
| 27 |
if torch.cuda.is_available():
|
| 28 |
DEVICE = "cuda"
|
|
|
|
| 466 |
return query_results
|
| 467 |
|
| 468 |
|
| 469 |
+
async def fetch_trending_models():
|
| 470 |
+
"""Fetch trending models from HuggingFace API"""
|
| 471 |
+
async with httpx.AsyncClient() as client:
|
| 472 |
+
response = await client.get("https://huggingface.co/api/models")
|
| 473 |
+
response.raise_for_status()
|
| 474 |
+
return response.json()
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
@cache(ttl=TRENDING_CACHE_TTL)
|
| 478 |
+
async def get_trending_models_with_summaries(
|
| 479 |
+
limit: int = 10,
|
| 480 |
+
min_likes: int = 0,
|
| 481 |
+
min_downloads: int = 0,
|
| 482 |
+
) -> List[ModelQueryResult]:
|
| 483 |
+
"""Fetch trending models and combine with summaries from database"""
|
| 484 |
+
try:
|
| 485 |
+
# Fetch trending models
|
| 486 |
+
trending_models = await fetch_trending_models()
|
| 487 |
+
|
| 488 |
+
# Filter by minimum likes/downloads
|
| 489 |
+
trending_models = [
|
| 490 |
+
model
|
| 491 |
+
for model in trending_models
|
| 492 |
+
if model.get("likes", 0) >= min_likes
|
| 493 |
+
and model.get("downloads", 0) >= min_downloads
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
# Sort by trending score and limit
|
| 497 |
+
trending_models = sorted(
|
| 498 |
+
trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
|
| 499 |
+
)[:limit]
|
| 500 |
+
|
| 501 |
+
# Get model IDs
|
| 502 |
+
model_ids = [model["modelId"] for model in trending_models]
|
| 503 |
+
|
| 504 |
+
# Fetch summaries from ChromaDB
|
| 505 |
+
collection = client.get_collection("model_cards")
|
| 506 |
+
summaries = collection.get(ids=model_ids, include=["documents"])
|
| 507 |
+
|
| 508 |
+
# Create mapping of model_id to summary
|
| 509 |
+
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
|
| 510 |
+
|
| 511 |
+
# Combine data
|
| 512 |
+
results = []
|
| 513 |
+
for model in trending_models:
|
| 514 |
+
if model["modelId"] in id_to_summary:
|
| 515 |
+
result = ModelQueryResult(
|
| 516 |
+
model_id=model["modelId"],
|
| 517 |
+
similarity=1.0, # Not applicable for trending
|
| 518 |
+
summary=id_to_summary[model["modelId"]],
|
| 519 |
+
likes=model.get("likes", 0),
|
| 520 |
+
downloads=model.get("downloads", 0),
|
| 521 |
+
)
|
| 522 |
+
results.append(result)
|
| 523 |
+
|
| 524 |
+
return results
|
| 525 |
+
|
| 526 |
+
except Exception as e:
|
| 527 |
+
logger.error(f"Error fetching trending models: {str(e)}")
|
| 528 |
+
raise HTTPException(status_code=500, detail="Failed to fetch trending models")
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@app.get("/trending/models", response_model=ModelQueryResponse)
|
| 532 |
+
async def get_trending_models(
|
| 533 |
+
limit: int = Query(default=10, ge=1, le=100),
|
| 534 |
+
min_likes: int = Query(default=0, ge=0),
|
| 535 |
+
min_downloads: int = Query(default=0, ge=0),
|
| 536 |
+
):
|
| 537 |
+
"""Get trending models with their summaries"""
|
| 538 |
+
results = await get_trending_models_with_summaries(
|
| 539 |
+
limit=limit, min_likes=min_likes, min_downloads=min_downloads
|
| 540 |
+
)
|
| 541 |
+
return ModelQueryResponse(results=results)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
async def fetch_trending_datasets():
|
| 545 |
+
"""Fetch trending datasets from HuggingFace API"""
|
| 546 |
+
async with httpx.AsyncClient() as client:
|
| 547 |
+
response = await client.get("https://huggingface.co/api/datasets")
|
| 548 |
+
response.raise_for_status()
|
| 549 |
+
return response.json()
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
@cache(ttl=TRENDING_CACHE_TTL)
|
| 553 |
+
async def get_trending_datasets_with_summaries(
|
| 554 |
+
limit: int = 10,
|
| 555 |
+
min_likes: int = 0,
|
| 556 |
+
min_downloads: int = 0,
|
| 557 |
+
) -> List[QueryResult]:
|
| 558 |
+
"""Fetch trending datasets and combine with summaries from database"""
|
| 559 |
+
try:
|
| 560 |
+
# Fetch trending datasets
|
| 561 |
+
trending_datasets = await fetch_trending_datasets()
|
| 562 |
+
|
| 563 |
+
# Filter by minimum likes/downloads
|
| 564 |
+
trending_datasets = [
|
| 565 |
+
dataset
|
| 566 |
+
for dataset in trending_datasets
|
| 567 |
+
if dataset.get("likes", 0) >= min_likes
|
| 568 |
+
and dataset.get("downloads", 0) >= min_downloads
|
| 569 |
+
]
|
| 570 |
+
|
| 571 |
+
# Sort by trending score and limit
|
| 572 |
+
trending_datasets = sorted(
|
| 573 |
+
trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True
|
| 574 |
+
)[:limit]
|
| 575 |
+
|
| 576 |
+
# Get dataset IDs
|
| 577 |
+
dataset_ids = [dataset["id"] for dataset in trending_datasets]
|
| 578 |
+
|
| 579 |
+
# Fetch summaries from ChromaDB
|
| 580 |
+
collection = client.get_collection("dataset_cards")
|
| 581 |
+
summaries = collection.get(ids=dataset_ids, include=["documents"])
|
| 582 |
+
|
| 583 |
+
# Create mapping of dataset_id to summary
|
| 584 |
+
id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
|
| 585 |
+
|
| 586 |
+
# Combine data
|
| 587 |
+
results = []
|
| 588 |
+
for dataset in trending_datasets:
|
| 589 |
+
if dataset["id"] in id_to_summary:
|
| 590 |
+
result = QueryResult(
|
| 591 |
+
dataset_id=dataset["id"],
|
| 592 |
+
similarity=1.0, # Not applicable for trending
|
| 593 |
+
summary=id_to_summary[dataset["id"]],
|
| 594 |
+
likes=dataset.get("likes", 0),
|
| 595 |
+
downloads=dataset.get("downloads", 0),
|
| 596 |
+
)
|
| 597 |
+
results.append(result)
|
| 598 |
+
|
| 599 |
+
return results
|
| 600 |
+
|
| 601 |
+
except Exception as e:
|
| 602 |
+
logger.error(f"Error fetching trending datasets: {str(e)}")
|
| 603 |
+
raise HTTPException(status_code=500, detail="Failed to fetch trending datasets")
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
@app.get("/trending/datasets", response_model=QueryResponse)
|
| 607 |
+
async def get_trending_datasets(
|
| 608 |
+
limit: int = Query(default=10, ge=1, le=100),
|
| 609 |
+
min_likes: int = Query(default=0, ge=0),
|
| 610 |
+
min_downloads: int = Query(default=0, ge=0),
|
| 611 |
+
):
|
| 612 |
+
"""Get trending datasets with their summaries"""
|
| 613 |
+
results = await get_trending_datasets_with_summaries(
|
| 614 |
+
limit=limit, min_likes=min_likes, min_downloads=min_downloads
|
| 615 |
+
)
|
| 616 |
+
return QueryResponse(results=results)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
if __name__ == "__main__":
|
| 620 |
import uvicorn
|
| 621 |
|