Spaces:
Running
Running
Commit
·
21fb9ff
0
Parent(s):
Clean history: code-only (no models)
Browse files- .dockerignore +64 -0
- .gitattributes +35 -0
- Dockerfile +55 -0
- README.md +12 -0
- app.py +275 -0
- config.toml +24 -0
- pyproject.toml +70 -0
- src/.DS_Store +0 -0
- src/__init__.py +8 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/cli.cpython-310.pyc +0 -0
- src/__pycache__/constants.cpython-310.pyc +0 -0
- src/__pycache__/constants.cpython-311.pyc +0 -0
- src/__pycache__/evaluate_imagenet.cpython-310.pyc +0 -0
- src/__pycache__/labels.cpython-310.pyc +0 -0
- src/data/download_data.py +8 -0
- src/eval/__pycache__/base_evaluator.cpython-310.pyc +0 -0
- src/eval/__pycache__/evaluate_food101.cpython-310.pyc +0 -0
- src/eval/__pycache__/evaluate_food101.cpython-311.pyc +0 -0
- src/eval/__pycache__/evaluate_imagenet.cpython-310.pyc +0 -0
- src/eval/eval.py +67 -0
- src/eval/evaluate_food101.py +326 -0
- src/labels.py +135 -0
- src/models/__pycache__/food_classification_model.cpython-310.pyc +0 -0
- src/models/__pycache__/food_classification_model.cpython-311.pyc +0 -0
- src/models/__pycache__/model_discovery.cpython-310.pyc +0 -0
- src/models/__pycache__/prithiv_ml_food101.cpython-310.pyc +0 -0
- src/models/__pycache__/prithiv_ml_food101.cpython-311.pyc +0 -0
- src/models/__pycache__/resnet18.cpython-310.pyc +0 -0
- src/models/__pycache__/resnet18.cpython-311.pyc +0 -0
- src/models/__pycache__/vgg16.cpython-310.pyc +0 -0
- src/models/food_classification_model.py +17 -0
- src/models/model_discovery.py +121 -0
- src/models/prithiv_ml_food101.py +89 -0
- src/models/resnet18.py +63 -0
- src/models/vgg16.py +130 -0
- src/train/__pycache__/preprocess_data.cpython-310.pyc +0 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git and version control
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.gitattributes
|
| 5 |
+
|
| 6 |
+
# Python cache and virtual environments
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
.Python
|
| 12 |
+
.venv/
|
| 13 |
+
venv/
|
| 14 |
+
.env
|
| 15 |
+
|
| 16 |
+
# Documentation and development files
|
| 17 |
+
docs/
|
| 18 |
+
references/
|
| 19 |
+
reports/
|
| 20 |
+
notebooks/
|
| 21 |
+
tests/
|
| 22 |
+
|
| 23 |
+
# Data files (too large for Docker image)
|
| 24 |
+
data/
|
| 25 |
+
mlruns/
|
| 26 |
+
models/*.pth
|
| 27 |
+
|
| 28 |
+
# Development and configuration files
|
| 29 |
+
.pytest_cache/
|
| 30 |
+
.coverage
|
| 31 |
+
.mypy_cache/
|
| 32 |
+
.ruff_cache/
|
| 33 |
+
*.log
|
| 34 |
+
|
| 35 |
+
# IDE and editor files
|
| 36 |
+
.vscode/
|
| 37 |
+
.idea/
|
| 38 |
+
*.swp
|
| 39 |
+
*.swo
|
| 40 |
+
*~
|
| 41 |
+
|
| 42 |
+
# OS files
|
| 43 |
+
.DS_Store
|
| 44 |
+
Thumbs.db
|
| 45 |
+
|
| 46 |
+
# Build artifacts
|
| 47 |
+
build/
|
| 48 |
+
dist/
|
| 49 |
+
*.egg-info/
|
| 50 |
+
|
| 51 |
+
# Streamlit specific
|
| 52 |
+
.streamlit/
|
| 53 |
+
|
| 54 |
+
# Project specific
|
| 55 |
+
Makefile
|
| 56 |
+
LICENSE
|
| 57 |
+
streamlit/README.md
|
| 58 |
+
|
| 59 |
+
# Deployment scripts
|
| 60 |
+
deploy_to_hf.sh
|
| 61 |
+
COMMANDS.md
|
| 62 |
+
|
| 63 |
+
# Food101 HF Space (don't include in Docker)
|
| 64 |
+
Food101/
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.10 slim image for smaller size
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install uv
|
| 8 |
+
RUN pip install uv
|
| 9 |
+
|
| 10 |
+
# Copy dependency files first (for better caching)
|
| 11 |
+
COPY pyproject.toml ./
|
| 12 |
+
COPY uv.lock ./
|
| 13 |
+
COPY README.md ./
|
| 14 |
+
|
| 15 |
+
# Install dependencies using uv (creates .venv)
|
| 16 |
+
RUN uv sync --frozen
|
| 17 |
+
|
| 18 |
+
# Use /tmp for all caches and Streamlit config to avoid permission issues in read-only paths
|
| 19 |
+
ENV HOME=/tmp \
|
| 20 |
+
XDG_CACHE_HOME=/tmp/.cache \
|
| 21 |
+
UV_CACHE_DIR=/tmp/.cache/uv \
|
| 22 |
+
PIP_CACHE_DIR=/tmp/.cache/pip \
|
| 23 |
+
HF_HOME=/tmp/.cache/huggingface \
|
| 24 |
+
TRANSFORMERS_CACHE=/tmp/.cache/huggingface/transformers \
|
| 25 |
+
TORCH_HOME=/tmp/.cache/torch \
|
| 26 |
+
STREAMLIT_CONFIG_DIR=/tmp/.streamlit \
|
| 27 |
+
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
|
| 28 |
+
RUN mkdir -p \
|
| 29 |
+
/tmp/.cache/uv \
|
| 30 |
+
/tmp/.cache/pip \
|
| 31 |
+
/tmp/.cache/huggingface/transformers \
|
| 32 |
+
/tmp/.cache/torch \
|
| 33 |
+
/tmp/.streamlit \
|
| 34 |
+
&& chmod -R 777 /tmp/.cache /tmp/.streamlit
|
| 35 |
+
|
| 36 |
+
# Copy application code
|
| 37 |
+
COPY src/ ./src/
|
| 38 |
+
COPY app.py ./
|
| 39 |
+
COPY config.toml ./
|
| 40 |
+
|
| 41 |
+
# Copy models
|
| 42 |
+
COPY models/ ./models/
|
| 43 |
+
|
| 44 |
+
# Expose Streamlit port (Hugging Face Spaces uses 7860)
|
| 45 |
+
EXPOSE 7860
|
| 46 |
+
|
| 47 |
+
# Set environment variables for Streamlit
|
| 48 |
+
ENV STREAMLIT_SERVER_PORT=7860
|
| 49 |
+
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 50 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
| 51 |
+
ENV STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION=false
|
| 52 |
+
|
| 53 |
+
# Run the application using the created virtual environment
|
| 54 |
+
ENV PATH="/app/.venv/bin:${PATH}"
|
| 55 |
+
CMD ["python", "-m", "streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Food101 Streamlit
|
| 3 |
+
emoji: 👁
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
short_description: Space for Detecting the Type of Dish in an Image
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from src.models.model_discovery import discover_models
|
| 10 |
+
from src.labels import LABELS
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_model(model_info):
|
| 14 |
+
"""Load and cache the selected model with proper error handling."""
|
| 15 |
+
model_class = model_info["class"]
|
| 16 |
+
model_name = model_info["class_name"]
|
| 17 |
+
|
| 18 |
+
# Set up custom cache directory to avoid permission issues
|
| 19 |
+
custom_cache = Path(tempfile.gettempdir()) / "tikka_masalai_cache"
|
| 20 |
+
custom_cache.mkdir(exist_ok=True)
|
| 21 |
+
|
| 22 |
+
# Set HuggingFace cache directory (use HF_HOME instead of deprecated TRANSFORMERS_CACHE)
|
| 23 |
+
os.environ["HF_HOME"] = str(custom_cache)
|
| 24 |
+
os.environ["TRANSFORMERS_CACHE"] = str(
|
| 25 |
+
custom_cache
|
| 26 |
+
) # Keep for backward compatibility
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
# Use a placeholder for the loading message that we can clear
|
| 30 |
+
loading_placeholder = st.empty()
|
| 31 |
+
loading_placeholder.info(f"Loading {model_name} model...")
|
| 32 |
+
|
| 33 |
+
# Try to load the model - handle different model initialization patterns
|
| 34 |
+
if "prithiv" in model_name.lower():
|
| 35 |
+
# PrithivML model with specific initialization
|
| 36 |
+
model = model_class()
|
| 37 |
+
elif "resnet" in model_name.lower():
|
| 38 |
+
# ResNet model - check if it needs specific paths
|
| 39 |
+
try:
|
| 40 |
+
model = model_class()
|
| 41 |
+
except TypeError:
|
| 42 |
+
# Try with default parameters if it requires them
|
| 43 |
+
model = model_class(
|
| 44 |
+
preprocessor_path="microsoft/resnet-18",
|
| 45 |
+
model_path="microsoft/resnet-18",
|
| 46 |
+
)
|
| 47 |
+
elif "vgg" in model_name.lower():
|
| 48 |
+
# VGG model with default parameters
|
| 49 |
+
model = model_class()
|
| 50 |
+
else:
|
| 51 |
+
# Generic model initialization
|
| 52 |
+
try:
|
| 53 |
+
model = model_class()
|
| 54 |
+
except TypeError:
|
| 55 |
+
# Skip models that require specific parameters we don't know about
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
f"Model {model_name} requires specific initialization parameters"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Show success message briefly, then clear it
|
| 61 |
+
loading_placeholder.success(f"{model_name} model loaded successfully!")
|
| 62 |
+
time.sleep(1.5) # Show success message for 1.5 seconds
|
| 63 |
+
loading_placeholder.empty() # Clear the message
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
except PermissionError as e:
|
| 67 |
+
st.error(f"❌ Permission error: {str(e)}")
|
| 68 |
+
if "cache" in str(e).lower():
|
| 69 |
+
st.info(
|
| 70 |
+
"💡 This is likely a cache permission issue. Please refresh the page and try again."
|
| 71 |
+
)
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
error_msg = str(e)
|
| 76 |
+
st.error(f"❌ Error loading {model_name} model: {error_msg}")
|
| 77 |
+
st.info("💡 Possible solutions:")
|
| 78 |
+
st.info("1. Refresh the page and try again")
|
| 79 |
+
st.info("2. Check if HuggingFace services are available")
|
| 80 |
+
st.info("3. Try a different model")
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def predict_food(model, image_bytes):
|
| 85 |
+
"""Make a prediction on the uploaded image."""
|
| 86 |
+
try:
|
| 87 |
+
# Get prediction index
|
| 88 |
+
prediction_idx = model.classify(image_bytes)
|
| 89 |
+
|
| 90 |
+
# Get the label name
|
| 91 |
+
if 0 <= prediction_idx < len(LABELS):
|
| 92 |
+
prediction_label = LABELS[prediction_idx]
|
| 93 |
+
return prediction_idx, prediction_label
|
| 94 |
+
else:
|
| 95 |
+
return None, "Unknown"
|
| 96 |
+
except Exception as e:
|
| 97 |
+
st.error(f"Error during prediction: {str(e)}")
|
| 98 |
+
return None, "Error"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def main():
|
| 102 |
+
"""Main Streamlit application."""
|
| 103 |
+
st.set_page_config(
|
| 104 |
+
page_title="TikkaMasalAI Food Classifier", page_icon="🍽️", layout="centered"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
st.title("🍽️ TikkaMasalAI Food Classifier")
|
| 108 |
+
st.markdown("Upload an image of food and let our AI identify what it is!")
|
| 109 |
+
|
| 110 |
+
# Discover available models
|
| 111 |
+
try:
|
| 112 |
+
available_models = discover_models()
|
| 113 |
+
except Exception as e:
|
| 114 |
+
st.error(f"❌ Error discovering models: {str(e)}")
|
| 115 |
+
st.info("Make sure the src/models directory contains valid model files.")
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
if not available_models:
|
| 119 |
+
st.error("❌ No compatible models found in the src/models directory!")
|
| 120 |
+
st.info("Make sure there are models that inherit from FoodClassificationModel.")
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
# Model selection in sidebar
|
| 124 |
+
with st.sidebar:
|
| 125 |
+
st.header("🤖 Model Selection")
|
| 126 |
+
selected_model_name = st.selectbox(
|
| 127 |
+
"Choose a model:",
|
| 128 |
+
options=list(available_models.keys()),
|
| 129 |
+
help="Select which AI model to use for food classification",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
selected_model_info = available_models[selected_model_name]
|
| 133 |
+
|
| 134 |
+
# Show model information
|
| 135 |
+
st.info(f"**Selected:** {selected_model_name}")
|
| 136 |
+
st.write(f"**Class:** `{selected_model_info['class_name']}`")
|
| 137 |
+
st.write(f"**Module:** `{selected_model_info['module']}`")
|
| 138 |
+
|
| 139 |
+
# Show app status
|
| 140 |
+
status_container = st.container()
|
| 141 |
+
|
| 142 |
+
# Load model with better UX
|
| 143 |
+
with status_container:
|
| 144 |
+
model_status = st.empty()
|
| 145 |
+
progress_bar = st.progress(0)
|
| 146 |
+
|
| 147 |
+
model_status.info("🔄 Initializing AI model...")
|
| 148 |
+
progress_bar.progress(25)
|
| 149 |
+
|
| 150 |
+
model = load_model(selected_model_info)
|
| 151 |
+
progress_bar.progress(100)
|
| 152 |
+
|
| 153 |
+
if model is None:
|
| 154 |
+
model_status.error("❌ Failed to load the model.")
|
| 155 |
+
st.error("### 🚨 Model Loading Failed")
|
| 156 |
+
st.markdown(
|
| 157 |
+
f"""
|
| 158 |
+
**Failed to load:** {selected_model_name}
|
| 159 |
+
|
| 160 |
+
**Possible causes:**
|
| 161 |
+
- Model-specific initialization requirements
|
| 162 |
+
- Missing dependencies for this model
|
| 163 |
+
- Temporary HuggingFace services issue
|
| 164 |
+
- Model cache conflicts in HF Spaces
|
| 165 |
+
- Network connectivity problems
|
| 166 |
+
|
| 167 |
+
**Solutions:**
|
| 168 |
+
1. **Try a different model** from the sidebar
|
| 169 |
+
2. **Refresh the page** and try again
|
| 170 |
+
3. **Wait 2-3 minutes** for any background downloads to complete
|
| 171 |
+
4. If the issue persists, the model will automatically retry
|
| 172 |
+
"""
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Add a retry button
|
| 176 |
+
if st.button("🔄 Retry Loading Model"):
|
| 177 |
+
st.experimental_rerun()
|
| 178 |
+
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
model_status.success(f"✅ {selected_model_name} loaded and ready!")
|
| 182 |
+
progress_bar.empty()
|
| 183 |
+
|
| 184 |
+
# File uploader
|
| 185 |
+
uploaded_file = st.file_uploader(
|
| 186 |
+
"Choose a food image...",
|
| 187 |
+
type=["png", "jpg", "jpeg"],
|
| 188 |
+
help="Upload an image of food to classify",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if uploaded_file is not None:
|
| 192 |
+
# Read image bytes
|
| 193 |
+
image_bytes = uploaded_file.read()
|
| 194 |
+
|
| 195 |
+
# Display the uploaded image
|
| 196 |
+
col1, col2 = st.columns([1, 1])
|
| 197 |
+
|
| 198 |
+
with col1:
|
| 199 |
+
st.subheader("📸 Uploaded Image")
|
| 200 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 201 |
+
st.image(image, caption="Your uploaded image", use_container_width=True)
|
| 202 |
+
|
| 203 |
+
with col2:
|
| 204 |
+
st.subheader("🔍 Prediction Results")
|
| 205 |
+
|
| 206 |
+
# Make prediction
|
| 207 |
+
with st.spinner("Analyzing your image..."):
|
| 208 |
+
prediction_idx, prediction_label = predict_food(model, image_bytes)
|
| 209 |
+
|
| 210 |
+
if prediction_idx is not None:
|
| 211 |
+
# Display results
|
| 212 |
+
st.success("Classification complete!")
|
| 213 |
+
|
| 214 |
+
# Format the label for display
|
| 215 |
+
display_label = prediction_label.replace("_", " ").title()
|
| 216 |
+
|
| 217 |
+
st.markdown(f"### 🏷️ **{display_label}**")
|
| 218 |
+
st.markdown(f"**Class Index:** {prediction_idx}")
|
| 219 |
+
|
| 220 |
+
# Show confidence bar (placeholder since the model doesn't return probabilities)
|
| 221 |
+
st.markdown("**Prediction Details:**")
|
| 222 |
+
st.info(f"The AI model identified this image as **{display_label}**")
|
| 223 |
+
|
| 224 |
+
# Show additional info
|
| 225 |
+
with st.expander("ℹ️ About this classification"):
|
| 226 |
+
st.write(f"- **Model:** {selected_model_name}")
|
| 227 |
+
st.write(f"- **Classes:** {len(LABELS)} different food types")
|
| 228 |
+
st.write(f"- **Raw label:** `{prediction_label}`")
|
| 229 |
+
st.write(f"- **Index:** {prediction_idx}")
|
| 230 |
+
else:
|
| 231 |
+
st.error("Failed to classify the image. Please try another image.")
|
| 232 |
+
|
| 233 |
+
# Sidebar with information
|
| 234 |
+
with st.sidebar:
|
| 235 |
+
st.header("📋 About")
|
| 236 |
+
st.write(
|
| 237 |
+
f"""
|
| 238 |
+
This app uses the **{selected_model_name}** model to classify food images into one of 101 different food categories.
|
| 239 |
+
"""
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
st.header("🎯 How to use")
|
| 243 |
+
st.write(
|
| 244 |
+
"""
|
| 245 |
+
1. Choose a model from the dropdown above
|
| 246 |
+
2. Upload an image of food using the file uploader
|
| 247 |
+
3. Wait for the AI to analyze your image
|
| 248 |
+
4. View the classification results
|
| 249 |
+
"""
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
st.header("🍕 Supported Foods")
|
| 253 |
+
st.write(
|
| 254 |
+
f"The model can recognize **{len(LABELS)}** different types of food including:"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Show a sample of labels
|
| 258 |
+
sample_labels = [label.replace("_", " ").title() for label in LABELS[:10]]
|
| 259 |
+
for label in sample_labels:
|
| 260 |
+
st.write(f"• {label}")
|
| 261 |
+
st.write(f"... and {len(LABELS) - 10} more!")
|
| 262 |
+
|
| 263 |
+
st.header("🔧 Technical Details")
|
| 264 |
+
st.write(
|
| 265 |
+
f"""
|
| 266 |
+
- **Selected Model:** {selected_model_name}
|
| 267 |
+
- **Available Models:** {len(available_models)}
|
| 268 |
+
- **Dataset:** Food-101
|
| 269 |
+
- **Framework:** PyTorch + Transformers
|
| 270 |
+
"""
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
main()
|
config.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[global]
|
| 2 |
+
developmentMode = false
|
| 3 |
+
|
| 4 |
+
[server]
|
| 5 |
+
port = 7860
|
| 6 |
+
address = "0.0.0.0"
|
| 7 |
+
headless = true
|
| 8 |
+
enableCORS = false
|
| 9 |
+
enableXsrfProtection = false
|
| 10 |
+
|
| 11 |
+
[browser]
|
| 12 |
+
gatherUsageStats = false
|
| 13 |
+
|
| 14 |
+
[client]
|
| 15 |
+
toolbarMode = "minimal"
|
| 16 |
+
|
| 17 |
+
[runner]
|
| 18 |
+
magicEnabled = true
|
| 19 |
+
installTracer = false
|
| 20 |
+
fixMatplotlib = true
|
| 21 |
+
|
| 22 |
+
[logger]
|
| 23 |
+
level = "info"
|
| 24 |
+
messageFormat = "%(asctime)s %(message)s"
|
pyproject.toml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "tikka-masalai"
|
| 7 |
+
version = "0.0.1"
|
| 8 |
+
description = "MLOPS project FIB"
|
| 9 |
+
authors = [
|
| 10 |
+
{ name = "Team Tikka MasalAI" },
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
readme = "README.md"
|
| 14 |
+
classifiers = [
|
| 15 |
+
"Programming Language :: Python :: 3",
|
| 16 |
+
]
|
| 17 |
+
dependencies = [
|
| 18 |
+
"dagshub>=0.6.3",
|
| 19 |
+
"datasets<4.1.1",
|
| 20 |
+
"dvc>=3.63.0",
|
| 21 |
+
"dvc-s3>=3.2.2",
|
| 22 |
+
"huggingface-hub>=0.35.0",
|
| 23 |
+
"ipykernel>=6.30.1",
|
| 24 |
+
"ipywidgets>=8.1.7",
|
| 25 |
+
"matplotlib>=3.10.6",
|
| 26 |
+
"mlflow>=2,<3",
|
| 27 |
+
"numpy>=2.2.6",
|
| 28 |
+
"pandas>=2.3.2",
|
| 29 |
+
"pillow>=11.3.0",
|
| 30 |
+
"polars>=1.0.0",
|
| 31 |
+
"pyarrow>=4.0.0,<20.0.0",
|
| 32 |
+
"pytest",
|
| 33 |
+
"python-dotenv",
|
| 34 |
+
"ruff",
|
| 35 |
+
"streamlit>=1.31.0",
|
| 36 |
+
"torch>=2.8.0",
|
| 37 |
+
"torchvision>=0.23.0",
|
| 38 |
+
"tqdm>=4.67.1",
|
| 39 |
+
"transformers>=4.56.2",
|
| 40 |
+
]
|
| 41 |
+
requires-python = ">=3.10"
|
| 42 |
+
|
| 43 |
+
[project.optional-dependencies]
|
| 44 |
+
dev = []
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# This makes src/ discoverable as a package
|
| 48 |
+
[tool.hatch.build.targets.wheel]
|
| 49 |
+
packages = ["src"]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
[tool.ruff]
|
| 53 |
+
line-length = 99
|
| 54 |
+
src = ["src"]
|
| 55 |
+
include = ["pyproject.toml", "src/**/*.py"]
|
| 56 |
+
|
| 57 |
+
[tool.ruff.lint]
|
| 58 |
+
extend-select = ["I"] # Add import sorting
|
| 59 |
+
|
| 60 |
+
[tool.ruff.lint.isort]
|
| 61 |
+
known-first-party = ["src"]
|
| 62 |
+
force-sort-within-sections = true
|
| 63 |
+
|
| 64 |
+
[dependency-groups]
|
| 65 |
+
dev = [
|
| 66 |
+
"black>=25.1.0",
|
| 67 |
+
"pylint>=3.3.8",
|
| 68 |
+
"pytest>=8.4.2",
|
| 69 |
+
]
|
| 70 |
+
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TikkaMasalAI source package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Export commonly used utilities
|
| 6 |
+
from .models.model_discovery import discover_models, get_model_names, get_model_info
|
| 7 |
+
|
| 8 |
+
__all__ = ["discover_models", "get_model_names", "get_model_info"]
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (333 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
src/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
src/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
src/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (2.64 kB). View file
|
|
|
src/__pycache__/evaluate_imagenet.cpython-310.pyc
ADDED
|
Binary file (6.93 kB). View file
|
|
|
src/__pycache__/labels.cpython-310.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
src/data/download_data.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
snapshot_download(
|
| 4 |
+
repo_id="ethz/food101",
|
| 5 |
+
repo_type="dataset",
|
| 6 |
+
local_dir="./data/raw/food101",
|
| 7 |
+
local_dir_use_symlinks=False, # ensures actual files, not symlinks
|
| 8 |
+
)
|
src/eval/__pycache__/base_evaluator.cpython-310.pyc
ADDED
|
Binary file (9.03 kB). View file
|
|
|
src/eval/__pycache__/evaluate_food101.cpython-310.pyc
ADDED
|
Binary file (9.38 kB). View file
|
|
|
src/eval/__pycache__/evaluate_food101.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
src/eval/__pycache__/evaluate_imagenet.cpython-310.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
src/eval/eval.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example script demonstrating how to evaluate multiple models on different datasets.
|
| 4 |
+
|
| 5 |
+
This script shows how to use the enhanced evaluation framework
|
| 6 |
+
with different model implementations including VGG16, ResNet18, and PrithivMlFood101.
|
| 7 |
+
"""
|
| 8 |
+
from src.models.vgg16 import VGG16
|
| 9 |
+
from src.models.resnet18 import Resnet18
|
| 10 |
+
from src.models.prithiv_ml_food101 import PrithivMlFood101
|
| 11 |
+
from src.eval.evaluate_food101 import Food101Evaluator
|
| 12 |
+
from src.models.food_classification_model import FoodClassificationModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def evaluate_food101(
|
| 16 |
+
model: FoodClassificationModel,
|
| 17 |
+
experiment_name: str = "food101_evaluation",
|
| 18 |
+
sample_limit: int = 50,
|
| 19 |
+
random_seed: int = 42,
|
| 20 |
+
run_name: str = None,
|
| 21 |
+
):
|
| 22 |
+
"""Main evaluation function."""
|
| 23 |
+
evaluator = Food101Evaluator(
|
| 24 |
+
model, experiment_name, sample_limit, random_seed, run_name
|
| 25 |
+
)
|
| 26 |
+
evaluator.run_evaluation()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
"""Demonstrate evaluation with multiple model architectures."""
|
| 31 |
+
|
| 32 |
+
print("=" * 90)
|
| 33 |
+
print("Multi-Model Evaluation: VGG16 vs ResNet-18 vs PrithivMlFood101")
|
| 34 |
+
|
| 35 |
+
# Food101 Evaluations
|
| 36 |
+
print("\n=== Food101 Evaluations ===")
|
| 37 |
+
|
| 38 |
+
print("\n1. Evaluating PrithivMlFood101 on Food101...")
|
| 39 |
+
prithiv_model = PrithivMlFood101()
|
| 40 |
+
evaluate_food101(
|
| 41 |
+
experiment_name="Food101_Model_Comparison",
|
| 42 |
+
run_name="PrithivML_Baseline_50samples",
|
| 43 |
+
sample_limit=50, # Small sample for demonstration
|
| 44 |
+
model=prithiv_model,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
print("\n2. Evaluating ResNet-18 on Food101 ...")
|
| 48 |
+
resnet18_food_model = Resnet18()
|
| 49 |
+
evaluate_food101(
|
| 50 |
+
experiment_name="Food101_Model_Comparison",
|
| 51 |
+
run_name="ResNet18_Transfer_50samples",
|
| 52 |
+
sample_limit=50, # Small sample for demonstration
|
| 53 |
+
model=resnet18_food_model,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
print("\n3. Evaluating VGG16 on Food101 ...")
|
| 57 |
+
vgg16_food_model = VGG16()
|
| 58 |
+
evaluate_food101(
|
| 59 |
+
experiment_name="Food101_Model_Comparison",
|
| 60 |
+
run_name="VGG16_Transfer_50samples",
|
| 61 |
+
sample_limit=50, # Small sample for demonstration
|
| 62 |
+
model=vgg16_food_model,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
src/eval/evaluate_food101.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Food101 evaluation script for model evaluation with MLflow tracking.
|
| 4 |
+
|
| 5 |
+
This script evaluates models on the Food101 dataset with MLflow experiment tracking.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Dict, Tuple, Any, Union
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import mlflow
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import glob
|
| 14 |
+
import random
|
| 15 |
+
import dagshub
|
| 16 |
+
|
| 17 |
+
from src.models.food_classification_model import FoodClassificationModel
|
| 18 |
+
from src.labels import LABELS, index_to_label
|
| 19 |
+
|
| 20 |
+
dagshub.init(repo_owner="HubertWojcik10", repo_name="TikkaMasalAI", mlflow=True)
|
| 21 |
+
mlflow.autolog()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Food101Evaluator:
|
| 25 |
+
"""Model evaluator for Food101 dataset with MLflow tracking."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
model: FoodClassificationModel,
|
| 30 |
+
experiment_name: str = "food101_evaluation",
|
| 31 |
+
sample_limit: int = 50,
|
| 32 |
+
random_seed: int = 42,
|
| 33 |
+
run_name: str = None,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the Food101 evaluator.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model: FoodClassificationModel instance to use for evaluation (required)
|
| 40 |
+
experiment_name: Name of the MLflow experiment
|
| 41 |
+
sample_limit: Maximum number of samples to evaluate
|
| 42 |
+
random_seed: Random seed for reproducible sampling
|
| 43 |
+
run_name: Custom name for the MLflow run (optional)
|
| 44 |
+
"""
|
| 45 |
+
self.DATASET_NAME = "Food101"
|
| 46 |
+
self.experiment_name = experiment_name
|
| 47 |
+
self.sample_limit = sample_limit
|
| 48 |
+
self.model = model
|
| 49 |
+
self.random_seed = random_seed
|
| 50 |
+
self.custom_run_name = run_name
|
| 51 |
+
self.model_name = self.model.__class__.__name__
|
| 52 |
+
self.data_dir = (
|
| 53 |
+
Path(__file__).parent.parent.parent / "data" / "raw" / "food101" / "data"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def load_validation_data(self) -> List[Tuple[bytes, int]]:
|
| 57 |
+
"""
|
| 58 |
+
Load validation data from parquet files with random sampling.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
List of tuples: (image_bytes, true_index)
|
| 62 |
+
"""
|
| 63 |
+
random.seed(self.random_seed)
|
| 64 |
+
|
| 65 |
+
validation_files = glob.glob(f"{self.data_dir}/validation-*.parquet")
|
| 66 |
+
print(f"Found {len(validation_files)} validation files")
|
| 67 |
+
|
| 68 |
+
# Load all samples first
|
| 69 |
+
all_samples = []
|
| 70 |
+
|
| 71 |
+
for file_path in validation_files:
|
| 72 |
+
print(f"Loading from {Path(file_path).name}...")
|
| 73 |
+
df = pd.read_parquet(file_path)
|
| 74 |
+
|
| 75 |
+
for _, row in df.iterrows():
|
| 76 |
+
image_data = row["image"]["bytes"]
|
| 77 |
+
true_index = row["label"]
|
| 78 |
+
|
| 79 |
+
all_samples.append((image_data, true_index))
|
| 80 |
+
|
| 81 |
+
print(f"Total available samples: {len(all_samples)}")
|
| 82 |
+
|
| 83 |
+
# Randomly sample the requested number of samples
|
| 84 |
+
if len(all_samples) <= self.sample_limit:
|
| 85 |
+
selected_samples = all_samples
|
| 86 |
+
print(f"Using all {len(selected_samples)} available samples")
|
| 87 |
+
else:
|
| 88 |
+
selected_samples = random.sample(all_samples, self.sample_limit)
|
| 89 |
+
print(
|
| 90 |
+
f"Randomly selected {len(selected_samples)} samples from {len(all_samples)} available"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
print(f"Random seed used: {self.random_seed}")
|
| 94 |
+
return selected_samples
|
| 95 |
+
|
| 96 |
+
def calculate_accuracy(
|
| 97 |
+
self, predictions: List[Union[int, str]], ground_truths: List[int]
|
| 98 |
+
) -> float:
|
| 99 |
+
"""
|
| 100 |
+
Calculate exact accuracy for Food101 dataset.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
predictions: List of predicted indices or label names
|
| 104 |
+
ground_truths: List of true labels
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Accuracy score as float
|
| 108 |
+
"""
|
| 109 |
+
if not predictions or not ground_truths:
|
| 110 |
+
return 0.0
|
| 111 |
+
|
| 112 |
+
# Check for exact matches
|
| 113 |
+
exact_matches = 0
|
| 114 |
+
for pred, true in zip(predictions, ground_truths):
|
| 115 |
+
if pred == true:
|
| 116 |
+
exact_matches += 1
|
| 117 |
+
|
| 118 |
+
return exact_matches / len(predictions)
|
| 119 |
+
|
| 120 |
+
def evaluate_model(
|
| 121 |
+
self, samples: List[Tuple[bytes, int]], verbose: bool = True
|
| 122 |
+
) -> Dict[str, Any]:
|
| 123 |
+
"""
|
| 124 |
+
Evaluate the model on the provided samples.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
samples: List of (image_bytes, true_index) tuples
|
| 128 |
+
verbose: Whether to print detailed results
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Dictionary with evaluation metrics
|
| 132 |
+
"""
|
| 133 |
+
print(f"\nEvaluating model on {len(samples)} samples...")
|
| 134 |
+
|
| 135 |
+
predictions = []
|
| 136 |
+
ground_truths = []
|
| 137 |
+
prediction_examples = []
|
| 138 |
+
correct_predictions = 0
|
| 139 |
+
|
| 140 |
+
for i, (image_bytes, true_index) in enumerate(samples):
|
| 141 |
+
try:
|
| 142 |
+
predicted_index = self.model.classify(image_bytes)
|
| 143 |
+
predictions.append(predicted_index)
|
| 144 |
+
ground_truths.append(true_index)
|
| 145 |
+
|
| 146 |
+
# Check if prediction is correct using dataset-specific logic
|
| 147 |
+
is_correct = predicted_index == true_index
|
| 148 |
+
if is_correct:
|
| 149 |
+
correct_predictions += 1
|
| 150 |
+
|
| 151 |
+
# Convert index to label name for display and logging
|
| 152 |
+
predicted_label_name = index_to_label(predicted_index)
|
| 153 |
+
|
| 154 |
+
# Store first 10 examples for MLflow
|
| 155 |
+
if i < 10:
|
| 156 |
+
prediction_examples.append(
|
| 157 |
+
{
|
| 158 |
+
"sample_id": i + 1,
|
| 159 |
+
"true_label": LABELS[true_index],
|
| 160 |
+
"predicted_label": predicted_label_name,
|
| 161 |
+
"predicted_index": predicted_index,
|
| 162 |
+
"true_index": true_index,
|
| 163 |
+
"is_correct": is_correct,
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if verbose and i < 10: # Print first 10 predictions
|
| 168 |
+
status = "✓" if is_correct else "✗"
|
| 169 |
+
print(
|
| 170 |
+
f"Sample {i+1:2d}: {status} True='{LABELS[true_index]:25s}' (idx: {true_index}) | Predicted='{predicted_label_name}' (idx: {predicted_index})"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Error processing sample {i+1}: {e}")
|
| 175 |
+
predictions.append("ERROR")
|
| 176 |
+
ground_truths.append(true_index)
|
| 177 |
+
|
| 178 |
+
# Calculate metrics
|
| 179 |
+
total_samples = len(samples)
|
| 180 |
+
successful_predictions = len([p for p in predictions if p != "ERROR"])
|
| 181 |
+
|
| 182 |
+
# Calculate accuracy using dataset-specific method
|
| 183 |
+
accuracy = self.calculate_accuracy(predictions, ground_truths)
|
| 184 |
+
success_rate = (
|
| 185 |
+
successful_predictions / total_samples if total_samples > 0 else 0
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
results = {
|
| 189 |
+
"total_samples": total_samples,
|
| 190 |
+
"successful_predictions": successful_predictions,
|
| 191 |
+
"correct_predictions": correct_predictions,
|
| 192 |
+
"success_rate": success_rate,
|
| 193 |
+
"accuracy": accuracy,
|
| 194 |
+
"prediction_examples": prediction_examples,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return results
|
| 198 |
+
|
| 199 |
+
def log_mlflow_metrics(self, results: Dict[str, Any]) -> None:
|
| 200 |
+
"""
|
| 201 |
+
Log evaluation metrics to MLflow.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
results: The results from the evaluation.
|
| 205 |
+
"""
|
| 206 |
+
mlflow.log_metric("total_samples", results["total_samples"])
|
| 207 |
+
mlflow.log_metric("successful_predictions", results["successful_predictions"])
|
| 208 |
+
mlflow.log_metric("success_rate", results["success_rate"])
|
| 209 |
+
mlflow.log_metric("correct_predictions", results["correct_predictions"])
|
| 210 |
+
mlflow.log_metric("accuracy", results["accuracy"])
|
| 211 |
+
|
| 212 |
+
def log_mlflow_artifacts(self, results: Dict[str, Any]) -> None:
|
| 213 |
+
"""
|
| 214 |
+
Log evaluation artifacts to MLflow.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
results: The results from the evaluation.
|
| 218 |
+
|
| 219 |
+
"""
|
| 220 |
+
examples_data = []
|
| 221 |
+
for example in results["prediction_examples"]:
|
| 222 |
+
status = "✓" if example.get("is_correct", False) else "✗"
|
| 223 |
+
examples_data.append(
|
| 224 |
+
f"Sample {example['sample_id']}: {status} {example['true_label']} -> {example['predicted_label']}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
examples_text = "\n".join(examples_data)
|
| 228 |
+
examples_file = f"{self.DATASET_NAME.lower()}_evaluation_examples.txt"
|
| 229 |
+
with open(examples_file, "w", encoding="utf-8", newline="") as f:
|
| 230 |
+
f.write(examples_text)
|
| 231 |
+
mlflow.log_artifact(examples_file)
|
| 232 |
+
|
| 233 |
+
model_source = (
|
| 234 |
+
getattr(self.model, "model_path", "N/A")
|
| 235 |
+
if hasattr(self.model, "model_path")
|
| 236 |
+
else "N/A"
|
| 237 |
+
)
|
| 238 |
+
summary = f"""{self.model_name} {self.DATASET_NAME} Evaluation Summary
|
| 239 |
+
========================================={'=' * len(self.DATASET_NAME)}
|
| 240 |
+
Model: {self.model_name} ({model_source})
|
| 241 |
+
Dataset: {self.DATASET_NAME} validation set
|
| 242 |
+
Samples: {results['total_samples']}
|
| 243 |
+
Success Rate: {results['success_rate']:.2%}
|
| 244 |
+
Accuracy: {results['accuracy']:.2%}
|
| 245 |
+
Correct Predictions: {results['correct_predictions']}
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
summary_file = f"{self.DATASET_NAME.lower()}_evaluation_summary.txt"
|
| 249 |
+
with open(summary_file, "w", encoding="utf-8", newline="") as f:
|
| 250 |
+
f.write(summary)
|
| 251 |
+
mlflow.log_artifact(summary_file)
|
| 252 |
+
|
| 253 |
+
# Clean up temporary files
|
| 254 |
+
Path(examples_file).unlink(missing_ok=True)
|
| 255 |
+
Path(summary_file).unlink(missing_ok=True)
|
| 256 |
+
|
| 257 |
+
def run_evaluation(self) -> None:
|
| 258 |
+
"""Run the complete evaluation pipeline with MLflow tracking."""
|
| 259 |
+
print("=" * 60)
|
| 260 |
+
print(f"{self.model_name} {self.DATASET_NAME} Evaluation with MLflow")
|
| 261 |
+
print("=" * 60)
|
| 262 |
+
|
| 263 |
+
mlflow.set_experiment(self.experiment_name)
|
| 264 |
+
|
| 265 |
+
# Create descriptive run name
|
| 266 |
+
if self.custom_run_name:
|
| 267 |
+
run_name = self.custom_run_name
|
| 268 |
+
else:
|
| 269 |
+
timestamp = datetime.now().strftime("%m%d_%H%M")
|
| 270 |
+
run_name = f"{self.model_name}_Food101_n{self.sample_limit}_seed{self.random_seed}_{timestamp}"
|
| 271 |
+
|
| 272 |
+
with mlflow.start_run(run_name=run_name):
|
| 273 |
+
# Add useful tags for filtering and organization
|
| 274 |
+
mlflow.set_tag("model_type", self.model_name)
|
| 275 |
+
mlflow.set_tag("dataset", self.DATASET_NAME)
|
| 276 |
+
mlflow.set_tag("sample_size", str(self.sample_limit))
|
| 277 |
+
mlflow.set_tag("evaluation_type", "validation")
|
| 278 |
+
|
| 279 |
+
mlflow.log_param("model_name", self.model_name)
|
| 280 |
+
mlflow.log_param("model_class", self.model.__class__.__name__)
|
| 281 |
+
mlflow.log_param("dataset", self.DATASET_NAME)
|
| 282 |
+
mlflow.log_param("sample_limit", self.sample_limit)
|
| 283 |
+
mlflow.log_param("random_seed", self.random_seed)
|
| 284 |
+
mlflow.log_param("evaluation_date", datetime.now().isoformat())
|
| 285 |
+
|
| 286 |
+
# Log model-specific parameters if available
|
| 287 |
+
if hasattr(self.model, "model_path"):
|
| 288 |
+
mlflow.log_param(
|
| 289 |
+
"model_source", getattr(self.model, "model_path", "Unknown")
|
| 290 |
+
)
|
| 291 |
+
if hasattr(self.model, "preprocessor_path"):
|
| 292 |
+
mlflow.log_param(
|
| 293 |
+
"preprocessor_path",
|
| 294 |
+
getattr(self.model, "preprocessor_path", "Unknown"),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
samples = self.load_validation_data()
|
| 298 |
+
|
| 299 |
+
if not samples:
|
| 300 |
+
print(
|
| 301 |
+
f"No validation samples loaded. Check the {self.DATASET_NAME} dataset connection."
|
| 302 |
+
)
|
| 303 |
+
mlflow.log_param("status", "failed - no data")
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
mlflow.log_param("samples_loaded", len(samples))
|
| 307 |
+
|
| 308 |
+
results = self.evaluate_model(samples, verbose=True)
|
| 309 |
+
|
| 310 |
+
self.log_mlflow_metrics(results)
|
| 311 |
+
self.log_mlflow_artifacts(results)
|
| 312 |
+
|
| 313 |
+
self._print_results(results)
|
| 314 |
+
|
| 315 |
+
def _print_results(self, results: Dict[str, Any]) -> None:
|
| 316 |
+
"""Print evaluation results to console."""
|
| 317 |
+
print("\n" + "=" * 60)
|
| 318 |
+
print("EVALUATION RESULTS")
|
| 319 |
+
print("=" * 60)
|
| 320 |
+
print(f"Total samples processed: {results['total_samples']}")
|
| 321 |
+
print(f"Successful predictions: {results['successful_predictions']}")
|
| 322 |
+
print(f"Success rate: {results['success_rate']:.2%}")
|
| 323 |
+
print(f"Correct predictions: {results['correct_predictions']}")
|
| 324 |
+
print(f"Accuracy: {results['accuracy']:.2%}")
|
| 325 |
+
|
| 326 |
+
print(f"\nMLflow run ID: {mlflow.active_run().info.run_id}")
|
src/labels.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LABELS = [
|
| 2 |
+
"apple_pie",
|
| 3 |
+
"baby_back_ribs",
|
| 4 |
+
"baklava",
|
| 5 |
+
"beef_carpaccio",
|
| 6 |
+
"beef_tartare",
|
| 7 |
+
"beet_salad",
|
| 8 |
+
"beignets",
|
| 9 |
+
"bibimbap",
|
| 10 |
+
"bread_pudding",
|
| 11 |
+
"breakfast_burrito",
|
| 12 |
+
"bruschetta",
|
| 13 |
+
"caesar_salad",
|
| 14 |
+
"cannoli",
|
| 15 |
+
"caprese_salad",
|
| 16 |
+
"carrot_cake",
|
| 17 |
+
"ceviche",
|
| 18 |
+
"cheesecake",
|
| 19 |
+
"cheese_plate",
|
| 20 |
+
"chicken_curry",
|
| 21 |
+
"chicken_quesadilla",
|
| 22 |
+
"chicken_wings",
|
| 23 |
+
"chocolate_cake",
|
| 24 |
+
"chocolate_mousse",
|
| 25 |
+
"churros",
|
| 26 |
+
"clam_chowder",
|
| 27 |
+
"club_sandwich",
|
| 28 |
+
"crab_cakes",
|
| 29 |
+
"creme_brulee",
|
| 30 |
+
"croque_madame",
|
| 31 |
+
"cup_cakes",
|
| 32 |
+
"deviled_eggs",
|
| 33 |
+
"donuts",
|
| 34 |
+
"dumplings",
|
| 35 |
+
"edamame",
|
| 36 |
+
"eggs_benedict",
|
| 37 |
+
"escargots",
|
| 38 |
+
"falafel",
|
| 39 |
+
"filet_mignon",
|
| 40 |
+
"fish_and_chips",
|
| 41 |
+
"foie_gras",
|
| 42 |
+
"french_fries",
|
| 43 |
+
"french_onion_soup",
|
| 44 |
+
"french_toast",
|
| 45 |
+
"fried_calamari",
|
| 46 |
+
"fried_rice",
|
| 47 |
+
"frozen_yogurt",
|
| 48 |
+
"garlic_bread",
|
| 49 |
+
"gnocchi",
|
| 50 |
+
"greek_salad",
|
| 51 |
+
"grilled_cheese_sandwich",
|
| 52 |
+
"grilled_salmon",
|
| 53 |
+
"guacamole",
|
| 54 |
+
"gyoza",
|
| 55 |
+
"hamburger",
|
| 56 |
+
"hot_and_sour_soup",
|
| 57 |
+
"hot_dog",
|
| 58 |
+
"huevos_rancheros",
|
| 59 |
+
"hummus",
|
| 60 |
+
"ice_cream",
|
| 61 |
+
"lasagna",
|
| 62 |
+
"lobster_bisque",
|
| 63 |
+
"lobster_roll_sandwich",
|
| 64 |
+
"macaroni_and_cheese",
|
| 65 |
+
"macarons",
|
| 66 |
+
"miso_soup",
|
| 67 |
+
"mussels",
|
| 68 |
+
"nachos",
|
| 69 |
+
"omelette",
|
| 70 |
+
"onion_rings",
|
| 71 |
+
"oysters",
|
| 72 |
+
"pad_thai",
|
| 73 |
+
"paella",
|
| 74 |
+
"pancakes",
|
| 75 |
+
"panna_cotta",
|
| 76 |
+
"peking_duck",
|
| 77 |
+
"pho",
|
| 78 |
+
"pizza",
|
| 79 |
+
"pork_chop",
|
| 80 |
+
"poutine",
|
| 81 |
+
"prime_rib",
|
| 82 |
+
"pulled_pork_sandwich",
|
| 83 |
+
"ramen",
|
| 84 |
+
"ravioli",
|
| 85 |
+
"red_velvet_cake",
|
| 86 |
+
"risotto",
|
| 87 |
+
"samosa",
|
| 88 |
+
"sashimi",
|
| 89 |
+
"scallops",
|
| 90 |
+
"seaweed_salad",
|
| 91 |
+
"shrimp_and_grits",
|
| 92 |
+
"spaghetti_bolognese",
|
| 93 |
+
"spaghetti_carbonara",
|
| 94 |
+
"spring_rolls",
|
| 95 |
+
"steak",
|
| 96 |
+
"strawberry_shortcake",
|
| 97 |
+
"sushi",
|
| 98 |
+
"tacos",
|
| 99 |
+
"takoyaki",
|
| 100 |
+
"tiramisu",
|
| 101 |
+
"tuna_tartare",
|
| 102 |
+
"waffles",
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def index_to_label(index: int) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Convert a class index to its corresponding label name.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
index: The class index (0-based)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
str: The label name corresponding to the index, or a fallback string if index is out of bounds
|
| 115 |
+
"""
|
| 116 |
+
if 0 <= index < len(LABELS):
|
| 117 |
+
return LABELS[index]
|
| 118 |
+
else:
|
| 119 |
+
return f"unknown_class_{index}"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def label_to_index(label: str) -> int:
|
| 123 |
+
"""
|
| 124 |
+
Convert a label name to its corresponding class index.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
label: The label name
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
int: The index corresponding to the label, or -1 if label is not found
|
| 131 |
+
"""
|
| 132 |
+
try:
|
| 133 |
+
return LABELS.index(label)
|
| 134 |
+
except ValueError:
|
| 135 |
+
return -1
|
src/models/__pycache__/food_classification_model.cpython-310.pyc
ADDED
|
Binary file (943 Bytes). View file
|
|
|
src/models/__pycache__/food_classification_model.cpython-311.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
src/models/__pycache__/model_discovery.cpython-310.pyc
ADDED
|
Binary file (2.95 kB). View file
|
|
|
src/models/__pycache__/prithiv_ml_food101.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
src/models/__pycache__/prithiv_ml_food101.cpython-311.pyc
ADDED
|
Binary file (9.91 kB). View file
|
|
|
src/models/__pycache__/resnet18.cpython-310.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
src/models/__pycache__/resnet18.cpython-311.pyc
ADDED
|
Binary file (2.28 kB). View file
|
|
|
src/models/__pycache__/vgg16.cpython-310.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
src/models/food_classification_model.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FoodClassificationModel(ABC):
|
| 5 |
+
"""Abstract Base Class that serves as a common interface for all models."""
|
| 6 |
+
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def classify(self, image: bytes) -> int:
|
| 9 |
+
"""
|
| 10 |
+
Abstract method to classify an image into a food category.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
image: The image bytes to classify.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
int: The index of the predicted class. This returns the class index, not the class name.
|
| 17 |
+
"""
|
src/models/model_discovery.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model discovery utility for dynamically finding all models that inherit from FoodClassificationModel.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import inspect
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
from .food_classification_model import FoodClassificationModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def discover_models(models_dir: Path = None) -> Dict[str, Dict[str, Any]]:
|
| 14 |
+
"""
|
| 15 |
+
Dynamically discover all models that inherit from FoodClassificationModel.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
models_dir: Path to the models directory. If None, uses the current module's directory.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Dict mapping display names to model information containing:
|
| 22 |
+
- 'class': The model class
|
| 23 |
+
- 'module': The module name
|
| 24 |
+
- 'class_name': The class name
|
| 25 |
+
"""
|
| 26 |
+
if models_dir is None:
|
| 27 |
+
models_dir = Path(__file__).parent
|
| 28 |
+
|
| 29 |
+
available_models = {}
|
| 30 |
+
|
| 31 |
+
# Iterate through all Python files in the models directory
|
| 32 |
+
for py_file in models_dir.glob("*.py"):
|
| 33 |
+
if (
|
| 34 |
+
py_file.name.startswith("__")
|
| 35 |
+
or py_file.name == "food_classification_model.py"
|
| 36 |
+
or py_file.name == "model_discovery.py"
|
| 37 |
+
):
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Import the module dynamically
|
| 42 |
+
module_name = f"src.models.{py_file.stem}"
|
| 43 |
+
module = importlib.import_module(module_name)
|
| 44 |
+
|
| 45 |
+
# Find all classes in the module that inherit from FoodClassificationModel
|
| 46 |
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
| 47 |
+
if (
|
| 48 |
+
issubclass(obj, FoodClassificationModel)
|
| 49 |
+
and obj != FoodClassificationModel
|
| 50 |
+
and obj.__module__ == module_name
|
| 51 |
+
):
|
| 52 |
+
|
| 53 |
+
# Create a user-friendly name
|
| 54 |
+
display_name = _create_display_name(name)
|
| 55 |
+
|
| 56 |
+
available_models[display_name] = {
|
| 57 |
+
"class": obj,
|
| 58 |
+
"module": module_name,
|
| 59 |
+
"class_name": name,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
# In a non-Streamlit context, we might want to log or handle this differently
|
| 64 |
+
print(f"Warning: Could not load model from {py_file.name}: {str(e)}")
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
return available_models
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _create_display_name(class_name: str) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Create a user-friendly display name from a class name.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
class_name: The original class name
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
A user-friendly display name
|
| 79 |
+
"""
|
| 80 |
+
# Create a user-friendly name
|
| 81 |
+
display_name = class_name
|
| 82 |
+
|
| 83 |
+
if "prithiv" in class_name.lower():
|
| 84 |
+
display_name = "PrithivML Food-101 (Benchmark)"
|
| 85 |
+
elif "resnet" in class_name.lower():
|
| 86 |
+
display_name = "ResNet-18"
|
| 87 |
+
elif "vgg" in class_name.lower():
|
| 88 |
+
display_name = "VGG-16"
|
| 89 |
+
elif "efficientnet" in class_name.lower():
|
| 90 |
+
display_name = "EfficientNet"
|
| 91 |
+
elif "mobilenet" in class_name.lower():
|
| 92 |
+
display_name = "MobileNet"
|
| 93 |
+
elif "densenet" in class_name.lower():
|
| 94 |
+
display_name = "DenseNet"
|
| 95 |
+
|
| 96 |
+
return display_name
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_model_names() -> list:
|
| 100 |
+
"""
|
| 101 |
+
Get a list of all available model display names.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
List of model display names
|
| 105 |
+
"""
|
| 106 |
+
models = discover_models()
|
| 107 |
+
return list(models.keys())
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_model_info(display_name: str) -> Dict[str, Any]:
|
| 111 |
+
"""
|
| 112 |
+
Get model information for a specific model by display name.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
display_name: The display name of the model
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Model information dictionary or None if not found
|
| 119 |
+
"""
|
| 120 |
+
models = discover_models()
|
| 121 |
+
return models.get(display_name)
|
src/models/prithiv_ml_food101.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoImageProcessor, SiglipForImageClassification
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from src.models.food_classification_model import FoodClassificationModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PrithivMlFood101(FoodClassificationModel):
|
| 13 |
+
"""
|
| 14 |
+
Interface for accessing the PrithivML Food-101 model architecture.
|
| 15 |
+
This model was already trained on the Food-101 dataset and performs well on it.
|
| 16 |
+
It is supposed to serve as a benchmark for our own model finetuning and potentially
|
| 17 |
+
as an alternative to be deployed.
|
| 18 |
+
See it on Huggingface: https://huggingface.co/prithivMLmods/Food-101-93M.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str = "prithivMLmods/Food-101-93M"):
|
| 22 |
+
"""
|
| 23 |
+
Load the PrithivML Food-101 model.
|
| 24 |
+
|
| 25 |
+
Preference order:
|
| 26 |
+
1) Load from local repo snapshot at <repo_root>/models/prithivMLmods/Food-101-93M
|
| 27 |
+
2) If not present, prompt the user to download using Makefile
|
| 28 |
+
make download-hf-model MODEL_PATH=prithivMLmods/Food-101-93M
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Set up proper cache directory for HF Spaces (safe no-op locally)
|
| 32 |
+
if not os.environ.get("HF_HOME"):
|
| 33 |
+
cache_dir = Path(tempfile.gettempdir()) / "transformers_cache"
|
| 34 |
+
cache_dir.mkdir(exist_ok=True)
|
| 35 |
+
os.environ["HF_HOME"] = str(cache_dir)
|
| 36 |
+
|
| 37 |
+
# Resolve repo root robustly from this file's location
|
| 38 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 39 |
+
local_model_dir = repo_root / "models" / "prithivMLmods" / "Food-101-93M"
|
| 40 |
+
|
| 41 |
+
# Determine whether a local copy exists (safetensors or bin)
|
| 42 |
+
local_exists = local_model_dir.exists() and (
|
| 43 |
+
(local_model_dir / "model.safetensors").exists()
|
| 44 |
+
or (local_model_dir / "pytorch_model.bin").exists()
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if not local_exists:
|
| 48 |
+
# Provide a clear, actionable message to fetch the model snapshot
|
| 49 |
+
make_cmd = "make download-hf-model MODEL_PATH=prithivMLmods/Food-101-93M"
|
| 50 |
+
raise RuntimeError(
|
| 51 |
+
"Local model not found at 'models/prithivMLmods/Food-101-93M'. "
|
| 52 |
+
"Please download it first using:\n"
|
| 53 |
+
f" {make_cmd}\n"
|
| 54 |
+
"After download completes, re-run your program."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Load from local directory snapshot
|
| 58 |
+
try:
|
| 59 |
+
self.model = SiglipForImageClassification.from_pretrained(
|
| 60 |
+
str(local_model_dir),
|
| 61 |
+
cache_dir=os.environ.get("HF_HOME"),
|
| 62 |
+
local_files_only=True,
|
| 63 |
+
force_download=False,
|
| 64 |
+
)
|
| 65 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 66 |
+
str(local_model_dir),
|
| 67 |
+
cache_dir=os.environ.get("HF_HOME"),
|
| 68 |
+
local_files_only=True,
|
| 69 |
+
force_download=False,
|
| 70 |
+
use_fast=True, # Use fast processor to avoid warning
|
| 71 |
+
)
|
| 72 |
+
self.model_name = str(local_model_dir)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
raise RuntimeError(
|
| 75 |
+
"Failed to load local model from 'models/prithivMLmods/Food-101-93M': "
|
| 76 |
+
f"{str(e)}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def classify(self, image: bytes) -> int:
|
| 80 |
+
pil_image = Image.open(io.BytesIO(image)).convert("RGB")
|
| 81 |
+
inputs = self.processor(images=pil_image, return_tensors="pt")
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
outputs = self.model(**inputs)
|
| 85 |
+
logits = outputs.logits
|
| 86 |
+
probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
|
| 87 |
+
|
| 88 |
+
predicted_idx = torch.argmax(probs).item()
|
| 89 |
+
return predicted_idx
|
src/models/resnet18.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.models.food_classification_model import FoodClassificationModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Resnet18(FoodClassificationModel):
|
| 12 |
+
"""
|
| 13 |
+
Interface for accessing the Resnet-18 model architecture.
|
| 14 |
+
See the base model here: https://huggingface.co/microsoft/resnet-18.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
preprocessor_path: str = "microsoft/resnet-18",
|
| 20 |
+
model_path: str = "microsoft/resnet-18",
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Prefer loading from a local snapshot under models/microsoft/resnet-18.
|
| 24 |
+
If the local snapshot doesn't exist, prompt the user to download it via Makefile.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Resolve repo root and local model dir
|
| 28 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 29 |
+
local_model_dir = repo_root / "models" / "microsoft" / "resnet-18"
|
| 30 |
+
|
| 31 |
+
# Check if a local HF snapshot exists (config + weights)
|
| 32 |
+
local_exists = local_model_dir.exists() and (
|
| 33 |
+
(local_model_dir / "pytorch_model.bin").exists()
|
| 34 |
+
or (local_model_dir / "model.safetensors").exists()
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if not local_exists:
|
| 38 |
+
make_cmd = "make download-hf-model MODEL_PATH=microsoft/resnet-18"
|
| 39 |
+
raise RuntimeError(
|
| 40 |
+
"Local model not found at 'models/microsoft/resnet-18'. "
|
| 41 |
+
"Please download it first using:\n"
|
| 42 |
+
f" {make_cmd}\n"
|
| 43 |
+
"After download completes, re-run your program."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Load from local folder
|
| 47 |
+
self.image_processor = AutoImageProcessor.from_pretrained(str(local_model_dir))
|
| 48 |
+
self.model = AutoModelForImageClassification.from_pretrained(
|
| 49 |
+
str(local_model_dir)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def classify(self, image: bytes) -> int:
|
| 53 |
+
pil_image = Image.open(io.BytesIO(image))
|
| 54 |
+
inputs = self.image_processor(pil_image, return_tensors="pt")
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
logits = self.model(**inputs).logits
|
| 58 |
+
|
| 59 |
+
# model predicts one of the 101 Food-101 classes (if fine-tuned for Food-101).
|
| 60 |
+
# If using the default microsoft/resnet-18 weights, this will predict one
|
| 61 |
+
# of the 1000 ImageNet classes, not Food-101.
|
| 62 |
+
predicted_label = logits.argmax(-1).item()
|
| 63 |
+
return predicted_label
|
src/models/vgg16.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import io
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
from src.models.food_classification_model import FoodClassificationModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VGG16(FoodClassificationModel):
|
| 13 |
+
"""Interface for accessing the VGG-16 model architecture."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, weights: str = "IMAGENET1K_V1", num_classes: int = 101):
|
| 16 |
+
"""
|
| 17 |
+
Initialize VGG-16. Prefer loading local fine-tuned weights if available.
|
| 18 |
+
|
| 19 |
+
Priority:
|
| 20 |
+
1) Load ImageNet base and replace classifier, then load local fine-tuned checkpoint
|
| 21 |
+
from <repo_root>/models/vgg16/vgg16-397923af.pth (if exists).
|
| 22 |
+
2) Otherwise, fall back to ImageNet weights only (not Food-101 trained), and
|
| 23 |
+
instruct user to provide or train a .pth for Food-101 fine-tuning.
|
| 24 |
+
"""
|
| 25 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 26 |
+
local_weights = repo_root / "models" / "vgg16/vgg16-397923af.pth"
|
| 27 |
+
|
| 28 |
+
# Base model with ImageNet weights
|
| 29 |
+
self.model = models.vgg16(weights=weights)
|
| 30 |
+
|
| 31 |
+
num_features = self.model.classifier[6].in_features
|
| 32 |
+
self.model.classifier[6] = nn.Linear(num_features, num_classes)
|
| 33 |
+
|
| 34 |
+
# If local fine-tuned weights exist, load them
|
| 35 |
+
if local_weights.exists():
|
| 36 |
+
try:
|
| 37 |
+
raw_ckpt: Dict[str, Any] = torch.load(local_weights, map_location="cpu")
|
| 38 |
+
|
| 39 |
+
# Unwrap common checkpoint formats
|
| 40 |
+
if isinstance(raw_ckpt, dict) and "state_dict" in raw_ckpt:
|
| 41 |
+
ckpt = raw_ckpt["state_dict"]
|
| 42 |
+
else:
|
| 43 |
+
ckpt = raw_ckpt
|
| 44 |
+
|
| 45 |
+
# Normalize key prefixes commonly introduced by wrappers
|
| 46 |
+
def strip_prefix(
|
| 47 |
+
sd: Dict[str, torch.Tensor], prefix: str
|
| 48 |
+
) -> Dict[str, torch.Tensor]:
|
| 49 |
+
if all(k.startswith(prefix) for k in sd.keys()):
|
| 50 |
+
return {k[len(prefix) :]: v for k, v in sd.items()}
|
| 51 |
+
return sd
|
| 52 |
+
|
| 53 |
+
for p in ("module.", "model.", "net."):
|
| 54 |
+
ckpt = strip_prefix(ckpt, p)
|
| 55 |
+
|
| 56 |
+
# Filter out mismatched keys (e.g., classifier.6 for 1000->101 classes)
|
| 57 |
+
model_sd = self.model.state_dict()
|
| 58 |
+
filtered_ckpt = {}
|
| 59 |
+
skipped = []
|
| 60 |
+
for k, v in ckpt.items():
|
| 61 |
+
if k in model_sd and isinstance(v, torch.Tensor):
|
| 62 |
+
if model_sd[k].shape == v.shape:
|
| 63 |
+
filtered_ckpt[k] = v
|
| 64 |
+
else:
|
| 65 |
+
skipped.append(
|
| 66 |
+
(k, tuple(v.shape), tuple(model_sd[k].shape))
|
| 67 |
+
)
|
| 68 |
+
# Silently ignore keys not present in the current model
|
| 69 |
+
|
| 70 |
+
missing_before = set(model_sd.keys()) - set(filtered_ckpt.keys())
|
| 71 |
+
self.model.load_state_dict(filtered_ckpt, strict=False)
|
| 72 |
+
|
| 73 |
+
# Optional: print a brief summary to logs for transparency
|
| 74 |
+
if skipped:
|
| 75 |
+
skipped_str = ", ".join(
|
| 76 |
+
[f"{k}: {src} -> {dst}" for k, src, dst in skipped[:5]]
|
| 77 |
+
)
|
| 78 |
+
more = "" if len(skipped) <= 5 else f" (+{len(skipped)-5} more)"
|
| 79 |
+
print(
|
| 80 |
+
f"[VGG16] Partially loaded checkpoint from '{local_weights}'. "
|
| 81 |
+
f"Skipped mismatched keys: {skipped_str}{more}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if (
|
| 85 |
+
"classifier.6.weight" in missing_before
|
| 86 |
+
or "classifier.6.bias" in missing_before
|
| 87 |
+
):
|
| 88 |
+
print(
|
| 89 |
+
"[VGG16] Final classifier layer initialized for 101 classes and was not loaded from checkpoint."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise RuntimeError(
|
| 94 |
+
f"Failed to load local VGG16 weights from '{local_weights}': {e}"
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
# No local fine-tuned weights: keep ImageNet weights but warn with action
|
| 98 |
+
raise RuntimeError(
|
| 99 |
+
"Local fine-tuned weights not found at 'models/vgg16/vgg16-397923af.pth'.\n"
|
| 100 |
+
"Please place your fine-tuned checkpoint there, or train/export one.\n"
|
| 101 |
+
"Alternatively, switch to a HF model with a Makefile download target."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.model.eval()
|
| 105 |
+
|
| 106 |
+
self.transform = transforms.Compose(
|
| 107 |
+
[
|
| 108 |
+
transforms.Resize(256),
|
| 109 |
+
transforms.CenterCrop(224),
|
| 110 |
+
transforms.ToTensor(),
|
| 111 |
+
transforms.Normalize(
|
| 112 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 113 |
+
),
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def classify(self, image: bytes) -> int:
|
| 118 |
+
pil_image = Image.open(io.BytesIO(image))
|
| 119 |
+
|
| 120 |
+
if pil_image.mode != "RGB":
|
| 121 |
+
pil_image = pil_image.convert("RGB")
|
| 122 |
+
|
| 123 |
+
input_tensor = self.transform(pil_image)
|
| 124 |
+
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
outputs = self.model(input_batch)
|
| 128 |
+
predicted_idx = torch.argmax(outputs).item()
|
| 129 |
+
|
| 130 |
+
return predicted_idx
|
src/train/__pycache__/preprocess_data.cpython-310.pyc
ADDED
|
Binary file (4.08 kB). View file
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|