File size: 9,155 Bytes
d7c8166
 
fca4028
0225bda
f02a5fd
 
 
 
 
 
 
 
 
 
 
ae692d1
f02a5fd
 
 
 
 
2589717
 
 
 
f02a5fd
 
d7c8166
dd8438e
0e07292
 
 
 
 
 
 
 
 
 
dd8438e
 
 
 
f02a5fd
dd8438e
0e07292
dd8438e
5ff38a4
2589717
 
 
 
 
 
 
0e07292
dd8438e
f02a5fd
dd8438e
 
f02a5fd
0e07292
 
fca4028
0e07292
dd8438e
0e07292
 
 
dd8438e
0e07292
fca4028
0e07292
dd8438e
0e07292
dd8438e
 
 
fca4028
dd8438e
0e07292
 
dd8438e
 
0e07292
 
f02a5fd
dd8438e
 
 
 
0e07292
dd8438e
 
 
fca4028
 
 
0e07292
 
 
f02a5fd
 
dd8438e
 
 
 
fca4028
 
 
 
 
feef34c
 
fca4028
 
feef34c
fca4028
feef34c
 
 
 
 
 
fca4028
 
 
 
 
 
 
 
 
dd8438e
5ff38a4
dd8438e
5ff38a4
 
 
 
 
dd8438e
5ff38a4
dd8438e
5ff38a4
 
 
 
 
 
 
 
 
 
 
 
 
dd8438e
 
 
5ff38a4
dd8438e
5ff38a4
 
dd8438e
5ff38a4
 
 
 
 
fca4028
5ff38a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8438e
5ff38a4
 
 
dd8438e
 
 
f02a5fd
 
ae692d1
 
dd8438e
 
 
 
 
 
 
0e07292
dd8438e
 
 
0e07292
 
dd8438e
0e07292
 
d7c8166
dd8438e
d7c8166
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr

from app_predictor import predict

# πŸ“Œ CUSTOM CSS
css_code = """
#footer-container {
    position: fixed;
    bottom: 0;
    left: 0;
    right: 0;
    z-index: 1000;
    background-color: var(--background-fill-primary);
    padding: var(--spacing-md);
    border-top: 1px solid var(--border-color-primary);
    text-align: center;
}

.gradio-container {
    padding-bottom: 70px !important;
}

.center {
    text-align: center;
}
"""


def update_inputs(mode: str):
    if mode == "Multimodal":
        return gr.Textbox(visible=True), gr.Image(visible=True)
    elif mode == "Text Only":
        return gr.Textbox(visible=True), gr.Image(visible=False)
    elif mode == "Image Only":
        return gr.Textbox(visible=False), gr.Image(visible=True)
    else:  # Default case
        return gr.Textbox(visible=True), gr.Image(visible=True)


