Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from flask import Flask, request, jsonify | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| # Qwen3-Embedding-4B model for retrieval | |
| MODEL_NAME = "Qwen/Qwen3-Embedding-4B" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| EMBEDDING_DIM = 2560 # Max dimension for Qwen3-Embedding-4B | |
| class EmbeddingModel: | |
| def __init__(self): | |
| logger.info(f"Loading {MODEL_NAME} on {DEVICE}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| self.model = AutoModel.from_pretrained(MODEL_NAME) | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| logger.info("β Model loaded successfully") | |
| def encode(self, texts, batch_size=16): | |
| """Encode texts to embeddings using Qwen3-Embedding-4B""" | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| # Qwen3 instruction format for retrieval | |
| batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts] | |
| inputs = self.tokenizer( | |
| batch_texts, | |
| padding="longest", | |
| truncation=True, | |
| max_length=32768, # Qwen3 supports up to 32k context | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Use EOS token embedding for Qwen3 | |
| eos_token_id = self.tokenizer.eos_token_id | |
| sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1 | |
| batch_embeddings = [] | |
| for j, seq_len in enumerate(sequence_lengths): | |
| embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy() | |
| batch_embeddings.append(embedding) | |
| batch_embeddings = np.array(batch_embeddings) | |
| # Normalize embeddings | |
| batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True) | |
| embeddings.extend(batch_embeddings) | |
| return embeddings | |
| # Global model instance | |
| embedding_model = None | |
| def get_model(): | |
| global embedding_model | |
| if embedding_model is None: | |
| embedding_model = EmbeddingModel() | |
| return embedding_model | |
| def health_check(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "model": MODEL_NAME, | |
| "device": DEVICE, | |
| "embedding_dim": EMBEDDING_DIM, | |
| "max_context": 32768 | |
| }) | |
| def embed_texts(): | |
| """Embed texts and return embeddings""" | |
| try: | |
| data = request.get_json() | |
| if not data or "texts" not in data: | |
| return jsonify({"error": "Missing 'texts' field"}), 400 | |
| texts = data["texts"] | |
| if not isinstance(texts, list): | |
| texts = [texts] | |
| logger.info(f"Embedding {len(texts)} texts") | |
| model = get_model() | |
| embeddings = model.encode(texts) | |
| return jsonify({ | |
| "embeddings": [embedding.tolist() for embedding in embeddings], | |
| "model": MODEL_NAME, | |
| "dimension": len(embeddings[0]) if embeddings else 0, | |
| "count": len(embeddings) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Embedding error: {str(e)}") | |
| return jsonify({"error": str(e)}), 500 | |
| def embed_single(): | |
| """Embed single text (convenience endpoint)""" | |
| try: | |
| data = request.get_json() | |
| if not data or "text" not in data: | |
| return jsonify({"error": "Missing 'text' field"}), 400 | |
| text = data["text"] | |
| logger.info(f"Embedding single text: {text[:100]}...") | |
| model = get_model() | |
| embeddings = model.encode([text]) | |
| return jsonify({ | |
| "embedding": embeddings[0].tolist(), | |
| "model": MODEL_NAME, | |
| "dimension": len(embeddings[0]), | |
| "text_length": len(text) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Single embedding error: {str(e)}") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| # Initialize model on startup | |
| logger.info("π Starting embedding service...") | |
| get_model() | |
| logger.info("β Service ready!") | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |