Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
ed553e8
1
Parent(s):
16eb3c6
improve refresh logic
Browse files
main.py
CHANGED
|
@@ -13,6 +13,7 @@ import polars as pl
|
|
| 13 |
from huggingface_hub import HfApi
|
| 14 |
from transformers import AutoTokenizer
|
| 15 |
import torch
|
|
|
|
| 16 |
|
| 17 |
# Configuration constants
|
| 18 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
|
@@ -89,15 +90,11 @@ def get_embedding_function():
|
|
| 89 |
def setup_database():
|
| 90 |
try:
|
| 91 |
embedding_function = get_embedding_function()
|
| 92 |
-
|
| 93 |
-
# Create dataset collection
|
| 94 |
dataset_collection = client.get_or_create_collection(
|
| 95 |
embedding_function=embedding_function,
|
| 96 |
name="dataset_cards",
|
| 97 |
metadata={"hnsw:space": "cosine"},
|
| 98 |
)
|
| 99 |
-
|
| 100 |
-
# Create model collection
|
| 101 |
model_collection = client.get_or_create_collection(
|
| 102 |
embedding_function=embedding_function,
|
| 103 |
name="model_cards",
|
|
@@ -111,26 +108,52 @@ def setup_database():
|
|
| 111 |
df = df.filter(
|
| 112 |
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
|
| 113 |
)
|
| 114 |
-
row_count = df.select(pl.len()).collect().item()
|
| 115 |
-
logger.info(f"Row count of dataset data: {row_count}")
|
| 116 |
|
| 117 |
-
#
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
logger.info(
|
| 123 |
-
f"
|
| 124 |
-
)
|
| 125 |
-
# Load parquet files and upsert into ChromaDB
|
| 126 |
-
df = df.select(
|
| 127 |
-
["datasetId", "summary", "likes", "downloads", "last_modified"]
|
| 128 |
)
|
| 129 |
-
df = df.collect()
|
| 130 |
-
total_rows = len(df)
|
| 131 |
|
| 132 |
for i in range(0, total_rows, BATCH_SIZE):
|
| 133 |
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
dataset_collection.upsert(
|
| 136 |
ids=batch_df.select(["datasetId"]).to_series().to_list(),
|
|
@@ -148,9 +171,11 @@ def setup_database():
|
|
| 148 |
)
|
| 149 |
],
|
| 150 |
)
|
| 151 |
-
logger.info(f"Processed {i +
|
| 152 |
|
| 153 |
-
logger.info(
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Load model data
|
| 156 |
model_df = pl.scan_parquet(
|
|
|
|
| 13 |
from huggingface_hub import HfApi
|
| 14 |
from transformers import AutoTokenizer
|
| 15 |
import torch
|
| 16 |
+
import dateutil.parser
|
| 17 |
|
| 18 |
# Configuration constants
|
| 19 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
|
|
|
| 90 |
def setup_database():
|
| 91 |
try:
|
| 92 |
embedding_function = get_embedding_function()
|
|
|
|
|
|
|
| 93 |
dataset_collection = client.get_or_create_collection(
|
| 94 |
embedding_function=embedding_function,
|
| 95 |
name="dataset_cards",
|
| 96 |
metadata={"hnsw:space": "cosine"},
|
| 97 |
)
|
|
|
|
|
|
|
| 98 |
model_collection = client.get_or_create_collection(
|
| 99 |
embedding_function=embedding_function,
|
| 100 |
name="model_cards",
|
|
|
|
| 108 |
df = df.filter(
|
| 109 |
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
|
| 110 |
)
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# Get the most recent last_modified date from the collection
|
| 113 |
+
latest_update = None
|
| 114 |
+
if dataset_collection.count() > 0:
|
| 115 |
+
metadata = dataset_collection.get(include=["metadatas"]).get("metadatas")
|
| 116 |
+
logger.info(f"Found {len(metadata)} existing records in collection")
|
| 117 |
+
|
| 118 |
+
last_modifieds = [
|
| 119 |
+
dateutil.parser.parse(m.get("last_modified")) for m in metadata
|
| 120 |
+
]
|
| 121 |
+
latest_update = max(last_modifieds)
|
| 122 |
+
logger.info(f"Most recent record in DB from: {latest_update}")
|
| 123 |
+
logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
|
| 124 |
+
|
| 125 |
+
# Filter and process only newer records
|
| 126 |
+
df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
|
| 127 |
+
|
| 128 |
+
# Log some stats about the incoming data
|
| 129 |
+
sample_dates = df.select("last_modified").limit(5).collect()
|
| 130 |
+
logger.info(f"Sample of incoming dates: {sample_dates}")
|
| 131 |
+
|
| 132 |
+
total_incoming = df.select(pl.len()).collect().item()
|
| 133 |
+
logger.info(f"Total incoming records: {total_incoming}")
|
| 134 |
+
|
| 135 |
+
if latest_update:
|
| 136 |
+
logger.info(f"Filtering records newer than {latest_update}")
|
| 137 |
+
df = df.filter(pl.col("last_modified") > latest_update)
|
| 138 |
+
filtered_count = df.select(pl.len()).collect().item()
|
| 139 |
+
logger.info(f"Found {filtered_count} records to update after filtering")
|
| 140 |
|
| 141 |
+
df = df.collect()
|
| 142 |
+
total_rows = len(df)
|
| 143 |
+
|
| 144 |
+
if total_rows > 0:
|
| 145 |
+
logger.info(f"Updating dataset collection with {total_rows} new records")
|
| 146 |
logger.info(
|
| 147 |
+
f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
|
|
|
|
|
|
| 149 |
|
| 150 |
for i in range(0, total_rows, BATCH_SIZE):
|
| 151 |
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
|
| 152 |
+
batch_size = len(batch_df)
|
| 153 |
+
logger.info(
|
| 154 |
+
f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records "
|
| 155 |
+
f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
|
| 156 |
+
)
|
| 157 |
|
| 158 |
dataset_collection.upsert(
|
| 159 |
ids=batch_df.select(["datasetId"]).to_series().to_list(),
|
|
|
|
| 171 |
)
|
| 172 |
],
|
| 173 |
)
|
| 174 |
+
logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
|
| 175 |
|
| 176 |
+
logger.info(
|
| 177 |
+
f"Database initialized with {dataset_collection.count():,} total rows"
|
| 178 |
+
)
|
| 179 |
|
| 180 |
# Load model data
|
| 181 |
model_df = pl.scan_parquet(
|