Spaces:
Running
Running
fix: truncate texts that are longer than allowed
Browse files
lightweight_embeddings/service.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
|
@@ -44,6 +46,21 @@ class ImageModelType(str, Enum):
|
|
| 44 |
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
class ModelInfo(NamedTuple):
|
| 48 |
"""
|
| 49 |
Container mapping a model type to its model identifier and optional ONNX file.
|
|
@@ -200,6 +217,14 @@ class EmbeddingsService:
|
|
| 200 |
device=self.device,
|
| 201 |
trust_remote_code=True,
|
| 202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Preload image models.
|
| 205 |
for i_model_type in ImageModelType:
|
|
@@ -265,6 +290,37 @@ class EmbeddingsService:
|
|
| 265 |
|
| 266 |
return input_images
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
async def _fetch_image(self, path_or_url: str) -> Image.Image:
|
| 269 |
"""
|
| 270 |
Asynchronously fetch an image from a URL or load from a local path.
|
|
@@ -312,14 +368,16 @@ class EmbeddingsService:
|
|
| 312 |
return processed_data
|
| 313 |
|
| 314 |
def _generate_text_embeddings(
|
| 315 |
-
self,
|
| 316 |
-
model_id: TextModelType,
|
| 317 |
-
texts: List[str],
|
| 318 |
) -> np.ndarray:
|
| 319 |
"""
|
| 320 |
Generate text embeddings using the SentenceTransformer model.
|
| 321 |
Single-text requests are cached using an LRU cache.
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
Returns:
|
| 324 |
A NumPy array of text embeddings.
|
| 325 |
|
|
@@ -345,9 +403,7 @@ class EmbeddingsService:
|
|
| 345 |
) from e
|
| 346 |
|
| 347 |
async def _async_generate_image_embeddings(
|
| 348 |
-
self,
|
| 349 |
-
model_id: ImageModelType,
|
| 350 |
-
images: List[str],
|
| 351 |
) -> np.ndarray:
|
| 352 |
"""
|
| 353 |
Asynchronously generate image embeddings.
|
|
@@ -355,6 +411,10 @@ class EmbeddingsService:
|
|
| 355 |
This method concurrently processes multiple images and offloads
|
| 356 |
the blocking model inference to a separate thread.
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
Returns:
|
| 359 |
A NumPy array of image embeddings.
|
| 360 |
|
|
@@ -386,9 +446,7 @@ class EmbeddingsService:
|
|
| 386 |
) from e
|
| 387 |
|
| 388 |
async def generate_embeddings(
|
| 389 |
-
self,
|
| 390 |
-
model: str,
|
| 391 |
-
inputs: Union[str, List[str]],
|
| 392 |
) -> np.ndarray:
|
| 393 |
"""
|
| 394 |
Asynchronously generate embeddings for text or image inputs based on model type.
|
|
@@ -402,16 +460,21 @@ class EmbeddingsService:
|
|
| 402 |
"""
|
| 403 |
modality = detect_model_kind(model)
|
| 404 |
if modality == ModelKind.TEXT:
|
| 405 |
-
|
| 406 |
text_list = self._validate_text_list(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
return await asyncio.to_thread(
|
| 408 |
-
self._generate_text_embeddings,
|
| 409 |
)
|
| 410 |
elif modality == ModelKind.IMAGE:
|
| 411 |
-
|
| 412 |
image_list = self._validate_image_list(inputs)
|
| 413 |
return await self._async_generate_image_embeddings(
|
| 414 |
-
|
| 415 |
)
|
| 416 |
|
| 417 |
async def rank(
|
|
@@ -424,6 +487,11 @@ class EmbeddingsService:
|
|
| 424 |
Asynchronously rank candidate texts/images against the provided queries.
|
| 425 |
Embeddings for queries and candidates are generated concurrently.
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
Returns:
|
| 428 |
A dictionary containing probabilities, cosine similarities, and usage statistics.
|
| 429 |
"""
|
|
@@ -469,6 +537,9 @@ class EmbeddingsService:
|
|
| 469 |
"""
|
| 470 |
Estimate the token count for the given text input using the SentenceTransformer tokenizer.
|
| 471 |
|
|
|
|
|
|
|
|
|
|
| 472 |
Returns:
|
| 473 |
The total number of tokens.
|
| 474 |
"""
|
|
@@ -482,8 +553,11 @@ class EmbeddingsService:
|
|
| 482 |
"""
|
| 483 |
Compute the softmax over the last dimension of the input array.
|
| 484 |
|
|
|
|
|
|
|
|
|
|
| 485 |
Returns:
|
| 486 |
-
|
| 487 |
"""
|
| 488 |
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
|
| 489 |
return exps / np.sum(exps, axis=-1, keepdims=True)
|
|
@@ -493,6 +567,10 @@ class EmbeddingsService:
|
|
| 493 |
"""
|
| 494 |
Compute the pairwise cosine similarity between all rows of arrays a and b.
|
| 495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
Returns:
|
| 497 |
A (N x M) matrix of cosine similarities.
|
| 498 |
"""
|
|
|
|
| 1 |
+
# filename: service.py
|
| 2 |
+
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import asyncio
|
|
|
|
| 46 |
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
|
| 47 |
|
| 48 |
|
| 49 |
+
class MaxModelLength(str, Enum):
|
| 50 |
+
"""
|
| 51 |
+
Enumeration of maximum token lengths for supported text models.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
MULTILINGUAL_E5_SMALL = 512
|
| 55 |
+
MULTILINGUAL_E5_BASE = 512
|
| 56 |
+
MULTILINGUAL_E5_LARGE = 512
|
| 57 |
+
SNOWFLAKE_ARCTIC_EMBED_L_V2 = 8192
|
| 58 |
+
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = 128
|
| 59 |
+
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = 128
|
| 60 |
+
BGE_M3 = 8192
|
| 61 |
+
GTE_MULTILINGUAL_BASE = 8192
|
| 62 |
+
|
| 63 |
+
|
| 64 |
class ModelInfo(NamedTuple):
|
| 65 |
"""
|
| 66 |
Container mapping a model type to its model identifier and optional ONNX file.
|
|
|
|
| 217 |
device=self.device,
|
| 218 |
trust_remote_code=True,
|
| 219 |
)
|
| 220 |
+
# Set maximum sequence length based on configuration.
|
| 221 |
+
max_length = int(MaxModelLength[t_model_type.name].value)
|
| 222 |
+
self.text_models[t_model_type].max_seq_length = max_length
|
| 223 |
+
logger.info(
|
| 224 |
+
"Set max_seq_length=%d for text model: %s",
|
| 225 |
+
max_length,
|
| 226 |
+
info.model_id,
|
| 227 |
+
)
|
| 228 |
|
| 229 |
# Preload image models.
|
| 230 |
for i_model_type in ImageModelType:
|
|
|
|
| 290 |
|
| 291 |
return input_images
|
| 292 |
|
| 293 |
+
def _truncate_text(self, text: str, model: SentenceTransformer) -> str:
|
| 294 |
+
"""
|
| 295 |
+
Truncate the input text to the maximum allowed tokens for the given model.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
text: The input text.
|
| 299 |
+
model: The SentenceTransformer model used for tokenization.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
The truncated text if token length exceeds the maximum allowed length,
|
| 303 |
+
otherwise the original text.
|
| 304 |
+
"""
|
| 305 |
+
try:
|
| 306 |
+
# Attempt to get the tokenizer from the first module of the SentenceTransformer.
|
| 307 |
+
module = model._first_module()
|
| 308 |
+
if not hasattr(module, "tokenizer"):
|
| 309 |
+
return text
|
| 310 |
+
tokenizer = module.tokenizer
|
| 311 |
+
# Tokenize without truncation.
|
| 312 |
+
encoded = tokenizer(text, add_special_tokens=True, truncation=False)
|
| 313 |
+
max_length = model.max_seq_length
|
| 314 |
+
if len(encoded["input_ids"]) > max_length:
|
| 315 |
+
truncated_ids = encoded["input_ids"][:max_length]
|
| 316 |
+
truncated_text = tokenizer.decode(
|
| 317 |
+
truncated_ids, skip_special_tokens=True
|
| 318 |
+
)
|
| 319 |
+
return truncated_text
|
| 320 |
+
except Exception as e:
|
| 321 |
+
logger.warning("Error during text truncation: %s", str(e))
|
| 322 |
+
return text
|
| 323 |
+
|
| 324 |
async def _fetch_image(self, path_or_url: str) -> Image.Image:
|
| 325 |
"""
|
| 326 |
Asynchronously fetch an image from a URL or load from a local path.
|
|
|
|
| 368 |
return processed_data
|
| 369 |
|
| 370 |
def _generate_text_embeddings(
|
| 371 |
+
self, model_id: TextModelType, texts: List[str]
|
|
|
|
|
|
|
| 372 |
) -> np.ndarray:
|
| 373 |
"""
|
| 374 |
Generate text embeddings using the SentenceTransformer model.
|
| 375 |
Single-text requests are cached using an LRU cache.
|
| 376 |
|
| 377 |
+
Args:
|
| 378 |
+
model_id: The text model type.
|
| 379 |
+
texts: A list of input texts.
|
| 380 |
+
|
| 381 |
Returns:
|
| 382 |
A NumPy array of text embeddings.
|
| 383 |
|
|
|
|
| 403 |
) from e
|
| 404 |
|
| 405 |
async def _async_generate_image_embeddings(
|
| 406 |
+
self, model_id: ImageModelType, images: List[str]
|
|
|
|
|
|
|
| 407 |
) -> np.ndarray:
|
| 408 |
"""
|
| 409 |
Asynchronously generate image embeddings.
|
|
|
|
| 411 |
This method concurrently processes multiple images and offloads
|
| 412 |
the blocking model inference to a separate thread.
|
| 413 |
|
| 414 |
+
Args:
|
| 415 |
+
model_id: The image model type.
|
| 416 |
+
images: A list of image URLs or file paths.
|
| 417 |
+
|
| 418 |
Returns:
|
| 419 |
A NumPy array of image embeddings.
|
| 420 |
|
|
|
|
| 446 |
) from e
|
| 447 |
|
| 448 |
async def generate_embeddings(
|
| 449 |
+
self, model: str, inputs: Union[str, List[str]]
|
|
|
|
|
|
|
| 450 |
) -> np.ndarray:
|
| 451 |
"""
|
| 452 |
Asynchronously generate embeddings for text or image inputs based on model type.
|
|
|
|
| 460 |
"""
|
| 461 |
modality = detect_model_kind(model)
|
| 462 |
if modality == ModelKind.TEXT:
|
| 463 |
+
text_model_enum = TextModelType(model)
|
| 464 |
text_list = self._validate_text_list(inputs)
|
| 465 |
+
model_instance = self.text_models[text_model_enum]
|
| 466 |
+
# Truncate each text if it exceeds the maximum allowed token length.
|
| 467 |
+
truncated_texts = [
|
| 468 |
+
self._truncate_text(text, model_instance) for text in text_list
|
| 469 |
+
]
|
| 470 |
return await asyncio.to_thread(
|
| 471 |
+
self._generate_text_embeddings, text_model_enum, truncated_texts
|
| 472 |
)
|
| 473 |
elif modality == ModelKind.IMAGE:
|
| 474 |
+
image_model_enum = ImageModelType(model)
|
| 475 |
image_list = self._validate_image_list(inputs)
|
| 476 |
return await self._async_generate_image_embeddings(
|
| 477 |
+
image_model_enum, image_list
|
| 478 |
)
|
| 479 |
|
| 480 |
async def rank(
|
|
|
|
| 487 |
Asynchronously rank candidate texts/images against the provided queries.
|
| 488 |
Embeddings for queries and candidates are generated concurrently.
|
| 489 |
|
| 490 |
+
Args:
|
| 491 |
+
model: The model identifier.
|
| 492 |
+
queries: The query input(s).
|
| 493 |
+
candidates: The candidate input(s).
|
| 494 |
+
|
| 495 |
Returns:
|
| 496 |
A dictionary containing probabilities, cosine similarities, and usage statistics.
|
| 497 |
"""
|
|
|
|
| 537 |
"""
|
| 538 |
Estimate the token count for the given text input using the SentenceTransformer tokenizer.
|
| 539 |
|
| 540 |
+
Args:
|
| 541 |
+
input_data: The text input(s).
|
| 542 |
+
|
| 543 |
Returns:
|
| 544 |
The total number of tokens.
|
| 545 |
"""
|
|
|
|
| 553 |
"""
|
| 554 |
Compute the softmax over the last dimension of the input array.
|
| 555 |
|
| 556 |
+
Args:
|
| 557 |
+
scores: A NumPy array of scores.
|
| 558 |
+
|
| 559 |
Returns:
|
| 560 |
+
A NumPy array of softmax probabilities.
|
| 561 |
"""
|
| 562 |
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
|
| 563 |
return exps / np.sum(exps, axis=-1, keepdims=True)
|
|
|
|
| 567 |
"""
|
| 568 |
Compute the pairwise cosine similarity between all rows of arrays a and b.
|
| 569 |
|
| 570 |
+
Args:
|
| 571 |
+
a: A NumPy array.
|
| 572 |
+
b: A NumPy array.
|
| 573 |
+
|
| 574 |
Returns:
|
| 575 |
A (N x M) matrix of cosine similarities.
|
| 576 |
"""
|