# app.py (FastAPI server to host the Jina Embedding model) # Must be set before importing Hugging Face libs import os os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" from fastapi import FastAPI from pydantic import BaseModel from typing import List, Optional import torch from transformers import AutoModel, AutoTokenizer app = FastAPI() # ----------------------------- # Load model once on startup # ----------------------------- MODEL_NAME = "jinaai/jina-embeddings-v4" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModel.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16 ).to(device) model.eval() # ----------------------------- # Request / Response Models # ----------------------------- class EmbedRequest(BaseModel): text: str task: str = "retrieval" # "retrieval", "text-matching", "code", etc. prompt_name: Optional[str] = None return_token_embeddings: bool = True # False → for queries (pooled embedding) class EmbedResponse(BaseModel): embeddings: List[List[float]] # (num_tokens, hidden_dim) if token-level # (1, hidden_dim) if pooled query class TokenizeRequest(BaseModel): text: str class TokenizeResponse(BaseModel): input_ids: List[int] class DecodeRequest(BaseModel): input_ids: List[int] class DecodeResponse(BaseModel): text: str # ----------------------------- # Embedding Endpoint # ----------------------------- @app.post("/embed", response_model=EmbedResponse) def embed(req: EmbedRequest): text = req.text # ----------------------------- # Case 1: Query → directly pooled embedding # ----------------------------- if not req.return_token_embeddings: with torch.no_grad(): emb = model.encode_text( texts=[text], task=req.task, prompt_name=req.prompt_name or "query", return_multivector=False ) return {"embeddings": emb.tolist()} # shape: (1, hidden_dim) # ----------------------------- # Case 2: Long passages → sliding window token embeddings # ----------------------------- enc = tokenizer(text, add_special_tokens=False, return_tensors="pt") input_ids = enc["input_ids"].squeeze(0).to(device) # (total_tokens,) total_tokens = input_ids.size(0) max_len = model.config.max_position_embeddings # e.g., 32k for v4 stride = 50 # overlap for sliding window embeddings = [] position = 0 while position < total_tokens: end = min(position + max_len, total_tokens) window_ids = input_ids[position:end].unsqueeze(0).to(device) with torch.no_grad(): outputs = model.encode_text( texts=[tokenizer.decode(window_ids[0])], task=req.task, prompt_name=req.prompt_name or "passage", return_multivector=True, ) window_embeds = outputs.squeeze(0).cpu() # (window_len, hidden_dim) # Drop overlapping tokens except in first window if position > 0: window_embeds = window_embeds[stride:] embeddings.append(window_embeds) # Advance window position += max_len - stride full_embeddings = torch.cat(embeddings, dim=0) # (total_tokens, hidden_dim) return {"embeddings": full_embeddings.tolist()} # ----------------------------- # Tokenize Endpoint # ----------------------------- @app.post("/tokenize", response_model=TokenizeResponse) def tokenize(req: TokenizeRequest): enc = tokenizer(req.text, add_special_tokens=False) return {"input_ids": enc["input_ids"]} # ----------------------------- # Decode Endpoint # ----------------------------- @app.post("/decode", response_model=DecodeResponse) def decode(req: DecodeRequest): decoded = tokenizer.decode(req.input_ids) return {"text": decoded}