iBrokeTheCode commited on
Commit
df5c96c
Β·
1 Parent(s): 0225bda

chore: Get embeddings and predictions

Browse files
Files changed (1) hide show
  1. predictor.py +158 -42
predictor.py CHANGED
@@ -1,60 +1,176 @@
 
 
 
 
 
 
 
1
  from tensorflow.keras.models import load_model
 
 
 
 
 
 
 
 
 
 
2
 
3
- # TODO: Review Code
4
- # Load the models once at the start of the script
5
- print("πŸ’¬ Loading models...")
 
 
 
 
 
 
 
 
 
 
 
6
  try:
7
  text_model = load_model("./models/text_model")
8
  image_model = load_model("./models/image_model")
9
  multimodal_model = load_model("./models/multimodal_model")
10
- print("βœ… Models loaded successfully!")
11
  except Exception as e:
12
- print(f"❌ Error loading models: {e}")
13
- text_model = None
14
- image_model = None
15
- multimodal_model = None
16
 
17
- # A placeholder for your class labels
18
- CLASS_LABELS = [
19
- "abcat0100000",
20
- "abcat0200000",
21
- "abcat0207000",
22
- ] # Add your actual labels
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # πŸ“Œ FUNCTIONS
26
- def predict(mode, text, image_path):
27
  """
28
- This placeholder function now returns a dictionary
29
- in the format expected by the gr.Label component.
 
 
 
 
 
 
30
  """
31
- multimodal_output = {
32
- "abcat0100000": 0.05,
33
- "abcat0200000": 0.10,
34
- "abcat0300000": 0.20,
35
- "abcat0400000": 0.45,
36
- "abcat0500000": 0.20,
37
- }
38
- text_only_output = {
39
- "abcat0100000": 0.08,
40
- "abcat0200000": 0.15,
41
- "abcat0300000": 0.25,
42
- "abcat0400000": 0.35,
43
- "abcat0500000": 0.17,
44
- }
45
- image_only_output = {
46
- "abcat0100000": 0.10,
47
- "abcat0200000": 0.20,
48
- "abcat0300000": 0.30,
49
- "abcat0400000": 0.25,
50
- "abcat0500000": 0.15,
51
- }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if mode == "Multimodal":
54
- return multimodal_output
55
  elif mode == "Text Only":
56
- return text_only_output
57
  elif mode == "Image Only":
58
- return image_only_output
59
  else:
 
60
  return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, Optional
3
+
4
+ from numpy import array, expand_dims, float32, ndarray, transpose, zeros
5
+ from PIL import Image
6
+ from sentence_transformers import SentenceTransformer
7
+ from tensorflow import constant
8
  from tensorflow.keras.models import load_model
9
+ from transformers import TFConvNextV2Model
10
+
11
+ # TODO: Hardcoded class labels for the output, as discussed
12
+ CLASS_LABELS = [
13
+ "abcat0100000",
14
+ "abcat0200000",
15
+ "abcat0300000",
16
+ "abcat0400000",
17
+ "abcat0500000",
18
+ ]
19
 
20
+ # πŸ“Œ LOAD MODELS
21
+ print("πŸ’¬ Loading embedding models...")
22
+ try:
23
+ text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
24
+ image_feature_extractor = TFConvNextV2Model.from_pretrained(
25
+ "facebook/convnextv2-tiny-22k-224"
26
+ )
27
+ print("βœ… Embedding models loaded successfully!")
28
+ except Exception as e:
29
+ print(f"❌ Error loading embedding models: {e}")
30
+ text_embedding_model, image_feature_extractor = None, None
31
+
32
+ # Load the final classification models (MLP heads)
33
+ print("πŸ’¬ Loading classification models...")
34
  try:
35
  text_model = load_model("./models/text_model")
36
  image_model = load_model("./models/image_model")
37
  multimodal_model = load_model("./models/multimodal_model")
38
+ print("βœ… Classification models loaded successfully!")
39
  except Exception as e:
40
+ print(f"❌ Error loading classification models: {e}")
41
+ text_model, image_model, multimodal_model = None, None, None
 
 
42
 
 
 
 
 
 
 
43
 
44
+ # πŸ“Œ EMBEDDING FUNCTIONS
45
+ def get_text_embeddings(text: Optional[str]) -> ndarray:
46
+ """
47
+ Generates a dense embedding vector from a text string.
48
+
49
+ Args:
50
+ text (Optional[str]): The input text. Can be None or an empty string.
51
+
52
+ Returns:
53
+ np.ndarray: A NumPy array of shape (1, 384) representing the text
54
+ embedding. Returns a zero vector if the input is empty.
55
+ """
56
+ # Handle cases where no text is provided
57
+ if not text or not text.strip():
58
+ # Returns a zero vector with the correct dimension (384)
59
+ return zeros(
60
+ (1, text_embedding_model.get_sentence_embedding_dimension()), dtype=float32
61
+ )
62
+
63
+ # Use the pre-trained SentenceTransformer to encode the text
64
+ embeddings = text_embedding_model.encode([text])
65
+ return array(embeddings, dtype=float32)
66
 
