File size: 9,701 Bytes
21fb9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import streamlit as st
from PIL import Image
import io
import os
import time
import tempfile
from pathlib import Path

from src.models.model_discovery import discover_models
from src.labels import LABELS


def load_model(model_info):
    """Load and cache the selected model with proper error handling."""
    model_class = model_info["class"]
    model_name = model_info["class_name"]

    # Set up custom cache directory to avoid permission issues
    custom_cache = Path(tempfile.gettempdir()) / "tikka_masalai_cache"
    custom_cache.mkdir(exist_ok=True)

    # Set HuggingFace cache directory (use HF_HOME instead of deprecated TRANSFORMERS_CACHE)
    os.environ["HF_HOME"] = str(custom_cache)
    os.environ["TRANSFORMERS_CACHE"] = str(
        custom_cache
    )  # Keep for backward compatibility

    try:
        # Use a placeholder for the loading message that we can clear
        loading_placeholder = st.empty()
        loading_placeholder.info(f"Loading {model_name} model...")

        # Try to load the model - handle different model initialization patterns
        if "prithiv" in model_name.lower():
            # PrithivML model with specific initialization
            model = model_class()
        elif "resnet" in model_name.lower():
            # ResNet model - check if it needs specific paths
            try:
                model = model_class()
            except TypeError:
                # Try with default parameters if it requires them
                model = model_class(
                    preprocessor_path="microsoft/resnet-18",
                    model_path="microsoft/resnet-18",
                )
        elif "vgg" in model_name.lower():
            # VGG model with default parameters
            model = model_class()
        else:
            # Generic model initialization
            try:
                model = model_class()
            except TypeError:
                # Skip models that require specific parameters we don't know about
                raise RuntimeError(
                    f"Model {model_name} requires specific initialization parameters"
                )

        # Show success message briefly, then clear it
        loading_placeholder.success(f"{model_name} model loaded successfully!")
        time.sleep(1.5)  # Show success message for 1.5 seconds
        loading_placeholder.empty()  # Clear the message
        return model

    except PermissionError as e:
        st.error(f"❌ Permission error: {str(e)}")
        if "cache" in str(e).lower():
            st.info(
                "πŸ’‘ This is likely a cache permission issue. Please refresh the page and try again."
            )
        return None

    except Exception as e:
        error_msg = str(e)
        st.error(f"❌ Error loading {model_name} model: {error_msg}")
        st.info("πŸ’‘ Possible solutions:")
        st.info("1. Refresh the page and try again")
        st.info("2. Check if HuggingFace services are available")
        st.info("3. Try a different model")
        return None


def predict_food(model, image_bytes):
    """Make a prediction on the uploaded image."""
    try:
        # Get prediction index
        prediction_idx = model.classify(image_bytes)

        # Get the label name
        if 0 <= prediction_idx < len(LABELS):
            prediction_label = LABELS[prediction_idx]
            return prediction_idx, prediction_label
        else:
            return None, "Unknown"
    except Exception as e:
        st.error(f"Error during prediction: {str(e)}")
        return None, "Error"


