raghss123123's picture
Update app.py
ea8e9c7 verified
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from flask import Flask, request, jsonify
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Qwen3-Embedding-4B model for retrieval
MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMBEDDING_DIM = 2560 # Max dimension for Qwen3-Embedding-4B
class EmbeddingModel:
def __init__(self):
logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = AutoModel.from_pretrained(MODEL_NAME)
self.model.to(DEVICE)
self.model.eval()
logger.info("βœ… Model loaded successfully")
def encode(self, texts, batch_size=16):
"""Encode texts to embeddings using Qwen3-Embedding-4B"""
if isinstance(texts, str):
texts = [texts]
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Qwen3 instruction format for retrieval
batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts]
inputs = self.tokenizer(
batch_texts,
padding="longest",
truncation=True,
max_length=32768, # Qwen3 supports up to 32k context
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = self.model(**inputs)
# Use EOS token embedding for Qwen3
eos_token_id = self.tokenizer.eos_token_id
sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
batch_embeddings = []
for j, seq_len in enumerate(sequence_lengths):
embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy()
batch_embeddings.append(embedding)
batch_embeddings = np.array(batch_embeddings)
# Normalize embeddings
batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
embeddings.extend(batch_embeddings)
return embeddings
# Global model instance
embedding_model = None
def get_model():
global embedding_model
if embedding_model is None:
embedding_model = EmbeddingModel()
return embedding_model
@app.route("/", methods=["GET"])
def health_check():
return jsonify({
"status": "healthy",
"model": MODEL_NAME,
"device": DEVICE,
"embedding_dim": EMBEDDING_DIM,
"max_context": 32768
})
@app.route("/embed", methods=["POST"])
def embed_texts():
"""Embed texts and return embeddings"""
try:
data = request.get_json()
if not data or "texts" not in data:
return jsonify({"error": "Missing 'texts' field"}), 400
texts = data["texts"]
if not isinstance(texts, list):
texts = [texts]
logger.info(f"Embedding {len(texts)} texts")
model = get_model()
embeddings = model.encode(texts)
return jsonify({
"embeddings": [embedding.tolist() for embedding in embeddings],
"model": MODEL_NAME,
"dimension": len(embeddings[0]) if embeddings else 0,
"count": len(embeddings)
})
except Exception as e:
logger.error(f"Embedding error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route("/embed_single", methods=["POST"])
def embed_single():
"""Embed single text (convenience endpoint)"""
try:
data = request.get_json()
if not data or "text" not in data:
return jsonify({"error": "Missing 'text' field"}), 400
text = data["text"]
logger.info(f"Embedding single text: {text[:100]}...")
model = get_model()
embeddings = model.encode([text])
return jsonify({
"embedding": embeddings[0].tolist(),
"model": MODEL_NAME,
"dimension": len(embeddings[0]),
"text_length": len(text)
})
except Exception as e:
logger.error(f"Single embedding error: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# Initialize model on startup
logger.info("πŸš€ Starting embedding service...")
get_model()
logger.info("βœ… Service ready!")
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)