# πŸ“Œ USER INTERFACE
with gr.Blocks(
    title="Multimodal Product Classification",
    theme=gr.themes.Ocean(),
    css=css_code,
) as demo:
    with gr.Tabs():
        # πŸ“Œ APP TAB
        with gr.TabItem("πŸš€ App"):
            with gr.Row(elem_classes="center"):
                gr.HTML("""
                    <div>
                        <h1>πŸ›οΈ Multimodal Product Classification</h1>
                    </div>
                    <br><br>
                    """)

            with gr.Row(equal_height=True):
                # πŸ“Œ CLASSIFICATION INPUTS COLUMN
                with gr.Column():
                    with gr.Column():
                        gr.Markdown("## πŸ“ Classification Inputs")

                        mode_radio = gr.Radio(
                            choices=["Multimodal", "Image Only", "Text Only"],
                            value="Multimodal",
                            label="Choose Classification Mode:",
                        )

                        text_input = gr.Textbox(
                            label="Product Description:",
                            placeholder="e.g., Apple iPhone 15 Pro Max 256GB",
                            lines=1,
                        )

                        image_input = gr.Image(
                            label="Product Image",
                            type="filepath",
                            visible=True,
                            height=300,
                            width="100%",
                        )

                        classify_button = gr.Button(
                            "✨ Classify Product", variant="primary"
                        )

                # πŸ“Œ RESULTS COLUMN
                with gr.Column():
                    with gr.Column():
                        gr.Markdown("## πŸ“Š Results")

                        gr.Markdown(
                            """**πŸ’‘ How to use this app**

                            This app classifies a product based on its description and image.
                            - **Multimodal:** The most accurate mode, using both the image and a detailed description for prediction.
                            - **Image Only:** Highly effective for visual products, relying solely on the product image.
                            - **Text Only:** Less precise, this mode requires a very descriptive and specific product description to achieve good results.
                            """
                        )

                        gr.HTML("<hr>")

                        output_label = gr.Label(
                            label="Predict category", num_top_classes=5
                        )

            # πŸ“Œ EXAMPLES SECTION
            gr.Examples(
                examples=[
                    [
                        "Multimodal",
                        'Laptop Asus - 15.6" / CPU I9 / 2Tb SSD / 32Gb RAM / RTX 2080',
                        "./assets/sample2.jpg",
                    ],
                    [
                        "Multimodal",
                        "Red Electric Guitar – Stratocaster Style, 6-String, White Pickguard, Solid-Body, Ideal for Rock & Roll",
                        "./assets/sample1.jpg",
                    ],
                    [
                        "Multimodal",
                        "Portable Wireless Speaker / JBL / Black / High Quality Sound",
                        "./assets/sample3.jpg",
                    ],
                ],
                label="Select an example to pre-fill the inputs, then click the 'Classify Product' button.",
                inputs=[mode_radio, text_input, image_input],
                # outputs=output_label,
                # fn=predict,
                # cache_examples=True,
            )

        # πŸ“Œ ABOUT TAB
        with gr.TabItem("ℹ️ About"):
            gr.Markdown("""
## Project Overview
                        
- This project is a multimodal product classification system for Best Buy products. 
- The core objective is to categorize products using both their text descriptions and images. 
- The system was trained on a dataset of **almost 50,000** products and their corresponding images to generate embeddings and train the classification models.

<br>

## Technical Workflow
                        
1.  **Data Preprocessing:** Product descriptions and images are extracted from the dataset, and a `categories.json` file is used to map product IDs to human-readable category names.
2.  **Embedding Generation:**
    - **Text:** A pre-trained `SentenceTransformer` model (`all-MiniLM-L6-v2`) is used to generate dense vector embeddings from the product descriptions.
    - **Image:** A pre-trained computer vision model from the Hugging Face `transformers` library (`TFConvNextV2Model`) is used to extract image features.
3.  **Model Training:** The generated text and image embeddings are then used to train a multi-layer perceptron (MLP) model for classification. Separate models were trained for text-only, image-only, and multimodal (combined embeddings) classification.
4.  **Deployment:** The trained models are deployed via a Gradio web interface, allowing for live prediction on new product data.

<br>
                                   
> **πŸ’‘ Want to explore the process in detail?**   
> See the full πŸ‘‰ [Jupyter notebook](https://huggingface.co/spaces/iBrokeTheCode/Multimodal_Product_Classification/blob/main/notebook_guide.ipynb) πŸ‘ˆοΈ for an end-to-end walkthrough, including Exploratory Data Analysis, embeddings generation, models training, evaluation, and model selection.
""")

        # πŸ“Œ MODEL TAB
        with gr.TabItem("🎯 Model"):
            gr.Markdown("""
## Model Details
The final classification is performed by a Multi-layer Perceptron (MLP) trained on the embeddings. This architecture allows the model to learn the relationships between the textual and visual features.

<br>
                        
## Performance Summary
                        
The following table summarizes the performance of all models trained in this project.
                        
<br>

| Model               | Modality     | Accuracy | Macro Avg F1-Score | Weighted Avg F1-Score |
| :------------------ | :----------- | :------- | :----------------- | :-------------------- |
| Random Forest       | Text         | 0.90     | 0.83               | 0.90                  |
| Logistic Regression | Text         | 0.90     | 0.84               | 0.90                  |
| Random Forest       | Image        | 0.80     | 0.70               | 0.79                  |
| Random Forest       | Combined     | 0.89     | 0.79               | 0.89                  |
| Logistic Regression | Combined     | 0.89     | 0.83               | 0.89                  |
| **MLP** | **Image** | **0.84** | **0.77** | **0.84** |
| **MLP** | **Text** | **0.92** | **0.87** | **0.92** |
| **MLP** | **Combined** | **0.92** | **0.85** | **0.92** |

<br>
                        
## Conclusion
                        
- Based on the overall results, the MLP models consistently outperformed their classical machine learning counterparts, demonstrating their ability to learn intricate, non-linear relationships within the data.
- Both the Text MLP and Combined MLP models achieved the highest accuracy and weighted F1-score, confirming their superior ability to classify the products.
- This modular approach demonstrates the ability to handle various data modalities and evaluate the contribution of each to the final prediction.
""")

    # πŸ“Œ FOOTER
    # gr.HTML("<hr>")
    with gr.Row(elem_id="footer-container"):
        gr.HTML("""
<div>
        <b>Connect with me:</b> πŸ’Ό <a href="https://www.linkedin.com/in/alex-turpo/" target="_blank">LinkedIn</a> β€’ 
        🐱 <a href="https://github.com/iBrokeTheCode" target="_blank">GitHub</a> β€’ 
        πŸ€— <a href="https://huggingface.co/iBrokeTheCode" target="_blank">Hugging Face</a>
    </div>
""")

    # πŸ“Œ EVENT LISTENERS
    mode_radio.change(
        fn=update_inputs,
        inputs=mode_radio,
        outputs=[text_input, image_input],
    )

    classify_button.click(
        fn=predict, inputs=[mode_radio, text_input, image_input], outputs=output_label
    )


demo.launch()