Spaces:
Sleeping
Sleeping
File size: 4,901 Bytes
48b2e5c ea8e9c7 48b2e5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from flask import Flask, request, jsonify
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Qwen3-Embedding-4B model for retrieval
MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMBEDDING_DIM = 2560 # Max dimension for Qwen3-Embedding-4B
class EmbeddingModel:
def __init__(self):
logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = AutoModel.from_pretrained(MODEL_NAME)
self.model.to(DEVICE)
self.model.eval()
logger.info("β
Model loaded successfully")
def encode(self, texts, batch_size=16):
"""Encode texts to embeddings using Qwen3-Embedding-4B"""
if isinstance(texts, str):
texts = [texts]
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Qwen3 instruction format for retrieval
batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts]
inputs = self.tokenizer(
batch_texts,
padding="longest",
truncation=True,
max_length=32768, # Qwen3 supports up to 32k context
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = self.model(**inputs)
# Use EOS token embedding for Qwen3
eos_token_id = self.tokenizer.eos_token_id
sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
batch_embeddings = []
for j, seq_len in enumerate(sequence_lengths):
embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy()
batch_embeddings.append(embedding)
batch_embeddings = np.array(batch_embeddings)
# Normalize embeddings
batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
embeddings.extend(batch_embeddings)
return embeddings
# Global model instance
embedding_model = None
def get_model():
global embedding_model
if embedding_model is None:
embedding_model = EmbeddingModel()
return embedding_model
@app.route("/", methods=["GET"])
def health_check():
return jsonify({
"status": "healthy",
"model": MODEL_NAME,
"device": DEVICE,
"embedding_dim": EMBEDDING_DIM,
"max_context": 32768
})
@app.route("/embed", methods=["POST"])
def embed_texts():
"""Embed texts and return embeddings"""
try:
data = request.get_json()
if not data or "texts" not in data:
return jsonify({"error": "Missing 'texts' field"}), 400
texts = data["texts"]
if not isinstance(texts, list):
texts = [texts]
logger.info(f"Embedding {len(texts)} texts")
model = get_model()
embeddings = model.encode(texts)
return jsonify({
"embeddings": [embedding.tolist() for embedding in embeddings],
"model": MODEL_NAME,
"dimension": len(embeddings[0]) if embeddings else 0,
"count": len(embeddings)
})
except Exception as e:
logger.error(f"Embedding error: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route("/embed_single", methods=["POST"])
def embed_single():
"""Embed single text (convenience endpoint)"""
try:
data = request.get_json()
if not data or "text" not in data:
return jsonify({"error": "Missing 'text' field"}), 400
text = data["text"]
logger.info(f"Embedding single text: {text[:100]}...")
model = get_model()
embeddings = model.encode([text])
return jsonify({
"embedding": embeddings[0].tolist(),
"model": MODEL_NAME,
"dimension": len(embeddings[0]),
"text_length": len(text)
})
except Exception as e:
logger.error(f"Single embedding error: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# Initialize model on startup
logger.info("π Starting embedding service...")
get_model()
logger.info("β
Service ready!")
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port) |