import os import torch import numpy as np from transformers import AutoTokenizer, AutoModel from flask import Flask, request, jsonify import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) # Qwen3-Embedding-4B model for retrieval MODEL_NAME = "Qwen/Qwen3-Embedding-4B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EMBEDDING_DIM = 2560 # Max dimension for Qwen3-Embedding-4B class EmbeddingModel: def __init__(self): logger.info(f"Loading {MODEL_NAME} on {DEVICE}") self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) self.model = AutoModel.from_pretrained(MODEL_NAME) self.model.to(DEVICE) self.model.eval() logger.info("✅ Model loaded successfully") def encode(self, texts, batch_size=16): """Encode texts to embeddings using Qwen3-Embedding-4B""" if isinstance(texts, str): texts = [texts] embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Qwen3 instruction format for retrieval batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts] inputs = self.tokenizer( batch_texts, padding="longest", truncation=True, max_length=32768, # Qwen3 supports up to 32k context return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = self.model(**inputs) # Use EOS token embedding for Qwen3 eos_token_id = self.tokenizer.eos_token_id sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1 batch_embeddings = [] for j, seq_len in enumerate(sequence_lengths): embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy() batch_embeddings.append(embedding) batch_embeddings = np.array(batch_embeddings) # Normalize embeddings batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True) embeddings.extend(batch_embeddings) return embeddings # Global model instance embedding_model = None def get_model(): global embedding_model if embedding_model is None: embedding_model = EmbeddingModel() return embedding_model @app.route("/", methods=["GET"]) def health_check(): return jsonify({ "status": "healthy", "model": MODEL_NAME, "device": DEVICE, "embedding_dim": EMBEDDING_DIM, "max_context": 32768 }) @app.route("/embed", methods=["POST"]) def embed_texts(): """Embed texts and return embeddings""" try: data = request.get_json() if not data or "texts" not in data: return jsonify({"error": "Missing 'texts' field"}), 400 texts = data["texts"] if not isinstance(texts, list): texts = [texts] logger.info(f"Embedding {len(texts)} texts") model = get_model() embeddings = model.encode(texts) return jsonify({ "embeddings": [embedding.tolist() for embedding in embeddings], "model": MODEL_NAME, "dimension": len(embeddings[0]) if embeddings else 0, "count": len(embeddings) }) except Exception as e: logger.error(f"Embedding error: {str(e)}") return jsonify({"error": str(e)}), 500 @app.route("/embed_single", methods=["POST"]) def embed_single(): """Embed single text (convenience endpoint)""" try: data = request.get_json() if not data or "text" not in data: return jsonify({"error": "Missing 'text' field"}), 400 text = data["text"] logger.info(f"Embedding single text: {text[:100]}...") model = get_model() embeddings = model.encode([text]) return jsonify({ "embedding": embeddings[0].tolist(), "model": MODEL_NAME, "dimension": len(embeddings[0]), "text_length": len(text) }) except Exception as e: logger.error(f"Single embedding error: {str(e)}") return jsonify({"error": str(e)}), 500 if __name__ == "__main__": # Initialize model on startup logger.info("🚀 Starting embedding service...") get_model() logger.info("✅ Service ready!") port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)