AdrianHagen's picture
Clean history: code-only (no models)
21fb9ff
import streamlit as st
from PIL import Image
import io
import os
import time
import tempfile
from pathlib import Path
from src.models.model_discovery import discover_models
from src.labels import LABELS
def load_model(model_info):
"""Load and cache the selected model with proper error handling."""
model_class = model_info["class"]
model_name = model_info["class_name"]
# Set up custom cache directory to avoid permission issues
custom_cache = Path(tempfile.gettempdir()) / "tikka_masalai_cache"
custom_cache.mkdir(exist_ok=True)
# Set HuggingFace cache directory (use HF_HOME instead of deprecated TRANSFORMERS_CACHE)
os.environ["HF_HOME"] = str(custom_cache)
os.environ["TRANSFORMERS_CACHE"] = str(
custom_cache
) # Keep for backward compatibility
try:
# Use a placeholder for the loading message that we can clear
loading_placeholder = st.empty()
loading_placeholder.info(f"Loading {model_name} model...")
# Try to load the model - handle different model initialization patterns
if "prithiv" in model_name.lower():
# PrithivML model with specific initialization
model = model_class()
elif "resnet" in model_name.lower():
# ResNet model - check if it needs specific paths
try:
model = model_class()
except TypeError:
# Try with default parameters if it requires them
model = model_class(
preprocessor_path="microsoft/resnet-18",
model_path="microsoft/resnet-18",
)
elif "vgg" in model_name.lower():
# VGG model with default parameters
model = model_class()
else:
# Generic model initialization
try:
model = model_class()
except TypeError:
# Skip models that require specific parameters we don't know about
raise RuntimeError(
f"Model {model_name} requires specific initialization parameters"
)
# Show success message briefly, then clear it
loading_placeholder.success(f"{model_name} model loaded successfully!")
time.sleep(1.5) # Show success message for 1.5 seconds
loading_placeholder.empty() # Clear the message
return model
except PermissionError as e:
st.error(f"❌ Permission error: {str(e)}")
if "cache" in str(e).lower():
st.info(
"πŸ’‘ This is likely a cache permission issue. Please refresh the page and try again."
)
return None
except Exception as e:
error_msg = str(e)
st.error(f"❌ Error loading {model_name} model: {error_msg}")
st.info("πŸ’‘ Possible solutions:")
st.info("1. Refresh the page and try again")
st.info("2. Check if HuggingFace services are available")
st.info("3. Try a different model")
return None
def predict_food(model, image_bytes):
"""Make a prediction on the uploaded image."""
try:
# Get prediction index
prediction_idx = model.classify(image_bytes)
# Get the label name
if 0 <= prediction_idx < len(LABELS):
prediction_label = LABELS[prediction_idx]
return prediction_idx, prediction_label
else:
return None, "Unknown"
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
return None, "Error"
def main():
"""Main Streamlit application."""
st.set_page_config(
page_title="TikkaMasalAI Food Classifier", page_icon="🍽️", layout="centered"
)
st.title("🍽️ TikkaMasalAI Food Classifier")
st.markdown("Upload an image of food and let our AI identify what it is!")
# Discover available models
try:
available_models = discover_models()
except Exception as e:
st.error(f"❌ Error discovering models: {str(e)}")
st.info("Make sure the src/models directory contains valid model files.")
return
if not available_models:
st.error("❌ No compatible models found in the src/models directory!")
st.info("Make sure there are models that inherit from FoodClassificationModel.")
return
# Model selection in sidebar
with st.sidebar:
st.header("πŸ€– Model Selection")
selected_model_name = st.selectbox(
"Choose a model:",
options=list(available_models.keys()),
help="Select which AI model to use for food classification",
)
selected_model_info = available_models[selected_model_name]
# Show model information
st.info(f"**Selected:** {selected_model_name}")
st.write(f"**Class:** `{selected_model_info['class_name']}`")
st.write(f"**Module:** `{selected_model_info['module']}`")
# Show app status
status_container = st.container()
# Load model with better UX
with status_container:
model_status = st.empty()
progress_bar = st.progress(0)
model_status.info("πŸ”„ Initializing AI model...")
progress_bar.progress(25)
model = load_model(selected_model_info)
progress_bar.progress(100)
if model is None:
model_status.error("❌ Failed to load the model.")
st.error("### 🚨 Model Loading Failed")
st.markdown(
f"""
**Failed to load:** {selected_model_name}
**Possible causes:**
- Model-specific initialization requirements
- Missing dependencies for this model
- Temporary HuggingFace services issue
- Model cache conflicts in HF Spaces
- Network connectivity problems
**Solutions:**
1. **Try a different model** from the sidebar
2. **Refresh the page** and try again
3. **Wait 2-3 minutes** for any background downloads to complete
4. If the issue persists, the model will automatically retry
"""
)
# Add a retry button
if st.button("πŸ”„ Retry Loading Model"):
st.experimental_rerun()
return
model_status.success(f"βœ… {selected_model_name} loaded and ready!")
progress_bar.empty()
# File uploader
uploaded_file = st.file_uploader(
"Choose a food image...",
type=["png", "jpg", "jpeg"],
help="Upload an image of food to classify",
)
if uploaded_file is not None:
# Read image bytes
image_bytes = uploaded_file.read()
# Display the uploaded image
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("πŸ“Έ Uploaded Image")
image = Image.open(io.BytesIO(image_bytes))
st.image(image, caption="Your uploaded image", use_container_width=True)
with col2:
st.subheader("πŸ” Prediction Results")
# Make prediction
with st.spinner("Analyzing your image..."):
prediction_idx, prediction_label = predict_food(model, image_bytes)
if prediction_idx is not None:
# Display results
st.success("Classification complete!")
# Format the label for display
display_label = prediction_label.replace("_", " ").title()
st.markdown(f"### 🏷️ **{display_label}**")
st.markdown(f"**Class Index:** {prediction_idx}")
# Show confidence bar (placeholder since the model doesn't return probabilities)
st.markdown("**Prediction Details:**")
st.info(f"The AI model identified this image as **{display_label}**")
# Show additional info
with st.expander("ℹ️ About this classification"):
st.write(f"- **Model:** {selected_model_name}")
st.write(f"- **Classes:** {len(LABELS)} different food types")
st.write(f"- **Raw label:** `{prediction_label}`")
st.write(f"- **Index:** {prediction_idx}")
else:
st.error("Failed to classify the image. Please try another image.")
# Sidebar with information
with st.sidebar:
st.header("πŸ“‹ About")
st.write(
f"""
This app uses the **{selected_model_name}** model to classify food images into one of 101 different food categories.
"""
)
st.header("🎯 How to use")
st.write(
"""
1. Choose a model from the dropdown above
2. Upload an image of food using the file uploader
3. Wait for the AI to analyze your image
4. View the classification results
"""
)
st.header("πŸ• Supported Foods")
st.write(
f"The model can recognize **{len(LABELS)}** different types of food including:"
)
# Show a sample of labels
sample_labels = [label.replace("_", " ").title() for label in LABELS[:10]]
for label in sample_labels:
st.write(f"β€’ {label}")
st.write(f"... and {len(LABELS) - 10} more!")
st.header("πŸ”§ Technical Details")
st.write(
f"""
- **Selected Model:** {selected_model_name}
- **Available Models:** {len(available_models)}
- **Dataset:** Food-101
- **Framework:** PyTorch + Transformers
"""
)
if __name__ == "__main__":
main()