iBrokeTheCode commited on
Commit
a061490
Β·
1 Parent(s): 43fe501

feat: Get product category predictions

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. predictor.py +72 -42
app.py CHANGED
@@ -64,6 +64,7 @@ with gr.Blocks(
64
  text_input = gr.Textbox(
65
  label="Product Description:",
66
  placeholder="e.g., Apple iPhone 15 Pro Max 256GB",
 
67
  )
68
 
69
  image_input = gr.Image(
 
64
  text_input = gr.Textbox(
65
  label="Product Description:",
66
  placeholder="e.g., Apple iPhone 15 Pro Max 256GB",
67
+ lines=2,
68
  )
69
 
70
  image_input = gr.Image(
predictor.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  from typing import Any, Dict, Optional
3
 
4
  from numpy import array, expand_dims, float32, ndarray, transpose, zeros
@@ -8,14 +8,61 @@ from tensorflow import constant
8
  from tensorflow.keras.models import load_model
9
  from transformers import TFConvNextV2Model
10
 
11
- # TODO: Hardcoded class labels for the output, as discussed
12
- CLASS_LABELS = [
13
- "abcat0100000",
14
- "abcat0200000",
15
- "abcat0300000",
16
- "abcat0400000",
17
- "abcat0500000",
18
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # πŸ“Œ LOAD MODELS
21
  print("πŸ’¬ Loading embedding models...")
@@ -40,6 +87,9 @@ except Exception as e:
40
  print(f"❌ Error loading classification models: {e}")
41
  text_model, image_model, multimodal_model = None, None, None
42
 
 
 
 
43
 
44
  # πŸ“Œ EMBEDDING FUNCTIONS
45
  def get_text_embeddings(text: Optional[str]) -> ndarray:
@@ -139,38 +189,18 @@ def predict(
139
 
140
  # Format the output into a dictionary with labels and probabilities
141
  # The model's output is a 2D array, so we take the first row (index 0)
142
- prediction_dict = dict(zip(CLASS_LABELS, predictions[0]))
143
-
144
- return prediction_dict
145
-
146
-
147
- # πŸ“Œ SANITY CHECKS
148
- if __name__ == "__main__":
149
- print("\n--- Running sanity checks for predictor.py ---")
150
-
151
- # Check text embedding function
152
- print("\n--- Testing get_text_embeddings ---")
153
- sample_text = (
154
- "A sleek silver laptop with a large screen and high-resolution display."
155
  )
156
- text_emb = get_text_embeddings(sample_text)
157
- print(f"Embedding shape for a normal string: {text_emb.shape}")
158
- empty_text_emb = get_text_embeddings("")
159
- print(f"Embedding shape for an empty string: {empty_text_emb.shape}")
160
- spaces_text_emb = get_text_embeddings(" ")
161
- print(f"Embedding shape for a string with spaces: {spaces_text_emb.shape}")
162
-
163
- # Check image embedding function
164
- print("\n--- Testing get_image_embeddings ---")
165
- test_image_path = "test.jpeg" # Ensure this file exists for the test to pass
166
- if os.path.exists(test_image_path):
167
- image_emb = get_image_embeddings(test_image_path)
168
- print(f"βœ… Embedding shape for an image file: {image_emb.shape}")
169
- else:
170
- print(
171
- f"⚠️ Warning: Test image file not found at {test_image_path}. Skipping image embedding test."
172
- )
173
 
174
- empty_image_emb = get_image_embeddings(None)
175
- print(f"Embedding shape for a None input: {empty_image_emb.shape}")
176
- print("--- Sanity checks complete. ---")
 
1
+ from json import load
2
  from typing import Any, Dict, Optional
3
 
4
  from numpy import array, expand_dims, float32, ndarray, transpose, zeros
 
8
  from tensorflow.keras.models import load_model
9
  from transformers import TFConvNextV2Model
10
 
11
+ # πŸ“Œ GLOBAL VARIABLES (categories)
12
+ CATEGORY_MAP: Dict[str, str] = {}
13
+ CLASS_LABELS = []
14
+
15
+
16
+ def build_category_map(categories_json_path: str):
17
+ """
18
+ Builds a flat dictionary and a list of category labels by traversing the hierarchical categories.json file.
19
+ """
20
+ global CATEGORY_MAP, CLASS_LABELS
21
+
22
+ try:
23
+ with open(categories_json_path, "r") as f:
24
+ categories_data = load(f)
25
+ except FileNotFoundError:
26
+ print(
27
+ f"❌ Error: {categories_json_path} not found. Using hardcoded labels as fallback."
28
+ )
29
+ return
30
+
31
+ category_map = {}
32
+
33
+ model_trained_ids = [
34
+ "abcat0100000",
35
+ "abcat0200000",
36
+ "abcat0207000",
37
+ "abcat0300000",
38
+ "abcat0400000",
39
+ "abcat0500000",
40
+ "abcat0700000",
41
+ "abcat0800000",
42
+ "abcat0900000",
43
+ "cat09000",
44
+ "pcmcat128500050004",
45
+ "pcmcat139900050002",
46
+ "pcmcat242800050021",
47
+ "pcmcat252700050006",
48
+ "pcmcat312300050015",
49
+ "pcmcat332000050000",
50
+ ]
51
+
52
+ def traverse_categories(categories):
53
+ for category in categories:
54
+ category_map[category["id"]] = category["name"]
55
+ if "subCategories" in category and category["subCategories"]:
56
+ traverse_categories(category["subCategories"])
57
+ if "path" in category and category["path"]:
58
+ for path_item in category["path"]:
59
+ category_map[path_item["id"]] = path_item["name"]
60
+
61
+ traverse_categories(categories_data)
62
+
63
+ CATEGORY_MAP = category_map
64
+ CLASS_LABELS = model_trained_ids
65
+
66
 
67
  # πŸ“Œ LOAD MODELS
68
  print("πŸ’¬ Loading embedding models...")
 
87
  print(f"❌ Error loading classification models: {e}")
88
  text_model, image_model, multimodal_model = None, None, None
89
 
90
+ # Generate category map and class labels list
91
+ build_category_map("./data/raw/categories.json")
92
+
93
 
94
  # πŸ“Œ EMBEDDING FUNCTIONS
95
  def get_text_embeddings(text: Optional[str]) -> ndarray:
 
189
 
190
  # Format the output into a dictionary with labels and probabilities
191
  # The model's output is a 2D array, so we take the first row (index 0)
192
+ prediction_dict_raw = dict(zip(CLASS_LABELS, predictions[0]))
193
+
194
+ # Map the raw IDs to human-readable names
195
+ prediction_dict_mapped = {}
196
+ for class_id, probability in prediction_dict_raw.items():
197
+ # Get the human-readable name, defaulting to the raw ID if not found
198
+ category_name = CATEGORY_MAP.get(class_id, class_id)
199
+ prediction_dict_mapped[category_name] = probability
200
+
201
+ # Sort the dictionary by probability in descending order for a cleaner display
202
+ sorted_predictions = dict(
203
+ sorted(prediction_dict_mapped.items(), key=lambda item: item[1], reverse=True)
 
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ return sorted_predictions