File size: 4,901 Bytes
48b2e5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea8e9c7
48b2e5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from flask import Flask, request, jsonify
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = Flask(__name__)

# Qwen3-Embedding-4B model for retrieval
MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMBEDDING_DIM = 2560  # Max dimension for Qwen3-Embedding-4B

class EmbeddingModel:
    def __init__(self):
        logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        self.model = AutoModel.from_pretrained(MODEL_NAME)
        self.model.to(DEVICE)
        self.model.eval()
        logger.info("βœ… Model loaded successfully")
    
    def encode(self, texts, batch_size=16):
        """Encode texts to embeddings using Qwen3-Embedding-4B"""
        if isinstance(texts, str):
            texts = [texts]
        
        embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            
            # Qwen3 instruction format for retrieval
            batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts]
            
            inputs = self.tokenizer(
                batch_texts,
                padding="longest",  
                truncation=True,
                max_length=32768,  # Qwen3 supports up to 32k context
                return_tensors="pt"
            ).to(DEVICE)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                # Use EOS token embedding for Qwen3
                eos_token_id = self.tokenizer.eos_token_id
                sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
                
                batch_embeddings = []
                for j, seq_len in enumerate(sequence_lengths):
                    embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy()
                    batch_embeddings.append(embedding)
                
                batch_embeddings = np.array(batch_embeddings)
                
                # Normalize embeddings
                batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
                
                embeddings.extend(batch_embeddings)
        
        return embeddings

# Global model instance
embedding_model = None

def get_model():
    global embedding_model
    if embedding_model is None:
        embedding_model = EmbeddingModel()
    return embedding_model

@app.route("/", methods=["GET"])
def health_check():
    return jsonify({
        "status": "healthy",
        "model": MODEL_NAME,
        "device": DEVICE,
        "embedding_dim": EMBEDDING_DIM,
        "max_context": 32768
    })

@app.route("/embed", methods=["POST"])
def embed_texts():
    """Embed texts and return embeddings"""
    try:
        data = request.get_json()
        
        if not data or "texts" not in data:
            return jsonify({"error": "Missing 'texts' field"}), 400
        
        texts = data["texts"]
        if not isinstance(texts, list):
            texts = [texts]
        
        logger.info(f"Embedding {len(texts)} texts")
        
        model = get_model()
        embeddings = model.encode(texts)
        
        return jsonify({
            "embeddings": [embedding.tolist() for embedding in embeddings],
            "model": MODEL_NAME,
            "dimension": len(embeddings[0]) if embeddings else 0,
            "count": len(embeddings)
        })
        
    except Exception as e:
        logger.error(f"Embedding error: {str(e)}")
        return jsonify({"error": str(e)}), 500

@app.route("/embed_single", methods=["POST"])
def embed_single():
    """Embed single text (convenience endpoint)"""
    try:
        data = request.get_json()
        
        if not data or "text" not in data:
            return jsonify({"error": "Missing 'text' field"}), 400
        
        text = data["text"]
        logger.info(f"Embedding single text: {text[:100]}...")
        
        model = get_model()
        embeddings = model.encode([text])
        
        return jsonify({
            "embedding": embeddings[0].tolist(),
            "model": MODEL_NAME,
            "dimension": len(embeddings[0]),
            "text_length": len(text)
        })
        
    except Exception as e:
        logger.error(f"Single embedding error: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    # Initialize model on startup
    logger.info("πŸš€ Starting embedding service...")
    get_model()
    logger.info("βœ… Service ready!")
    
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port)