Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import asyncio | |
| import logging | |
| import chromadb | |
| import requests | |
| import stamina | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| from huggingface_hub import InferenceClient | |
| from tqdm.auto import tqdm | |
| from tqdm.contrib.concurrent import thread_map | |
| from prep_viewer_data import prep_data | |
| from utils import get_chroma_client | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions" | |
| EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908" | |
| INFERENCE_MODEL_URL = ( | |
| "https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud" | |
| ) | |
| def initialize_clients(): | |
| logger.info("Initializing clients") | |
| chroma_client = get_chroma_client() | |
| inference_client = InferenceClient( | |
| INFERENCE_MODEL_URL, | |
| ) | |
| return chroma_client, inference_client | |
| def create_collection(chroma_client): | |
| logger.info("Creating or getting collection") | |
| embedding_function = SentenceTransformerEmbeddingFunction( | |
| model_name=EMBEDDING_MODEL_NAME, | |
| trust_remote_code=True, | |
| revision=EMBEDDING_MODEL_REVISION, | |
| ) | |
| logger.info(f"Embedding function: {embedding_function}") | |
| logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}") | |
| logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}") | |
| return chroma_client.create_collection( | |
| name="dataset-viewer-descriptions", | |
| get_or_create=True, | |
| embedding_function=embedding_function, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def embed_card(text, client): | |
| text = text[:8192] | |
| return client.feature_extraction(text) | |
| def embed_and_upsert_datasets( | |
| dataset_rows_and_ids: list[dict[str, str]], | |
| collection: chromadb.Collection, | |
| inference_client: InferenceClient, | |
| batch_size: int = 100, | |
| ): | |
| logger.info( | |
| f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data" | |
| ) | |
| for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)): | |
| batch = dataset_rows_and_ids[i : i + batch_size] | |
| ids = [] | |
| documents = [] | |
| for item in batch: | |
| ids.append(item["dataset_id"]) | |
| documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}") | |
| results = thread_map( | |
| lambda doc: embed_card(doc, inference_client), documents, leave=False | |
| ) | |
| logger.info(f"Results: {len(results)}") | |
| collection.upsert( | |
| ids=ids, | |
| embeddings=[embedding.tolist()[0] for embedding in results], | |
| ) | |
| logger.debug(f"Processed batch {i//batch_size + 1}") | |
| async def refresh_viewer_data(sample_size=200_000, min_likes=2): | |
| logger.info( | |
| f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}" | |
| ) | |
| chroma_client, inference_client = initialize_clients() | |
| collection = create_collection(chroma_client) | |
| logger.info("Collection created successfully") | |
| logger.info("Preparing data") | |
| df = await prep_data(sample_size=sample_size, min_likes=min_likes) | |
| df.write_parquet("viewer_data.parquet") | |
| if df is not None: | |
| logger.info("Data prepared successfully") | |
| logger.info(f"Data: {df}") | |
| dataset_rows_and_ids = df.to_dicts() | |
| logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets") | |
| embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client) | |
| logger.info("Refresh completed successfully") | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| asyncio.run(refresh_viewer_data()) | |