67
+
68
+ def get_image_embeddings(image_path: Optional[str]) -> ndarray:
69
  """
70
+ Preprocesses an image and generates an embedding vector using a pre-trained model.
71
+
72
+ Args:
73
+ image_path (Optional[str]): The file path to the image.
74
+
75
+ Returns:
76
+ np.ndarray: A NumPy array of shape (1, 768) representing the image
77
+ embedding. Returns a zero vector if no image is provided.
78
  """
79
+ # Handle cases where no image is provided
80
+ if image_path is None:
81
+ return zeros((1, 768), dtype=float32)
82
+
83
+ # Load the image and convert to RGB format
84
+ image = Image.open(image_path).convert("RGB")
85
+
86
+ # Resize the image to the model's expected input size (224x224)
87
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Convert to NumPy array and add a batch dimension (1, H, W, C)
90
+ image_array = array(image, dtype=float32)
91
+ image_array = expand_dims(image_array, axis=0)
92
+
93
+ # Transpose the array to match the model's channel order (1, C, H, W)
94
+ image_array = transpose(image_array, (0, 3, 1, 2))
95
+
96
+ # Normalize the pixel values (not strictly necessary for this model, but good practice)
97
+ image_array = image_array / 255.0
98
+
99
+ # Pass the preprocessed image through the feature extractor model
100
+ embeddings_output = image_feature_extractor(constant(image_array))
101
+
102
+ # Extract the final embedding from the pooler_output
103
+ embeddings = embeddings_output.pooler_output
104
+
105
+ return embeddings.numpy()
106
+
107
+
108
+ # πŸ“Œ MAIN PREDICTION FUNCTION
109
+ def predict(
110
+ mode: str, text: Optional[str], image_path: Optional[str]
111
+ ) -> Dict[str, Any]:
112
+ """
113
+ Predicts the category of a product based on the selected mode.
114
+
115
+ Args:
116
+ mode (str): The prediction mode ("Multimodal", "Text Only", "Image Only").
117
+ text (Optional[str]): The product description text.
118
+ image_path (Optional[str]): The file path to the product image.
119
+
120
+ Returns:
121
+ Dict[str, Any]: A dictionary of class labels and their corresponding
122
+ prediction probabilities. Returns an empty dictionary
123
+ if the mode is invalid.
124
+ """
125
+ # Generate embeddings for both inputs
126
+ text_emb = get_text_embeddings(text)
127
+ image_emb = get_image_embeddings(image_path)
128
+
129
+ # Get predictions based on the selected mode
130
  if mode == "Multimodal":
131
+ predictions = multimodal_model.predict([text_emb, image_emb])
132
  elif mode == "Text Only":
133
+ predictions = text_model.predict(text_emb)
134
  elif mode == "Image Only":
135
+ predictions = image_model.predict(image_emb)
136
  else:
137
+ # Return an empty dictionary if the mode is not recognized
138
  return {}
139
+
140
+ # Format the output into a dictionary with labels and probabilities
141
+ # The model's output is a 2D array, so we take the first row (index 0)
142
+ prediction_dict = dict(zip(CLASS_LABELS, predictions[0]))
143
+
144
+ return prediction_dict
145
+
146
+
147
+ # πŸ“Œ SANITY CHECKS
148
+ if __name__ == "__main__":
149
+ print("\n--- Running sanity checks for predictor.py ---")
150
+
151
+ # Check text embedding function
152
+ print("\n--- Testing get_text_embeddings ---")
153
+ sample_text = (
154
+ "A sleek silver laptop with a large screen and high-resolution display."
155
+ )
156
+ text_emb = get_text_embeddings(sample_text)
157
+ print(f"Embedding shape for a normal string: {text_emb.shape}")
158
+ empty_text_emb = get_text_embeddings("")
159
+ print(f"Embedding shape for an empty string: {empty_text_emb.shape}")
160
+ spaces_text_emb = get_text_embeddings(" ")
161
+ print(f"Embedding shape for a string with spaces: {spaces_text_emb.shape}")
162
+
163
+ # Check image embedding function
164
+ print("\n--- Testing get_image_embeddings ---")
165
+ test_image_path = "test.jpeg" # Ensure this file exists for the test to pass
166
+ if os.path.exists(test_image_path):
167
+ image_emb = get_image_embeddings(test_image_path)
168
+ print(f"βœ… Embedding shape for an image file: {image_emb.shape}")
169
+ else:
170
+ print(
171
+ f"⚠️ Warning: Test image file not found at {test_image_path}. Skipping image embedding test."
172
+ )
173
+
174
+ empty_image_emb = get_image_embeddings(None)
175
+ print(f"Embedding shape for a None input: {empty_image_emb.shape}")
176
+ print("--- Sanity checks complete. ---")