def main():
    """Main Streamlit application."""
    st.set_page_config(
        page_title="TikkaMasalAI Food Classifier", page_icon="🍽️", layout="centered"
    )

    st.title("🍽️ TikkaMasalAI Food Classifier")
    st.markdown("Upload an image of food and let our AI identify what it is!")

    # Discover available models
    try:
        available_models = discover_models()
    except Exception as e:
        st.error(f"❌ Error discovering models: {str(e)}")
        st.info("Make sure the src/models directory contains valid model files.")
        return

    if not available_models:
        st.error("❌ No compatible models found in the src/models directory!")
        st.info("Make sure there are models that inherit from FoodClassificationModel.")
        return

    # Model selection in sidebar
    with st.sidebar:
        st.header("πŸ€– Model Selection")
        selected_model_name = st.selectbox(
            "Choose a model:",
            options=list(available_models.keys()),
            help="Select which AI model to use for food classification",
        )

        selected_model_info = available_models[selected_model_name]

        # Show model information
        st.info(f"**Selected:** {selected_model_name}")
        st.write(f"**Class:** `{selected_model_info['class_name']}`")
        st.write(f"**Module:** `{selected_model_info['module']}`")

    # Show app status
    status_container = st.container()

    # Load model with better UX
    with status_container:
        model_status = st.empty()
        progress_bar = st.progress(0)

        model_status.info("πŸ”„ Initializing AI model...")
        progress_bar.progress(25)

        model = load_model(selected_model_info)
        progress_bar.progress(100)

        if model is None:
            model_status.error("❌ Failed to load the model.")
            st.error("### 🚨 Model Loading Failed")
            st.markdown(
                f"""
            **Failed to load:** {selected_model_name}
            
            **Possible causes:**
            - Model-specific initialization requirements
            - Missing dependencies for this model
            - Temporary HuggingFace services issue
            - Model cache conflicts in HF Spaces
            - Network connectivity problems
            
            **Solutions:**
            1. **Try a different model** from the sidebar
            2. **Refresh the page** and try again
            3. **Wait 2-3 minutes** for any background downloads to complete
            4. If the issue persists, the model will automatically retry
            """
            )

            # Add a retry button
            if st.button("πŸ”„ Retry Loading Model"):
                st.experimental_rerun()

            return

        model_status.success(f"βœ… {selected_model_name} loaded and ready!")
        progress_bar.empty()

    # File uploader
    uploaded_file = st.file_uploader(
        "Choose a food image...",
        type=["png", "jpg", "jpeg"],
        help="Upload an image of food to classify",
    )

    if uploaded_file is not None:
        # Read image bytes
        image_bytes = uploaded_file.read()

        # Display the uploaded image
        col1, col2 = st.columns([1, 1])

        with col1:
            st.subheader("πŸ“Έ Uploaded Image")
            image = Image.open(io.BytesIO(image_bytes))
            st.image(image, caption="Your uploaded image", use_container_width=True)

        with col2:
            st.subheader("πŸ” Prediction Results")

            # Make prediction
            with st.spinner("Analyzing your image..."):
                prediction_idx, prediction_label = predict_food(model, image_bytes)

            if prediction_idx is not None:
                # Display results
                st.success("Classification complete!")

                # Format the label for display
                display_label = prediction_label.replace("_", " ").title()

                st.markdown(f"### 🏷️ **{display_label}**")
                st.markdown(f"**Class Index:** {prediction_idx}")

                # Show confidence bar (placeholder since the model doesn't return probabilities)
                st.markdown("**Prediction Details:**")
                st.info(f"The AI model identified this image as **{display_label}**")

                # Show additional info
                with st.expander("ℹ️ About this classification"):
                    st.write(f"- **Model:** {selected_model_name}")
                    st.write(f"- **Classes:** {len(LABELS)} different food types")
                    st.write(f"- **Raw label:** `{prediction_label}`")
                    st.write(f"- **Index:** {prediction_idx}")
            else:
                st.error("Failed to classify the image. Please try another image.")

    # Sidebar with information
    with st.sidebar:
        st.header("πŸ“‹ About")
        st.write(
            f"""
        This app uses the **{selected_model_name}** model to classify food images into one of 101 different food categories.
        """
        )

        st.header("🎯 How to use")
        st.write(
            """
        1. Choose a model from the dropdown above
        2. Upload an image of food using the file uploader
        3. Wait for the AI to analyze your image
        4. View the classification results
        """
        )

        st.header("πŸ• Supported Foods")
        st.write(
            f"The model can recognize **{len(LABELS)}** different types of food including:"
        )

        # Show a sample of labels
        sample_labels = [label.replace("_", " ").title() for label in LABELS[:10]]
        for label in sample_labels:
            st.write(f"β€’ {label}")
        st.write(f"... and {len(LABELS) - 10} more!")

        st.header("πŸ”§ Technical Details")
        st.write(
            f"""
        - **Selected Model:** {selected_model_name}
        - **Available Models:** {len(available_models)}
        - **Dataset:** Food-101
        - **Framework:** PyTorch + Transformers
        """
        )


if __name__ == "__main__":
    main()