ChatbotRAG / embedding_service.py
minhvtt's picture
Upload 6 files
6c982a7 verified
raw
history blame
5.63 kB
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel
from typing import Union, List
import io
class JinaClipEmbeddingService:
"""
Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
Sử dụng AutoModel với trust_remote_code
"""
def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
"""
Initialize Jina CLIP v2 model
Args:
model_path: Path to model hoặc HuggingFace model name
"""
print(f"Loading Jina CLIP v2 model from {model_path}...")
# Load model với trust_remote_code
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
# Chuyển sang eval mode
self.model.eval()
# Sử dụng GPU nếu có
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
print(f"✓ Loaded Jina CLIP v2 model on: {self.device}")
def encode_text(
self,
text: Union[str, List[str]],
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode text thành vector embeddings (hỗ trợ tiếng Việt)
Args:
text: Text hoặc list of texts (tiếng Việt)
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
if isinstance(text, str):
text = [text]
# Jina CLIP v2 encode_text method
# Automatically handles tokenization internally
embeddings = self.model.encode_text(
text,
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
)
# Convert to numpy
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().detach().numpy()
# Normalize nếu cần
if normalize:
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings
def encode_image(
self,
image: Union[Image.Image, bytes, List, str],
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode image thành vector embeddings
Args:
image: PIL Image, bytes, URL string, hoặc list of images
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
# Convert bytes to PIL Image nếu cần
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image)).convert('RGB')
elif isinstance(image, list):
processed_images = []
for img in image:
if isinstance(img, bytes):
processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
elif isinstance(img, str):
# URL string - keep as is, Jina CLIP can handle URLs
processed_images.append(img)
else:
processed_images.append(img)
image = processed_images
elif not isinstance(image, list) and not isinstance(image, str):
# Single PIL Image
image = [image]
# Jina CLIP v2 encode_image method
# Supports PIL Images, file paths, or URLs
embeddings = self.model.encode_image(
image,
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
)
# Convert to numpy
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().detach().numpy()
# Normalize nếu cần
if normalize:
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings
def encode_multimodal(
self,
text: Union[str, List[str]] = None,
image: Union[Image.Image, bytes, List] = None,
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode cả text và image, trả về embeddings kết hợp
Args:
text: Text hoặc list of texts (tiếng Việt)
image: PIL Image, bytes, hoặc list of images
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
embeddings = []
if text is not None:
text_emb = self.encode_text(text, truncate_dim=truncate_dim, normalize=False)
embeddings.append(text_emb)
if image is not None:
image_emb = self.encode_image(image, truncate_dim=truncate_dim, normalize=False)
embeddings.append(image_emb)
# Combine embeddings (average)
if len(embeddings) == 2:
# Average của text và image embeddings
combined = np.mean(embeddings, axis=0)
elif len(embeddings) == 1:
combined = embeddings[0]
else:
raise ValueError("Phải cung cấp ít nhất text hoặc image")
# Normalize nếu cần
if normalize:
combined = combined / np.linalg.norm(combined, axis=1, keepdims=True)
return combined
def get_embedding_dimension(self) -> int:
"""
Trả về dimension của embeddings (1024 cho Jina CLIP v2)
"""
return 1024