Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
abbed11
1
Parent(s):
ab03ef7
return if no dataset card is available
Browse files
main.py
CHANGED
|
@@ -5,11 +5,12 @@ from typing import List, Optional
|
|
| 5 |
import chromadb
|
| 6 |
from cashews import cache
|
| 7 |
from fastapi import FastAPI, HTTPException, Query
|
|
|
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from starlette.responses import RedirectResponse
|
| 10 |
-
|
| 11 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
| 12 |
-
from huggingface_hub import DatasetCard
|
| 13 |
|
| 14 |
# Set up logging
|
| 15 |
logging.basicConfig(
|
|
@@ -87,7 +88,7 @@ async def try_get_card(hub_id: str) -> Optional[str]:
|
|
| 87 |
return None
|
| 88 |
|
| 89 |
|
| 90 |
-
@app.get("/similar", response_model=
|
| 91 |
@cache(ttl="1h")
|
| 92 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
| 93 |
try:
|
|
@@ -99,6 +100,8 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 99 |
try:
|
| 100 |
embedding_function = get_embedding_function()
|
| 101 |
card = await try_get_card(dataset_id)
|
|
|
|
|
|
|
| 102 |
embeddings = embedding_function(card)
|
| 103 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
| 104 |
logger.info(f"Dataset {dataset_id} added to collection")
|
|
@@ -107,7 +110,7 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 107 |
logger.error(
|
| 108 |
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
| 109 |
)
|
| 110 |
-
|
| 111 |
|
| 112 |
embedding = result["embeddings"][0]
|
| 113 |
|
|
@@ -118,7 +121,7 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 118 |
|
| 119 |
if not query_result["ids"]:
|
| 120 |
logger.info(f"No similar datasets found for: {dataset_id}")
|
| 121 |
-
return
|
| 122 |
|
| 123 |
# Prepare the response
|
| 124 |
results = [
|
|
@@ -133,8 +136,7 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
|
|
| 133 |
|
| 134 |
except Exception as e:
|
| 135 |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
|
| 136 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 137 |
-
|
| 138 |
|
| 139 |
if __name__ == "__main__":
|
| 140 |
import uvicorn
|
|
|
|
| 5 |
import chromadb
|
| 6 |
from cashews import cache
|
| 7 |
from fastapi import FastAPI, HTTPException, Query
|
| 8 |
+
from httpx import AsyncClient
|
| 9 |
+
from huggingface_hub import DatasetCard
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from starlette.responses import RedirectResponse
|
| 12 |
+
|
| 13 |
from load_data import get_embedding_function, get_save_path, refresh_data
|
|
|
|
| 14 |
|
| 15 |
# Set up logging
|
| 16 |
logging.basicConfig(
|
|
|
|
| 88 |
return None
|
| 89 |
|
| 90 |
|
| 91 |
+
@app.get("/similar", response_model=QueryResponse)
|
| 92 |
@cache(ttl="1h")
|
| 93 |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
|
| 94 |
try:
|
|
|
|
| 100 |
try:
|
| 101 |
embedding_function = get_embedding_function()
|
| 102 |
card = await try_get_card(dataset_id)
|
| 103 |
+
if card is None:
|
| 104 |
+
return QueryResponse(message="No dataset card available for recommendations.")
|
| 105 |
embeddings = embedding_function(card)
|
| 106 |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
|
| 107 |
logger.info(f"Dataset {dataset_id} added to collection")
|
|
|
|
| 110 |
logger.error(
|
| 111 |
f"Error adding dataset {dataset_id} to collection: {str(e)}"
|
| 112 |
)
|
| 113 |
+
return QueryResponse(message="No dataset card available for recommendations.")
|
| 114 |
|
| 115 |
embedding = result["embeddings"][0]
|
| 116 |
|
|
|
|
| 121 |
|
| 122 |
if not query_result["ids"]:
|
| 123 |
logger.info(f"No similar datasets found for: {dataset_id}")
|
| 124 |
+
return QueryResponse(message="No similar datasets found.")
|
| 125 |
|
| 126 |
# Prepare the response
|
| 127 |
results = [
|
|
|
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
|
| 139 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 140 |
|
| 141 |
if __name__ == "__main__":
|
| 142 |
import uvicorn
|