Spaces:
Running
Running
| import streamlit as st | |
| from PIL import Image | |
| import io | |
| import os | |
| import time | |
| import tempfile | |
| from pathlib import Path | |
| from src.models.model_discovery import discover_models | |
| from src.labels import LABELS | |
| def load_model(model_info): | |
| """Load and cache the selected model with proper error handling.""" | |
| model_class = model_info["class"] | |
| model_name = model_info["class_name"] | |
| # Set up custom cache directory to avoid permission issues | |
| custom_cache = Path(tempfile.gettempdir()) / "tikka_masalai_cache" | |
| custom_cache.mkdir(exist_ok=True) | |
| # Set HuggingFace cache directory (use HF_HOME instead of deprecated TRANSFORMERS_CACHE) | |
| os.environ["HF_HOME"] = str(custom_cache) | |
| os.environ["TRANSFORMERS_CACHE"] = str( | |
| custom_cache | |
| ) # Keep for backward compatibility | |
| try: | |
| # Use a placeholder for the loading message that we can clear | |
| loading_placeholder = st.empty() | |
| loading_placeholder.info(f"Loading {model_name} model...") | |
| # Try to load the model - handle different model initialization patterns | |
| if "prithiv" in model_name.lower(): | |
| # PrithivML model with specific initialization | |
| model = model_class() | |
| elif "resnet" in model_name.lower(): | |
| # ResNet model - check if it needs specific paths | |
| try: | |
| model = model_class() | |
| except TypeError: | |
| # Try with default parameters if it requires them | |
| model = model_class( | |
| preprocessor_path="microsoft/resnet-18", | |
| model_path="microsoft/resnet-18", | |
| ) | |
| elif "vgg" in model_name.lower(): | |
| # VGG model with default parameters | |
| model = model_class() | |
| else: | |
| # Generic model initialization | |
| try: | |
| model = model_class() | |
| except TypeError: | |
| # Skip models that require specific parameters we don't know about | |
| raise RuntimeError( | |
| f"Model {model_name} requires specific initialization parameters" | |
| ) | |
| # Show success message briefly, then clear it | |
| loading_placeholder.success(f"{model_name} model loaded successfully!") | |
| time.sleep(1.5) # Show success message for 1.5 seconds | |
| loading_placeholder.empty() # Clear the message | |
| return model | |
| except PermissionError as e: | |
| st.error(f"β Permission error: {str(e)}") | |
| if "cache" in str(e).lower(): | |
| st.info( | |
| "π‘ This is likely a cache permission issue. Please refresh the page and try again." | |
| ) | |
| return None | |
| except Exception as e: | |
| error_msg = str(e) | |
| st.error(f"β Error loading {model_name} model: {error_msg}") | |
| st.info("π‘ Possible solutions:") | |
| st.info("1. Refresh the page and try again") | |
| st.info("2. Check if HuggingFace services are available") | |
| st.info("3. Try a different model") | |
| return None | |
| def predict_food(model, image_bytes): | |
| """Make a prediction on the uploaded image.""" | |
| try: | |
| # Get prediction index | |
| prediction_idx = model.classify(image_bytes) | |
| # Get the label name | |
| if 0 <= prediction_idx < len(LABELS): | |
| prediction_label = LABELS[prediction_idx] | |
| return prediction_idx, prediction_label | |
| else: | |
| return None, "Unknown" | |
| except Exception as e: | |
| st.error(f"Error during prediction: {str(e)}") | |
| return None, "Error" | |
| def main(): | |
| """Main Streamlit application.""" | |
| st.set_page_config( | |
| page_title="TikkaMasalAI Food Classifier", page_icon="π½οΈ", layout="centered" | |
| ) | |
| st.title("π½οΈ TikkaMasalAI Food Classifier") | |
| st.markdown("Upload an image of food and let our AI identify what it is!") | |
| # Discover available models | |
| try: | |
| available_models = discover_models() | |
| except Exception as e: | |
| st.error(f"β Error discovering models: {str(e)}") | |
| st.info("Make sure the src/models directory contains valid model files.") | |
| return | |
| if not available_models: | |
| st.error("β No compatible models found in the src/models directory!") | |
| st.info("Make sure there are models that inherit from FoodClassificationModel.") | |
| return | |
| # Model selection in sidebar | |
| with st.sidebar: | |
| st.header("π€ Model Selection") | |
| selected_model_name = st.selectbox( | |
| "Choose a model:", | |
| options=list(available_models.keys()), | |
| help="Select which AI model to use for food classification", | |
| ) | |
| selected_model_info = available_models[selected_model_name] | |
| # Show model information | |
| st.info(f"**Selected:** {selected_model_name}") | |
| st.write(f"**Class:** `{selected_model_info['class_name']}`") | |
| st.write(f"**Module:** `{selected_model_info['module']}`") | |
| # Show app status | |
| status_container = st.container() | |
| # Load model with better UX | |
| with status_container: | |
| model_status = st.empty() | |
| progress_bar = st.progress(0) | |
| model_status.info("π Initializing AI model...") | |
| progress_bar.progress(25) | |
| model = load_model(selected_model_info) | |
| progress_bar.progress(100) | |
| if model is None: | |
| model_status.error("β Failed to load the model.") | |
| st.error("### π¨ Model Loading Failed") | |
| st.markdown( | |
| f""" | |
| **Failed to load:** {selected_model_name} | |
| **Possible causes:** | |
| - Model-specific initialization requirements | |
| - Missing dependencies for this model | |
| - Temporary HuggingFace services issue | |
| - Model cache conflicts in HF Spaces | |
| - Network connectivity problems | |
| **Solutions:** | |
| 1. **Try a different model** from the sidebar | |
| 2. **Refresh the page** and try again | |
| 3. **Wait 2-3 minutes** for any background downloads to complete | |
| 4. If the issue persists, the model will automatically retry | |
| """ | |
| ) | |
| # Add a retry button | |
| if st.button("π Retry Loading Model"): | |
| st.experimental_rerun() | |
| return | |
| model_status.success(f"β {selected_model_name} loaded and ready!") | |
| progress_bar.empty() | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "Choose a food image...", | |
| type=["png", "jpg", "jpeg"], | |
| help="Upload an image of food to classify", | |
| ) | |
| if uploaded_file is not None: | |
| # Read image bytes | |
| image_bytes = uploaded_file.read() | |
| # Display the uploaded image | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("πΈ Uploaded Image") | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| st.image(image, caption="Your uploaded image", use_container_width=True) | |
| with col2: | |
| st.subheader("π Prediction Results") | |
| # Make prediction | |
| with st.spinner("Analyzing your image..."): | |
| prediction_idx, prediction_label = predict_food(model, image_bytes) | |
| if prediction_idx is not None: | |
| # Display results | |
| st.success("Classification complete!") | |
| # Format the label for display | |
| display_label = prediction_label.replace("_", " ").title() | |
| st.markdown(f"### π·οΈ **{display_label}**") | |
| st.markdown(f"**Class Index:** {prediction_idx}") | |
| # Show confidence bar (placeholder since the model doesn't return probabilities) | |
| st.markdown("**Prediction Details:**") | |
| st.info(f"The AI model identified this image as **{display_label}**") | |
| # Show additional info | |
| with st.expander("βΉοΈ About this classification"): | |
| st.write(f"- **Model:** {selected_model_name}") | |
| st.write(f"- **Classes:** {len(LABELS)} different food types") | |
| st.write(f"- **Raw label:** `{prediction_label}`") | |
| st.write(f"- **Index:** {prediction_idx}") | |
| else: | |
| st.error("Failed to classify the image. Please try another image.") | |
| # Sidebar with information | |
| with st.sidebar: | |
| st.header("π About") | |
| st.write( | |
| f""" | |
| This app uses the **{selected_model_name}** model to classify food images into one of 101 different food categories. | |
| """ | |
| ) | |
| st.header("π― How to use") | |
| st.write( | |
| """ | |
| 1. Choose a model from the dropdown above | |
| 2. Upload an image of food using the file uploader | |
| 3. Wait for the AI to analyze your image | |
| 4. View the classification results | |
| """ | |
| ) | |
| st.header("π Supported Foods") | |
| st.write( | |
| f"The model can recognize **{len(LABELS)}** different types of food including:" | |
| ) | |
| # Show a sample of labels | |
| sample_labels = [label.replace("_", " ").title() for label in LABELS[:10]] | |
| for label in sample_labels: | |
| st.write(f"β’ {label}") | |
| st.write(f"... and {len(LABELS) - 10} more!") | |
| st.header("π§ Technical Details") | |
| st.write( | |
| f""" | |
| - **Selected Model:** {selected_model_name} | |
| - **Available Models:** {len(available_models)} | |
| - **Dataset:** Food-101 | |
| - **Framework:** PyTorch + Transformers | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| main() | |