ChatbotRAG / embedding_service.py
minhvtt's picture
Upload 6 files
6c982a7 verified
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel
from typing import Union, List
import io
class JinaClipEmbeddingService:
"""
Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
Sử dụng AutoModel với trust_remote_code
"""
def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
"""
Initialize Jina CLIP v2 model
Args:
model_path: Path to model hoặc HuggingFace model name
"""
print(f"Loading Jina CLIP v2 model from {model_path}...")
# Load model với trust_remote_code
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
# Chuyển sang eval mode
self.model.eval()
# Sử dụng GPU nếu có
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
print(f"✓ Loaded Jina CLIP v2 model on: {self.device}")
def encode_text(
self,
text: Union[str, List[str]],
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode text thành vector embeddings (hỗ trợ tiếng Việt)
Args:
text: Text hoặc list of texts (tiếng Việt)
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
if isinstance(text, str):
text = [text]
# Jina CLIP v2 encode_text method
# Automatically handles tokenization internally
embeddings = self.model.encode_text(
text,
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
)
# Convert to numpy
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().detach().numpy()
# Normalize nếu cần
if normalize:
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings
def encode_image(
self,
image: Union[Image.Image, bytes, List, str],
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode image thành vector embeddings
Args:
image: PIL Image, bytes, URL string, hoặc list of images
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
# Convert bytes to PIL Image nếu cần
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image)).convert('RGB')
elif isinstance(image, list):
processed_images = []
for img in image:
if isinstance(img, bytes):
processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
elif isinstance(img, str):
# URL string - keep as is, Jina CLIP can handle URLs
processed_images.append(img)
else:
processed_images.append(img)
image = processed_images
elif not isinstance(image, list) and not isinstance(image, str):
# Single PIL Image
image = [image]
# Jina CLIP v2 encode_image method
# Supports PIL Images, file paths, or URLs
embeddings = self.model.encode_image(
image,
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
)
# Convert to numpy
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().detach().numpy()
# Normalize nếu cần
if normalize:
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings
def encode_multimodal(
self,
text: Union[str, List[str]] = None,
image: Union[Image.Image, bytes, List] = None,
truncate_dim: int = None,
normalize: bool = True
) -> np.ndarray:
"""
Encode cả text và image, trả về embeddings kết hợp
Args:
text: Text hoặc list of texts (tiếng Việt)
image: PIL Image, bytes, hoặc list of images
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
normalize: Có normalize embeddings không
Returns:
numpy array của embeddings
"""
embeddings = []
if text is not None:
text_emb = self.encode_text(text, truncate_dim=truncate_dim, normalize=False)
embeddings.append(text_emb)
if image is not None:
image_emb = self.encode_image(image, truncate_dim=truncate_dim, normalize=False)
embeddings.append(image_emb)
# Combine embeddings (average)
if len(embeddings) == 2:
# Average của text và image embeddings
combined = np.mean(embeddings, axis=0)
elif len(embeddings) == 1:
combined = embeddings[0]
else:
raise ValueError("Phải cung cấp ít nhất text hoặc image")
# Normalize nếu cần
if normalize:
combined = combined / np.linalg.norm(combined, axis=1, keepdims=True)
return combined
def get_embedding_dimension(self) -> int:
"""
Trả về dimension của embeddings (1024 cho Jina CLIP v2)
"""
return 1024