iBrokeTheCode's picture
chore: Add samples for a quick app run
fca4028
raw
history blame
7.26 kB
from json import load
from typing import Any, Dict, Optional
from numpy import array, expand_dims, float32, ndarray, transpose, zeros
from PIL import Image
from sentence_transformers import SentenceTransformer
from tensorflow import constant
from tensorflow.keras.models import load_model
from transformers import TFConvNextV2Model
# πŸ“Œ GLOBAL VARIABLES (categories)
CATEGORY_MAP: Dict[str, str] = {}
CLASS_LABELS = []
def build_category_map(categories_json_path: str):
"""
Builds a flat dictionary and a list of category labels by traversing the hierarchical categories.json file.
"""
global CATEGORY_MAP, CLASS_LABELS
try:
with open(categories_json_path, "r") as f:
categories_data = load(f)
except FileNotFoundError:
print(
f"❌ Error: {categories_json_path} not found. Using hardcoded labels as fallback."
)
return
category_map = {}
model_trained_ids = [
"abcat0100000",
"abcat0200000",
"abcat0207000",
"abcat0300000",
"abcat0400000",
"abcat0500000",
"abcat0700000",
"abcat0800000",
"abcat0900000",
"cat09000",
"pcmcat128500050004",
"pcmcat139900050002",
"pcmcat242800050021",
"pcmcat252700050006",
"pcmcat312300050015",
"pcmcat332000050000",
]
def traverse_categories(categories):
for category in categories:
category_map[category["id"]] = category["name"]
if "subCategories" in category and category["subCategories"]:
traverse_categories(category["subCategories"])
if "path" in category and category["path"]:
for path_item in category["path"]:
category_map[path_item["id"]] = path_item["name"]
traverse_categories(categories_data)
CATEGORY_MAP = category_map
CLASS_LABELS = model_trained_ids
# πŸ“Œ LOAD MODELS
print("πŸ’¬ Loading embedding models...")
try:
text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
image_feature_extractor = TFConvNextV2Model.from_pretrained(
"facebook/convnextv2-tiny-22k-224"
)
print("βœ… Embedding models loaded successfully!")
except Exception as e:
print(f"❌ Error loading embedding models: {e}")
text_embedding_model, image_feature_extractor = None, None
# Load the final classification models (MLP heads)
print("πŸ’¬ Loading classification models...")
try:
text_model = load_model("./models/text_model")
image_model = load_model("./models/image_model")
multimodal_model = load_model("./models/multimodal_model")
print("βœ… Classification models loaded successfully!")
except Exception as e:
print(f"❌ Error loading classification models: {e}")
text_model, image_model, multimodal_model = None, None, None
# Generate category map and class labels list
build_category_map("./data/raw/categories.json")
# πŸ“Œ EMBEDDING FUNCTIONS
def get_text_embeddings(text: Optional[str]) -> ndarray:
"""
Generates a dense embedding vector from a text string.
Args:
text (Optional[str]): The input text. Can be None or an empty string.
Returns:
np.ndarray: A NumPy array of shape (1, 384) representing the text
embedding. Returns a zero vector if the input is empty.
"""
# Handle cases where no text is provided
if not text or not text.strip():
# Returns a zero vector with the correct dimension (384)
return zeros(
(1, text_embedding_model.get_sentence_embedding_dimension()), dtype=float32
)
# Use the pre-trained SentenceTransformer to encode the text
embeddings = text_embedding_model.encode([text])
return array(embeddings, dtype=float32)
def get_image_embeddings(image_path: Optional[str]) -> ndarray:
"""
Preprocesses an image and generates an embedding vector using a pre-trained model.
Args:
image_path (Optional[str]): The file path to the image.
Returns:
np.ndarray: A NumPy array of shape (1, 768) representing the image
embedding. Returns a zero vector if no image is provided.
"""
# Handle cases where no image is provided
if image_path is None:
return zeros((1, 768), dtype=float32)
# Load the image and convert to RGB format
image = Image.open(image_path).convert("RGB")
# Resize the image to the model's expected input size (224x224)
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Convert to NumPy array and add a batch dimension (1, H, W, C)
image_array = array(image, dtype=float32)
image_array = expand_dims(image_array, axis=0)
# Transpose the array to match the model's channel order (1, C, H, W)
image_array = transpose(image_array, (0, 3, 1, 2))
# Normalize the pixel values (not strictly necessary for this model, but good practice)
image_array = image_array / 255.0
# Pass the preprocessed image through the feature extractor model
embeddings_output = image_feature_extractor(constant(image_array))
# Extract the final embedding from the pooler_output
embeddings = embeddings_output.pooler_output
return embeddings.numpy()
# πŸ“Œ MAIN PREDICTION FUNCTION
def predict(
mode: str, text: Optional[str], image_path: Optional[str]
) -> Dict[str, Any]:
"""
Predicts the category of a product based on the selected mode.
Args:
mode (str): The prediction mode ("Multimodal", "Text Only", "Image Only").
text (Optional[str]): The product description text.
image_path (Optional[str]): The file path to the product image.
Returns:
Dict[str, Any]: A dictionary of class labels and their corresponding
prediction probabilities. Returns an empty dictionary
if the mode is invalid.
"""
# Generate embeddings for both inputs
text_emb = get_text_embeddings(text)
image_emb = get_image_embeddings(image_path)
# Get predictions based on the selected mode
if mode == "Multimodal":
predictions = multimodal_model.predict([text_emb, image_emb])
elif mode == "Text Only":
predictions = text_model.predict(text_emb)
elif mode == "Image Only":
predictions = image_model.predict(image_emb)
else:
# Return an empty dictionary if the mode is not recognized
return {}
# Format the output into a dictionary with labels and probabilities
# The model's output is a 2D array, so we take the first row (index 0)
prediction_dict_raw = dict(zip(CLASS_LABELS, predictions[0]))
# Map the raw IDs to human-readable names
prediction_dict_mapped = {}
for class_id, probability in prediction_dict_raw.items():
# Get the human-readable name, defaulting to the raw ID if not found
category_name = CATEGORY_MAP.get(class_id, class_id)
prediction_dict_mapped[category_name] = probability
# Sort the dictionary by probability in descending order for a cleaner display
sorted_predictions = dict(
sorted(prediction_dict_mapped.items(), key=lambda item: item[1], reverse=True)
)
return sorted_predictions