AdrianHagen commited on
Commit
21fb9ff
·
0 Parent(s):

Clean history: code-only (no models)

Browse files
Files changed (39) hide show
  1. .dockerignore +64 -0
  2. .gitattributes +35 -0
  3. Dockerfile +55 -0
  4. README.md +12 -0
  5. app.py +275 -0
  6. config.toml +24 -0
  7. pyproject.toml +70 -0
  8. src/.DS_Store +0 -0
  9. src/__init__.py +8 -0
  10. src/__pycache__/__init__.cpython-310.pyc +0 -0
  11. src/__pycache__/__init__.cpython-311.pyc +0 -0
  12. src/__pycache__/cli.cpython-310.pyc +0 -0
  13. src/__pycache__/constants.cpython-310.pyc +0 -0
  14. src/__pycache__/constants.cpython-311.pyc +0 -0
  15. src/__pycache__/evaluate_imagenet.cpython-310.pyc +0 -0
  16. src/__pycache__/labels.cpython-310.pyc +0 -0
  17. src/data/download_data.py +8 -0
  18. src/eval/__pycache__/base_evaluator.cpython-310.pyc +0 -0
  19. src/eval/__pycache__/evaluate_food101.cpython-310.pyc +0 -0
  20. src/eval/__pycache__/evaluate_food101.cpython-311.pyc +0 -0
  21. src/eval/__pycache__/evaluate_imagenet.cpython-310.pyc +0 -0
  22. src/eval/eval.py +67 -0
  23. src/eval/evaluate_food101.py +326 -0
  24. src/labels.py +135 -0
  25. src/models/__pycache__/food_classification_model.cpython-310.pyc +0 -0
  26. src/models/__pycache__/food_classification_model.cpython-311.pyc +0 -0
  27. src/models/__pycache__/model_discovery.cpython-310.pyc +0 -0
  28. src/models/__pycache__/prithiv_ml_food101.cpython-310.pyc +0 -0
  29. src/models/__pycache__/prithiv_ml_food101.cpython-311.pyc +0 -0
  30. src/models/__pycache__/resnet18.cpython-310.pyc +0 -0
  31. src/models/__pycache__/resnet18.cpython-311.pyc +0 -0
  32. src/models/__pycache__/vgg16.cpython-310.pyc +0 -0
  33. src/models/food_classification_model.py +17 -0
  34. src/models/model_discovery.py +121 -0
  35. src/models/prithiv_ml_food101.py +89 -0
  36. src/models/resnet18.py +63 -0
  37. src/models/vgg16.py +130 -0
  38. src/train/__pycache__/preprocess_data.cpython-310.pyc +0 -0
  39. uv.lock +0 -0
.dockerignore ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git and version control
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Python cache and virtual environments
7
+ __pycache__/
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ .Python
12
+ .venv/
13
+ venv/
14
+ .env
15
+
16
+ # Documentation and development files
17
+ docs/
18
+ references/
19
+ reports/
20
+ notebooks/
21
+ tests/
22
+
23
+ # Data files (too large for Docker image)
24
+ data/
25
+ mlruns/
26
+ models/*.pth
27
+
28
+ # Development and configuration files
29
+ .pytest_cache/
30
+ .coverage
31
+ .mypy_cache/
32
+ .ruff_cache/
33
+ *.log
34
+
35
+ # IDE and editor files
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+ *.swo
40
+ *~
41
+
42
+ # OS files
43
+ .DS_Store
44
+ Thumbs.db
45
+
46
+ # Build artifacts
47
+ build/
48
+ dist/
49
+ *.egg-info/
50
+
51
+ # Streamlit specific
52
+ .streamlit/
53
+
54
+ # Project specific
55
+ Makefile
56
+ LICENSE
57
+ streamlit/README.md
58
+
59
+ # Deployment scripts
60
+ deploy_to_hf.sh
61
+ COMMANDS.md
62
+
63
+ # Food101 HF Space (don't include in Docker)
64
+ Food101/
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 slim image for smaller size
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install uv
8
+ RUN pip install uv
9
+
10
+ # Copy dependency files first (for better caching)
11
+ COPY pyproject.toml ./
12
+ COPY uv.lock ./
13
+ COPY README.md ./
14
+
15
+ # Install dependencies using uv (creates .venv)
16
+ RUN uv sync --frozen
17
+
18
+ # Use /tmp for all caches and Streamlit config to avoid permission issues in read-only paths
19
+ ENV HOME=/tmp \
20
+ XDG_CACHE_HOME=/tmp/.cache \
21
+ UV_CACHE_DIR=/tmp/.cache/uv \
22
+ PIP_CACHE_DIR=/tmp/.cache/pip \
23
+ HF_HOME=/tmp/.cache/huggingface \
24
+ TRANSFORMERS_CACHE=/tmp/.cache/huggingface/transformers \
25
+ TORCH_HOME=/tmp/.cache/torch \
26
+ STREAMLIT_CONFIG_DIR=/tmp/.streamlit \
27
+ STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
28
+ RUN mkdir -p \
29
+ /tmp/.cache/uv \
30
+ /tmp/.cache/pip \
31
+ /tmp/.cache/huggingface/transformers \
32
+ /tmp/.cache/torch \
33
+ /tmp/.streamlit \
34
+ && chmod -R 777 /tmp/.cache /tmp/.streamlit
35
+
36
+ # Copy application code
37
+ COPY src/ ./src/
38
+ COPY app.py ./
39
+ COPY config.toml ./
40
+
41
+ # Copy models
42
+ COPY models/ ./models/
43
+
44
+ # Expose Streamlit port (Hugging Face Spaces uses 7860)
45
+ EXPOSE 7860
46
+
47
+ # Set environment variables for Streamlit
48
+ ENV STREAMLIT_SERVER_PORT=7860
49
+ ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
50
+ ENV STREAMLIT_SERVER_HEADLESS=true
51
+ ENV STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION=false
52
+
53
+ # Run the application using the created virtual environment
54
+ ENV PATH="/app/.venv/bin:${PATH}"
55
+ CMD ["python", "-m", "streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Food101 Streamlit
3
+ emoji: 👁
4
+ colorFrom: yellow
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ short_description: Space for Detecting the Type of Dish in an Image
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import io
4
+ import os
5
+ import time
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ from src.models.model_discovery import discover_models
10
+ from src.labels import LABELS
11
+
12
+
13
+ def load_model(model_info):
14
+ """Load and cache the selected model with proper error handling."""
15
+ model_class = model_info["class"]
16
+ model_name = model_info["class_name"]
17
+
18
+ # Set up custom cache directory to avoid permission issues
19
+ custom_cache = Path(tempfile.gettempdir()) / "tikka_masalai_cache"
20
+ custom_cache.mkdir(exist_ok=True)
21
+
22
+ # Set HuggingFace cache directory (use HF_HOME instead of deprecated TRANSFORMERS_CACHE)
23
+ os.environ["HF_HOME"] = str(custom_cache)
24
+ os.environ["TRANSFORMERS_CACHE"] = str(
25
+ custom_cache
26
+ ) # Keep for backward compatibility
27
+
28
+ try:
29
+ # Use a placeholder for the loading message that we can clear
30
+ loading_placeholder = st.empty()
31
+ loading_placeholder.info(f"Loading {model_name} model...")
32
+
33
+ # Try to load the model - handle different model initialization patterns
34
+ if "prithiv" in model_name.lower():
35
+ # PrithivML model with specific initialization
36
+ model = model_class()
37
+ elif "resnet" in model_name.lower():
38
+ # ResNet model - check if it needs specific paths
39
+ try:
40
+ model = model_class()
41
+ except TypeError:
42
+ # Try with default parameters if it requires them
43
+ model = model_class(
44
+ preprocessor_path="microsoft/resnet-18",
45
+ model_path="microsoft/resnet-18",
46
+ )
47
+ elif "vgg" in model_name.lower():
48
+ # VGG model with default parameters
49
+ model = model_class()
50
+ else:
51
+ # Generic model initialization
52
+ try:
53
+ model = model_class()
54
+ except TypeError:
55
+ # Skip models that require specific parameters we don't know about
56
+ raise RuntimeError(
57
+ f"Model {model_name} requires specific initialization parameters"
58
+ )
59
+
60
+ # Show success message briefly, then clear it
61
+ loading_placeholder.success(f"{model_name} model loaded successfully!")
62
+ time.sleep(1.5) # Show success message for 1.5 seconds
63
+ loading_placeholder.empty() # Clear the message
64
+ return model
65
+
66
+ except PermissionError as e:
67
+ st.error(f"❌ Permission error: {str(e)}")
68
+ if "cache" in str(e).lower():
69
+ st.info(
70
+ "💡 This is likely a cache permission issue. Please refresh the page and try again."
71
+ )
72
+ return None
73
+
74
+ except Exception as e:
75
+ error_msg = str(e)
76
+ st.error(f"❌ Error loading {model_name} model: {error_msg}")
77
+ st.info("💡 Possible solutions:")
78
+ st.info("1. Refresh the page and try again")
79
+ st.info("2. Check if HuggingFace services are available")
80
+ st.info("3. Try a different model")
81
+ return None
82
+
83
+
84
+ def predict_food(model, image_bytes):
85
+ """Make a prediction on the uploaded image."""
86
+ try:
87
+ # Get prediction index
88
+ prediction_idx = model.classify(image_bytes)
89
+
90
+ # Get the label name
91
+ if 0 <= prediction_idx < len(LABELS):
92
+ prediction_label = LABELS[prediction_idx]
93
+ return prediction_idx, prediction_label
94
+ else:
95
+ return None, "Unknown"
96
+ except Exception as e:
97
+ st.error(f"Error during prediction: {str(e)}")
98
+ return None, "Error"
99
+
100
+
101
+ def main():
102
+ """Main Streamlit application."""
103
+ st.set_page_config(
104
+ page_title="TikkaMasalAI Food Classifier", page_icon="🍽️", layout="centered"
105
+ )
106
+
107
+ st.title("🍽️ TikkaMasalAI Food Classifier")
108
+ st.markdown("Upload an image of food and let our AI identify what it is!")
109
+
110
+ # Discover available models
111
+ try:
112
+ available_models = discover_models()
113
+ except Exception as e:
114
+ st.error(f"❌ Error discovering models: {str(e)}")
115
+ st.info("Make sure the src/models directory contains valid model files.")
116
+ return
117
+
118
+ if not available_models:
119
+ st.error("❌ No compatible models found in the src/models directory!")
120
+ st.info("Make sure there are models that inherit from FoodClassificationModel.")
121
+ return
122
+
123
+ # Model selection in sidebar
124
+ with st.sidebar:
125
+ st.header("🤖 Model Selection")
126
+ selected_model_name = st.selectbox(
127
+ "Choose a model:",
128
+ options=list(available_models.keys()),
129
+ help="Select which AI model to use for food classification",
130
+ )
131
+
132
+ selected_model_info = available_models[selected_model_name]
133
+
134
+ # Show model information
135
+ st.info(f"**Selected:** {selected_model_name}")
136
+ st.write(f"**Class:** `{selected_model_info['class_name']}`")
137
+ st.write(f"**Module:** `{selected_model_info['module']}`")
138
+
139
+ # Show app status
140
+ status_container = st.container()
141
+
142
+ # Load model with better UX
143
+ with status_container:
144
+ model_status = st.empty()
145
+ progress_bar = st.progress(0)
146
+
147
+ model_status.info("🔄 Initializing AI model...")
148
+ progress_bar.progress(25)
149
+
150
+ model = load_model(selected_model_info)
151
+ progress_bar.progress(100)
152
+
153
+ if model is None:
154
+ model_status.error("❌ Failed to load the model.")
155
+ st.error("### 🚨 Model Loading Failed")
156
+ st.markdown(
157
+ f"""
158
+ **Failed to load:** {selected_model_name}
159
+
160
+ **Possible causes:**
161
+ - Model-specific initialization requirements
162
+ - Missing dependencies for this model
163
+ - Temporary HuggingFace services issue
164
+ - Model cache conflicts in HF Spaces
165
+ - Network connectivity problems
166
+
167
+ **Solutions:**
168
+ 1. **Try a different model** from the sidebar
169
+ 2. **Refresh the page** and try again
170
+ 3. **Wait 2-3 minutes** for any background downloads to complete
171
+ 4. If the issue persists, the model will automatically retry
172
+ """
173
+ )
174
+
175
+ # Add a retry button
176
+ if st.button("🔄 Retry Loading Model"):
177
+ st.experimental_rerun()
178
+
179
+ return
180
+
181
+ model_status.success(f"✅ {selected_model_name} loaded and ready!")
182
+ progress_bar.empty()
183
+
184
+ # File uploader
185
+ uploaded_file = st.file_uploader(
186
+ "Choose a food image...",
187
+ type=["png", "jpg", "jpeg"],
188
+ help="Upload an image of food to classify",
189
+ )
190
+
191
+ if uploaded_file is not None:
192
+ # Read image bytes
193
+ image_bytes = uploaded_file.read()
194
+
195
+ # Display the uploaded image
196
+ col1, col2 = st.columns([1, 1])
197
+
198
+ with col1:
199
+ st.subheader("📸 Uploaded Image")
200
+ image = Image.open(io.BytesIO(image_bytes))
201
+ st.image(image, caption="Your uploaded image", use_container_width=True)
202
+
203
+ with col2:
204
+ st.subheader("🔍 Prediction Results")
205
+
206
+ # Make prediction
207
+ with st.spinner("Analyzing your image..."):
208
+ prediction_idx, prediction_label = predict_food(model, image_bytes)
209
+
210
+ if prediction_idx is not None:
211
+ # Display results
212
+ st.success("Classification complete!")
213
+
214
+ # Format the label for display
215
+ display_label = prediction_label.replace("_", " ").title()
216
+
217
+ st.markdown(f"### 🏷️ **{display_label}**")
218
+ st.markdown(f"**Class Index:** {prediction_idx}")
219
+
220
+ # Show confidence bar (placeholder since the model doesn't return probabilities)
221
+ st.markdown("**Prediction Details:**")
222
+ st.info(f"The AI model identified this image as **{display_label}**")
223
+
224
+ # Show additional info
225
+ with st.expander("ℹ️ About this classification"):
226
+ st.write(f"- **Model:** {selected_model_name}")
227
+ st.write(f"- **Classes:** {len(LABELS)} different food types")
228
+ st.write(f"- **Raw label:** `{prediction_label}`")
229
+ st.write(f"- **Index:** {prediction_idx}")
230
+ else:
231
+ st.error("Failed to classify the image. Please try another image.")
232
+
233
+ # Sidebar with information
234
+ with st.sidebar:
235
+ st.header("📋 About")
236
+ st.write(
237
+ f"""
238
+ This app uses the **{selected_model_name}** model to classify food images into one of 101 different food categories.
239
+ """
240
+ )
241
+
242
+ st.header("🎯 How to use")
243
+ st.write(
244
+ """
245
+ 1. Choose a model from the dropdown above
246
+ 2. Upload an image of food using the file uploader
247
+ 3. Wait for the AI to analyze your image
248
+ 4. View the classification results
249
+ """
250
+ )
251
+
252
+ st.header("🍕 Supported Foods")
253
+ st.write(
254
+ f"The model can recognize **{len(LABELS)}** different types of food including:"
255
+ )
256
+
257
+ # Show a sample of labels
258
+ sample_labels = [label.replace("_", " ").title() for label in LABELS[:10]]
259
+ for label in sample_labels:
260
+ st.write(f"• {label}")
261
+ st.write(f"... and {len(LABELS) - 10} more!")
262
+
263
+ st.header("🔧 Technical Details")
264
+ st.write(
265
+ f"""
266
+ - **Selected Model:** {selected_model_name}
267
+ - **Available Models:** {len(available_models)}
268
+ - **Dataset:** Food-101
269
+ - **Framework:** PyTorch + Transformers
270
+ """
271
+ )
272
+
273
+
274
+ if __name__ == "__main__":
275
+ main()
config.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [global]
2
+ developmentMode = false
3
+
4
+ [server]
5
+ port = 7860
6
+ address = "0.0.0.0"
7
+ headless = true
8
+ enableCORS = false
9
+ enableXsrfProtection = false
10
+
11
+ [browser]
12
+ gatherUsageStats = false
13
+
14
+ [client]
15
+ toolbarMode = "minimal"
16
+
17
+ [runner]
18
+ magicEnabled = true
19
+ installTracer = false
20
+ fixMatplotlib = true
21
+
22
+ [logger]
23
+ level = "info"
24
+ messageFormat = "%(asctime)s %(message)s"
pyproject.toml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "tikka-masalai"
7
+ version = "0.0.1"
8
+ description = "MLOPS project FIB"
9
+ authors = [
10
+ { name = "Team Tikka MasalAI" },
11
+ ]
12
+
13
+ readme = "README.md"
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ ]
17
+ dependencies = [
18
+ "dagshub>=0.6.3",
19
+ "datasets<4.1.1",
20
+ "dvc>=3.63.0",
21
+ "dvc-s3>=3.2.2",
22
+ "huggingface-hub>=0.35.0",
23
+ "ipykernel>=6.30.1",
24
+ "ipywidgets>=8.1.7",
25
+ "matplotlib>=3.10.6",
26
+ "mlflow>=2,<3",
27
+ "numpy>=2.2.6",
28
+ "pandas>=2.3.2",
29
+ "pillow>=11.3.0",
30
+ "polars>=1.0.0",
31
+ "pyarrow>=4.0.0,<20.0.0",
32
+ "pytest",
33
+ "python-dotenv",
34
+ "ruff",
35
+ "streamlit>=1.31.0",
36
+ "torch>=2.8.0",
37
+ "torchvision>=0.23.0",
38
+ "tqdm>=4.67.1",
39
+ "transformers>=4.56.2",
40
+ ]
41
+ requires-python = ">=3.10"
42
+
43
+ [project.optional-dependencies]
44
+ dev = []
45
+
46
+
47
+ # This makes src/ discoverable as a package
48
+ [tool.hatch.build.targets.wheel]
49
+ packages = ["src"]
50
+
51
+
52
+ [tool.ruff]
53
+ line-length = 99
54
+ src = ["src"]
55
+ include = ["pyproject.toml", "src/**/*.py"]
56
+
57
+ [tool.ruff.lint]
58
+ extend-select = ["I"] # Add import sorting
59
+
60
+ [tool.ruff.lint.isort]
61
+ known-first-party = ["src"]
62
+ force-sort-within-sections = true
63
+
64
+ [dependency-groups]
65
+ dev = [
66
+ "black>=25.1.0",
67
+ "pylint>=3.3.8",
68
+ "pytest>=8.4.2",
69
+ ]
70
+
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TikkaMasalAI source package.
3
+ """
4
+
5
+ # Export commonly used utilities
6
+ from .models.model_discovery import discover_models, get_model_names, get_model_info
7
+
8
+ __all__ = ["discover_models", "get_model_names", "get_model_info"]
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (333 Bytes). View file
 
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (170 Bytes). View file
 
src/__pycache__/cli.cpython-310.pyc ADDED
Binary file (6.94 kB). View file
 
src/__pycache__/constants.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
src/__pycache__/constants.cpython-311.pyc ADDED
Binary file (2.64 kB). View file
 
src/__pycache__/evaluate_imagenet.cpython-310.pyc ADDED
Binary file (6.93 kB). View file
 
src/__pycache__/labels.cpython-310.pyc ADDED
Binary file (2.29 kB). View file
 
src/data/download_data.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ snapshot_download(
4
+ repo_id="ethz/food101",
5
+ repo_type="dataset",
6
+ local_dir="./data/raw/food101",
7
+ local_dir_use_symlinks=False, # ensures actual files, not symlinks
8
+ )
src/eval/__pycache__/base_evaluator.cpython-310.pyc ADDED
Binary file (9.03 kB). View file
 
src/eval/__pycache__/evaluate_food101.cpython-310.pyc ADDED
Binary file (9.38 kB). View file
 
src/eval/__pycache__/evaluate_food101.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
src/eval/__pycache__/evaluate_imagenet.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
src/eval/eval.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example script demonstrating how to evaluate multiple models on different datasets.
4
+
5
+ This script shows how to use the enhanced evaluation framework
6
+ with different model implementations including VGG16, ResNet18, and PrithivMlFood101.
7
+ """
8
+ from src.models.vgg16 import VGG16
9
+ from src.models.resnet18 import Resnet18
10
+ from src.models.prithiv_ml_food101 import PrithivMlFood101
11
+ from src.eval.evaluate_food101 import Food101Evaluator
12
+ from src.models.food_classification_model import FoodClassificationModel
13
+
14
+
15
+ def evaluate_food101(
16
+ model: FoodClassificationModel,
17
+ experiment_name: str = "food101_evaluation",
18
+ sample_limit: int = 50,
19
+ random_seed: int = 42,
20
+ run_name: str = None,
21
+ ):
22
+ """Main evaluation function."""
23
+ evaluator = Food101Evaluator(
24
+ model, experiment_name, sample_limit, random_seed, run_name
25
+ )
26
+ evaluator.run_evaluation()
27
+
28
+
29
+ def main():
30
+ """Demonstrate evaluation with multiple model architectures."""
31
+
32
+ print("=" * 90)
33
+ print("Multi-Model Evaluation: VGG16 vs ResNet-18 vs PrithivMlFood101")
34
+
35
+ # Food101 Evaluations
36
+ print("\n=== Food101 Evaluations ===")
37
+
38
+ print("\n1. Evaluating PrithivMlFood101 on Food101...")
39
+ prithiv_model = PrithivMlFood101()
40
+ evaluate_food101(
41
+ experiment_name="Food101_Model_Comparison",
42
+ run_name="PrithivML_Baseline_50samples",
43
+ sample_limit=50, # Small sample for demonstration
44
+ model=prithiv_model,
45
+ )
46
+
47
+ print("\n2. Evaluating ResNet-18 on Food101 ...")
48
+ resnet18_food_model = Resnet18()
49
+ evaluate_food101(
50
+ experiment_name="Food101_Model_Comparison",
51
+ run_name="ResNet18_Transfer_50samples",
52
+ sample_limit=50, # Small sample for demonstration
53
+ model=resnet18_food_model,
54
+ )
55
+
56
+ print("\n3. Evaluating VGG16 on Food101 ...")
57
+ vgg16_food_model = VGG16()
58
+ evaluate_food101(
59
+ experiment_name="Food101_Model_Comparison",
60
+ run_name="VGG16_Transfer_50samples",
61
+ sample_limit=50, # Small sample for demonstration
62
+ model=vgg16_food_model,
63
+ )
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
src/eval/evaluate_food101.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Food101 evaluation script for model evaluation with MLflow tracking.
4
+
5
+ This script evaluates models on the Food101 dataset with MLflow experiment tracking.
6
+ """
7
+
8
+ from typing import List, Dict, Tuple, Any, Union
9
+ from pathlib import Path
10
+ import mlflow
11
+ from datetime import datetime
12
+ import pandas as pd
13
+ import glob
14
+ import random
15
+ import dagshub
16
+
17
+ from src.models.food_classification_model import FoodClassificationModel
18
+ from src.labels import LABELS, index_to_label
19
+
20
+ dagshub.init(repo_owner="HubertWojcik10", repo_name="TikkaMasalAI", mlflow=True)
21
+ mlflow.autolog()
22
+
23
+
24
+ class Food101Evaluator:
25
+ """Model evaluator for Food101 dataset with MLflow tracking."""
26
+
27
+ def __init__(
28
+ self,
29
+ model: FoodClassificationModel,
30
+ experiment_name: str = "food101_evaluation",
31
+ sample_limit: int = 50,
32
+ random_seed: int = 42,
33
+ run_name: str = None,
34
+ ):
35
+ """
36
+ Initialize the Food101 evaluator.
37
+
38
+ Args:
39
+ model: FoodClassificationModel instance to use for evaluation (required)
40
+ experiment_name: Name of the MLflow experiment
41
+ sample_limit: Maximum number of samples to evaluate
42
+ random_seed: Random seed for reproducible sampling
43
+ run_name: Custom name for the MLflow run (optional)
44
+ """
45
+ self.DATASET_NAME = "Food101"
46
+ self.experiment_name = experiment_name
47
+ self.sample_limit = sample_limit
48
+ self.model = model
49
+ self.random_seed = random_seed
50
+ self.custom_run_name = run_name
51
+ self.model_name = self.model.__class__.__name__
52
+ self.data_dir = (
53
+ Path(__file__).parent.parent.parent / "data" / "raw" / "food101" / "data"
54
+ )
55
+
56
+ def load_validation_data(self) -> List[Tuple[bytes, int]]:
57
+ """
58
+ Load validation data from parquet files with random sampling.
59
+
60
+ Returns:
61
+ List of tuples: (image_bytes, true_index)
62
+ """
63
+ random.seed(self.random_seed)
64
+
65
+ validation_files = glob.glob(f"{self.data_dir}/validation-*.parquet")
66
+ print(f"Found {len(validation_files)} validation files")
67
+
68
+ # Load all samples first
69
+ all_samples = []
70
+
71
+ for file_path in validation_files:
72
+ print(f"Loading from {Path(file_path).name}...")
73
+ df = pd.read_parquet(file_path)
74
+
75
+ for _, row in df.iterrows():
76
+ image_data = row["image"]["bytes"]
77
+ true_index = row["label"]
78
+
79
+ all_samples.append((image_data, true_index))
80
+
81
+ print(f"Total available samples: {len(all_samples)}")
82
+
83
+ # Randomly sample the requested number of samples
84
+ if len(all_samples) <= self.sample_limit:
85
+ selected_samples = all_samples
86
+ print(f"Using all {len(selected_samples)} available samples")
87
+ else:
88
+ selected_samples = random.sample(all_samples, self.sample_limit)
89
+ print(
90
+ f"Randomly selected {len(selected_samples)} samples from {len(all_samples)} available"
91
+ )
92
+
93
+ print(f"Random seed used: {self.random_seed}")
94
+ return selected_samples
95
+
96
+ def calculate_accuracy(
97
+ self, predictions: List[Union[int, str]], ground_truths: List[int]
98
+ ) -> float:
99
+ """
100
+ Calculate exact accuracy for Food101 dataset.
101
+
102
+ Args:
103
+ predictions: List of predicted indices or label names
104
+ ground_truths: List of true labels
105
+
106
+ Returns:
107
+ Accuracy score as float
108
+ """
109
+ if not predictions or not ground_truths:
110
+ return 0.0
111
+
112
+ # Check for exact matches
113
+ exact_matches = 0
114
+ for pred, true in zip(predictions, ground_truths):
115
+ if pred == true:
116
+ exact_matches += 1
117
+
118
+ return exact_matches / len(predictions)
119
+
120
+ def evaluate_model(
121
+ self, samples: List[Tuple[bytes, int]], verbose: bool = True
122
+ ) -> Dict[str, Any]:
123
+ """
124
+ Evaluate the model on the provided samples.
125
+
126
+ Args:
127
+ samples: List of (image_bytes, true_index) tuples
128
+ verbose: Whether to print detailed results
129
+
130
+ Returns:
131
+ Dictionary with evaluation metrics
132
+ """
133
+ print(f"\nEvaluating model on {len(samples)} samples...")
134
+
135
+ predictions = []
136
+ ground_truths = []
137
+ prediction_examples = []
138
+ correct_predictions = 0
139
+
140
+ for i, (image_bytes, true_index) in enumerate(samples):
141
+ try:
142
+ predicted_index = self.model.classify(image_bytes)
143
+ predictions.append(predicted_index)
144
+ ground_truths.append(true_index)
145
+
146
+ # Check if prediction is correct using dataset-specific logic
147
+ is_correct = predicted_index == true_index
148
+ if is_correct:
149
+ correct_predictions += 1
150
+
151
+ # Convert index to label name for display and logging
152
+ predicted_label_name = index_to_label(predicted_index)
153
+
154
+ # Store first 10 examples for MLflow
155
+ if i < 10:
156
+ prediction_examples.append(
157
+ {
158
+ "sample_id": i + 1,
159
+ "true_label": LABELS[true_index],
160
+ "predicted_label": predicted_label_name,
161
+ "predicted_index": predicted_index,
162
+ "true_index": true_index,
163
+ "is_correct": is_correct,
164
+ }
165
+ )
166
+
167
+ if verbose and i < 10: # Print first 10 predictions
168
+ status = "✓" if is_correct else "✗"
169
+ print(
170
+ f"Sample {i+1:2d}: {status} True='{LABELS[true_index]:25s}' (idx: {true_index}) | Predicted='{predicted_label_name}' (idx: {predicted_index})"
171
+ )
172
+
173
+ except Exception as e:
174
+ print(f"Error processing sample {i+1}: {e}")
175
+ predictions.append("ERROR")
176
+ ground_truths.append(true_index)
177
+
178
+ # Calculate metrics
179
+ total_samples = len(samples)
180
+ successful_predictions = len([p for p in predictions if p != "ERROR"])
181
+
182
+ # Calculate accuracy using dataset-specific method
183
+ accuracy = self.calculate_accuracy(predictions, ground_truths)
184
+ success_rate = (
185
+ successful_predictions / total_samples if total_samples > 0 else 0
186
+ )
187
+
188
+ results = {
189
+ "total_samples": total_samples,
190
+ "successful_predictions": successful_predictions,
191
+ "correct_predictions": correct_predictions,
192
+ "success_rate": success_rate,
193
+ "accuracy": accuracy,
194
+ "prediction_examples": prediction_examples,
195
+ }
196
+
197
+ return results
198
+
199
+ def log_mlflow_metrics(self, results: Dict[str, Any]) -> None:
200
+ """
201
+ Log evaluation metrics to MLflow.
202
+
203
+ Args:
204
+ results: The results from the evaluation.
205
+ """
206
+ mlflow.log_metric("total_samples", results["total_samples"])
207
+ mlflow.log_metric("successful_predictions", results["successful_predictions"])
208
+ mlflow.log_metric("success_rate", results["success_rate"])
209
+ mlflow.log_metric("correct_predictions", results["correct_predictions"])
210
+ mlflow.log_metric("accuracy", results["accuracy"])
211
+
212
+ def log_mlflow_artifacts(self, results: Dict[str, Any]) -> None:
213
+ """
214
+ Log evaluation artifacts to MLflow.
215
+
216
+ Args:
217
+ results: The results from the evaluation.
218
+
219
+ """
220
+ examples_data = []
221
+ for example in results["prediction_examples"]:
222
+ status = "✓" if example.get("is_correct", False) else "✗"
223
+ examples_data.append(
224
+ f"Sample {example['sample_id']}: {status} {example['true_label']} -> {example['predicted_label']}"
225
+ )
226
+
227
+ examples_text = "\n".join(examples_data)
228
+ examples_file = f"{self.DATASET_NAME.lower()}_evaluation_examples.txt"
229
+ with open(examples_file, "w", encoding="utf-8", newline="") as f:
230
+ f.write(examples_text)
231
+ mlflow.log_artifact(examples_file)
232
+
233
+ model_source = (
234
+ getattr(self.model, "model_path", "N/A")
235
+ if hasattr(self.model, "model_path")
236
+ else "N/A"
237
+ )
238
+ summary = f"""{self.model_name} {self.DATASET_NAME} Evaluation Summary
239
+ ========================================={'=' * len(self.DATASET_NAME)}
240
+ Model: {self.model_name} ({model_source})
241
+ Dataset: {self.DATASET_NAME} validation set
242
+ Samples: {results['total_samples']}
243
+ Success Rate: {results['success_rate']:.2%}
244
+ Accuracy: {results['accuracy']:.2%}
245
+ Correct Predictions: {results['correct_predictions']}
246
+ """
247
+
248
+ summary_file = f"{self.DATASET_NAME.lower()}_evaluation_summary.txt"
249
+ with open(summary_file, "w", encoding="utf-8", newline="") as f:
250
+ f.write(summary)
251
+ mlflow.log_artifact(summary_file)
252
+
253
+ # Clean up temporary files
254
+ Path(examples_file).unlink(missing_ok=True)
255
+ Path(summary_file).unlink(missing_ok=True)
256
+
257
+ def run_evaluation(self) -> None:
258
+ """Run the complete evaluation pipeline with MLflow tracking."""
259
+ print("=" * 60)
260
+ print(f"{self.model_name} {self.DATASET_NAME} Evaluation with MLflow")
261
+ print("=" * 60)
262
+
263
+ mlflow.set_experiment(self.experiment_name)
264
+
265
+ # Create descriptive run name
266
+ if self.custom_run_name:
267
+ run_name = self.custom_run_name
268
+ else:
269
+ timestamp = datetime.now().strftime("%m%d_%H%M")
270
+ run_name = f"{self.model_name}_Food101_n{self.sample_limit}_seed{self.random_seed}_{timestamp}"
271
+
272
+ with mlflow.start_run(run_name=run_name):
273
+ # Add useful tags for filtering and organization
274
+ mlflow.set_tag("model_type", self.model_name)
275
+ mlflow.set_tag("dataset", self.DATASET_NAME)
276
+ mlflow.set_tag("sample_size", str(self.sample_limit))
277
+ mlflow.set_tag("evaluation_type", "validation")
278
+
279
+ mlflow.log_param("model_name", self.model_name)
280
+ mlflow.log_param("model_class", self.model.__class__.__name__)
281
+ mlflow.log_param("dataset", self.DATASET_NAME)
282
+ mlflow.log_param("sample_limit", self.sample_limit)
283
+ mlflow.log_param("random_seed", self.random_seed)
284
+ mlflow.log_param("evaluation_date", datetime.now().isoformat())
285
+
286
+ # Log model-specific parameters if available
287
+ if hasattr(self.model, "model_path"):
288
+ mlflow.log_param(
289
+ "model_source", getattr(self.model, "model_path", "Unknown")
290
+ )
291
+ if hasattr(self.model, "preprocessor_path"):
292
+ mlflow.log_param(
293
+ "preprocessor_path",
294
+ getattr(self.model, "preprocessor_path", "Unknown"),
295
+ )
296
+
297
+ samples = self.load_validation_data()
298
+
299
+ if not samples:
300
+ print(
301
+ f"No validation samples loaded. Check the {self.DATASET_NAME} dataset connection."
302
+ )
303
+ mlflow.log_param("status", "failed - no data")
304
+ return
305
+
306
+ mlflow.log_param("samples_loaded", len(samples))
307
+
308
+ results = self.evaluate_model(samples, verbose=True)
309
+
310
+ self.log_mlflow_metrics(results)
311
+ self.log_mlflow_artifacts(results)
312
+
313
+ self._print_results(results)
314
+
315
+ def _print_results(self, results: Dict[str, Any]) -> None:
316
+ """Print evaluation results to console."""
317
+ print("\n" + "=" * 60)
318
+ print("EVALUATION RESULTS")
319
+ print("=" * 60)
320
+ print(f"Total samples processed: {results['total_samples']}")
321
+ print(f"Successful predictions: {results['successful_predictions']}")
322
+ print(f"Success rate: {results['success_rate']:.2%}")
323
+ print(f"Correct predictions: {results['correct_predictions']}")
324
+ print(f"Accuracy: {results['accuracy']:.2%}")
325
+
326
+ print(f"\nMLflow run ID: {mlflow.active_run().info.run_id}")
src/labels.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LABELS = [
2
+ "apple_pie",
3
+ "baby_back_ribs",
4
+ "baklava",
5
+ "beef_carpaccio",
6
+ "beef_tartare",
7
+ "beet_salad",
8
+ "beignets",
9
+ "bibimbap",
10
+ "bread_pudding",
11
+ "breakfast_burrito",
12
+ "bruschetta",
13
+ "caesar_salad",
14
+ "cannoli",
15
+ "caprese_salad",
16
+ "carrot_cake",
17
+ "ceviche",
18
+ "cheesecake",
19
+ "cheese_plate",
20
+ "chicken_curry",
21
+ "chicken_quesadilla",
22
+ "chicken_wings",
23
+ "chocolate_cake",
24
+ "chocolate_mousse",
25
+ "churros",
26
+ "clam_chowder",
27
+ "club_sandwich",
28
+ "crab_cakes",
29
+ "creme_brulee",
30
+ "croque_madame",
31
+ "cup_cakes",
32
+ "deviled_eggs",
33
+ "donuts",
34
+ "dumplings",
35
+ "edamame",
36
+ "eggs_benedict",
37
+ "escargots",
38
+ "falafel",
39
+ "filet_mignon",
40
+ "fish_and_chips",
41
+ "foie_gras",
42
+ "french_fries",
43
+ "french_onion_soup",
44
+ "french_toast",
45
+ "fried_calamari",
46
+ "fried_rice",
47
+ "frozen_yogurt",
48
+ "garlic_bread",
49
+ "gnocchi",
50
+ "greek_salad",
51
+ "grilled_cheese_sandwich",
52
+ "grilled_salmon",
53
+ "guacamole",
54
+ "gyoza",
55
+ "hamburger",
56
+ "hot_and_sour_soup",
57
+ "hot_dog",
58
+ "huevos_rancheros",
59
+ "hummus",
60
+ "ice_cream",
61
+ "lasagna",
62
+ "lobster_bisque",
63
+ "lobster_roll_sandwich",
64
+ "macaroni_and_cheese",
65
+ "macarons",
66
+ "miso_soup",
67
+ "mussels",
68
+ "nachos",
69
+ "omelette",
70
+ "onion_rings",
71
+ "oysters",
72
+ "pad_thai",
73
+ "paella",
74
+ "pancakes",
75
+ "panna_cotta",
76
+ "peking_duck",
77
+ "pho",
78
+ "pizza",
79
+ "pork_chop",
80
+ "poutine",
81
+ "prime_rib",
82
+ "pulled_pork_sandwich",
83
+ "ramen",
84
+ "ravioli",
85
+ "red_velvet_cake",
86
+ "risotto",
87
+ "samosa",
88
+ "sashimi",
89
+ "scallops",
90
+ "seaweed_salad",
91
+ "shrimp_and_grits",
92
+ "spaghetti_bolognese",
93
+ "spaghetti_carbonara",
94
+ "spring_rolls",
95
+ "steak",
96
+ "strawberry_shortcake",
97
+ "sushi",
98
+ "tacos",
99
+ "takoyaki",
100
+ "tiramisu",
101
+ "tuna_tartare",
102
+ "waffles",
103
+ ]
104
+
105
+
106
+ def index_to_label(index: int) -> str:
107
+ """
108
+ Convert a class index to its corresponding label name.
109
+
110
+ Args:
111
+ index: The class index (0-based)
112
+
113
+ Returns:
114
+ str: The label name corresponding to the index, or a fallback string if index is out of bounds
115
+ """
116
+ if 0 <= index < len(LABELS):
117
+ return LABELS[index]
118
+ else:
119
+ return f"unknown_class_{index}"
120
+
121
+
122
+ def label_to_index(label: str) -> int:
123
+ """
124
+ Convert a label name to its corresponding class index.
125
+
126
+ Args:
127
+ label: The label name
128
+
129
+ Returns:
130
+ int: The index corresponding to the label, or -1 if label is not found
131
+ """
132
+ try:
133
+ return LABELS.index(label)
134
+ except ValueError:
135
+ return -1
src/models/__pycache__/food_classification_model.cpython-310.pyc ADDED
Binary file (943 Bytes). View file
 
src/models/__pycache__/food_classification_model.cpython-311.pyc ADDED
Binary file (1.11 kB). View file
 
src/models/__pycache__/model_discovery.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
src/models/__pycache__/prithiv_ml_food101.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
src/models/__pycache__/prithiv_ml_food101.cpython-311.pyc ADDED
Binary file (9.91 kB). View file
 
src/models/__pycache__/resnet18.cpython-310.pyc ADDED
Binary file (2.18 kB). View file
 
src/models/__pycache__/resnet18.cpython-311.pyc ADDED
Binary file (2.28 kB). View file
 
src/models/__pycache__/vgg16.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
src/models/food_classification_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class FoodClassificationModel(ABC):
5
+ """Abstract Base Class that serves as a common interface for all models."""
6
+
7
+ @abstractmethod
8
+ def classify(self, image: bytes) -> int:
9
+ """
10
+ Abstract method to classify an image into a food category.
11
+
12
+ Args:
13
+ image: The image bytes to classify.
14
+
15
+ Returns:
16
+ int: The index of the predicted class. This returns the class index, not the class name.
17
+ """
src/models/model_discovery.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model discovery utility for dynamically finding all models that inherit from FoodClassificationModel.
3
+ """
4
+
5
+ import importlib
6
+ import inspect
7
+ from pathlib import Path
8
+ from typing import Dict, Any
9
+
10
+ from .food_classification_model import FoodClassificationModel
11
+
12
+
13
+ def discover_models(models_dir: Path = None) -> Dict[str, Dict[str, Any]]:
14
+ """
15
+ Dynamically discover all models that inherit from FoodClassificationModel.
16
+
17
+ Args:
18
+ models_dir: Path to the models directory. If None, uses the current module's directory.
19
+
20
+ Returns:
21
+ Dict mapping display names to model information containing:
22
+ - 'class': The model class
23
+ - 'module': The module name
24
+ - 'class_name': The class name
25
+ """
26
+ if models_dir is None:
27
+ models_dir = Path(__file__).parent
28
+
29
+ available_models = {}
30
+
31
+ # Iterate through all Python files in the models directory
32
+ for py_file in models_dir.glob("*.py"):
33
+ if (
34
+ py_file.name.startswith("__")
35
+ or py_file.name == "food_classification_model.py"
36
+ or py_file.name == "model_discovery.py"
37
+ ):
38
+ continue
39
+
40
+ try:
41
+ # Import the module dynamically
42
+ module_name = f"src.models.{py_file.stem}"
43
+ module = importlib.import_module(module_name)
44
+
45
+ # Find all classes in the module that inherit from FoodClassificationModel
46
+ for name, obj in inspect.getmembers(module, inspect.isclass):
47
+ if (
48
+ issubclass(obj, FoodClassificationModel)
49
+ and obj != FoodClassificationModel
50
+ and obj.__module__ == module_name
51
+ ):
52
+
53
+ # Create a user-friendly name
54
+ display_name = _create_display_name(name)
55
+
56
+ available_models[display_name] = {
57
+ "class": obj,
58
+ "module": module_name,
59
+ "class_name": name,
60
+ }
61
+
62
+ except Exception as e:
63
+ # In a non-Streamlit context, we might want to log or handle this differently
64
+ print(f"Warning: Could not load model from {py_file.name}: {str(e)}")
65
+ continue
66
+
67
+ return available_models
68
+
69
+
70
+ def _create_display_name(class_name: str) -> str:
71
+ """
72
+ Create a user-friendly display name from a class name.
73
+
74
+ Args:
75
+ class_name: The original class name
76
+
77
+ Returns:
78
+ A user-friendly display name
79
+ """
80
+ # Create a user-friendly name
81
+ display_name = class_name
82
+
83
+ if "prithiv" in class_name.lower():
84
+ display_name = "PrithivML Food-101 (Benchmark)"
85
+ elif "resnet" in class_name.lower():
86
+ display_name = "ResNet-18"
87
+ elif "vgg" in class_name.lower():
88
+ display_name = "VGG-16"
89
+ elif "efficientnet" in class_name.lower():
90
+ display_name = "EfficientNet"
91
+ elif "mobilenet" in class_name.lower():
92
+ display_name = "MobileNet"
93
+ elif "densenet" in class_name.lower():
94
+ display_name = "DenseNet"
95
+
96
+ return display_name
97
+
98
+
99
+ def get_model_names() -> list:
100
+ """
101
+ Get a list of all available model display names.
102
+
103
+ Returns:
104
+ List of model display names
105
+ """
106
+ models = discover_models()
107
+ return list(models.keys())
108
+
109
+
110
+ def get_model_info(display_name: str) -> Dict[str, Any]:
111
+ """
112
+ Get model information for a specific model by display name.
113
+
114
+ Args:
115
+ display_name: The display name of the model
116
+
117
+ Returns:
118
+ Model information dictionary or None if not found
119
+ """
120
+ models = discover_models()
121
+ return models.get(display_name)
src/models/prithiv_ml_food101.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, SiglipForImageClassification
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ from src.models.food_classification_model import FoodClassificationModel
10
+
11
+
12
+ class PrithivMlFood101(FoodClassificationModel):
13
+ """
14
+ Interface for accessing the PrithivML Food-101 model architecture.
15
+ This model was already trained on the Food-101 dataset and performs well on it.
16
+ It is supposed to serve as a benchmark for our own model finetuning and potentially
17
+ as an alternative to be deployed.
18
+ See it on Huggingface: https://huggingface.co/prithivMLmods/Food-101-93M.
19
+ """
20
+
21
+ def __init__(self, model_name: str = "prithivMLmods/Food-101-93M"):
22
+ """
23
+ Load the PrithivML Food-101 model.
24
+
25
+ Preference order:
26
+ 1) Load from local repo snapshot at <repo_root>/models/prithivMLmods/Food-101-93M
27
+ 2) If not present, prompt the user to download using Makefile
28
+ make download-hf-model MODEL_PATH=prithivMLmods/Food-101-93M
29
+ """
30
+
31
+ # Set up proper cache directory for HF Spaces (safe no-op locally)
32
+ if not os.environ.get("HF_HOME"):
33
+ cache_dir = Path(tempfile.gettempdir()) / "transformers_cache"
34
+ cache_dir.mkdir(exist_ok=True)
35
+ os.environ["HF_HOME"] = str(cache_dir)
36
+
37
+ # Resolve repo root robustly from this file's location
38
+ repo_root = Path(__file__).resolve().parents[2]
39
+ local_model_dir = repo_root / "models" / "prithivMLmods" / "Food-101-93M"
40
+
41
+ # Determine whether a local copy exists (safetensors or bin)
42
+ local_exists = local_model_dir.exists() and (
43
+ (local_model_dir / "model.safetensors").exists()
44
+ or (local_model_dir / "pytorch_model.bin").exists()
45
+ )
46
+
47
+ if not local_exists:
48
+ # Provide a clear, actionable message to fetch the model snapshot
49
+ make_cmd = "make download-hf-model MODEL_PATH=prithivMLmods/Food-101-93M"
50
+ raise RuntimeError(
51
+ "Local model not found at 'models/prithivMLmods/Food-101-93M'. "
52
+ "Please download it first using:\n"
53
+ f" {make_cmd}\n"
54
+ "After download completes, re-run your program."
55
+ )
56
+
57
+ # Load from local directory snapshot
58
+ try:
59
+ self.model = SiglipForImageClassification.from_pretrained(
60
+ str(local_model_dir),
61
+ cache_dir=os.environ.get("HF_HOME"),
62
+ local_files_only=True,
63
+ force_download=False,
64
+ )
65
+ self.processor = AutoImageProcessor.from_pretrained(
66
+ str(local_model_dir),
67
+ cache_dir=os.environ.get("HF_HOME"),
68
+ local_files_only=True,
69
+ force_download=False,
70
+ use_fast=True, # Use fast processor to avoid warning
71
+ )
72
+ self.model_name = str(local_model_dir)
73
+ except Exception as e:
74
+ raise RuntimeError(
75
+ "Failed to load local model from 'models/prithivMLmods/Food-101-93M': "
76
+ f"{str(e)}"
77
+ )
78
+
79
+ def classify(self, image: bytes) -> int:
80
+ pil_image = Image.open(io.BytesIO(image)).convert("RGB")
81
+ inputs = self.processor(images=pil_image, return_tensors="pt")
82
+
83
+ with torch.no_grad():
84
+ outputs = self.model(**inputs)
85
+ logits = outputs.logits
86
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
87
+
88
+ predicted_idx = torch.argmax(probs).item()
89
+ return predicted_idx
src/models/resnet18.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ from pathlib import Path
7
+
8
+ from src.models.food_classification_model import FoodClassificationModel
9
+
10
+
11
+ class Resnet18(FoodClassificationModel):
12
+ """
13
+ Interface for accessing the Resnet-18 model architecture.
14
+ See the base model here: https://huggingface.co/microsoft/resnet-18.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ preprocessor_path: str = "microsoft/resnet-18",
20
+ model_path: str = "microsoft/resnet-18",
21
+ ):
22
+ """
23
+ Prefer loading from a local snapshot under models/microsoft/resnet-18.
24
+ If the local snapshot doesn't exist, prompt the user to download it via Makefile.
25
+ """
26
+
27
+ # Resolve repo root and local model dir
28
+ repo_root = Path(__file__).resolve().parents[2]
29
+ local_model_dir = repo_root / "models" / "microsoft" / "resnet-18"
30
+
31
+ # Check if a local HF snapshot exists (config + weights)
32
+ local_exists = local_model_dir.exists() and (
33
+ (local_model_dir / "pytorch_model.bin").exists()
34
+ or (local_model_dir / "model.safetensors").exists()
35
+ )
36
+
37
+ if not local_exists:
38
+ make_cmd = "make download-hf-model MODEL_PATH=microsoft/resnet-18"
39
+ raise RuntimeError(
40
+ "Local model not found at 'models/microsoft/resnet-18'. "
41
+ "Please download it first using:\n"
42
+ f" {make_cmd}\n"
43
+ "After download completes, re-run your program."
44
+ )
45
+
46
+ # Load from local folder
47
+ self.image_processor = AutoImageProcessor.from_pretrained(str(local_model_dir))
48
+ self.model = AutoModelForImageClassification.from_pretrained(
49
+ str(local_model_dir)
50
+ )
51
+
52
+ def classify(self, image: bytes) -> int:
53
+ pil_image = Image.open(io.BytesIO(image))
54
+ inputs = self.image_processor(pil_image, return_tensors="pt")
55
+
56
+ with torch.no_grad():
57
+ logits = self.model(**inputs).logits
58
+
59
+ # model predicts one of the 101 Food-101 classes (if fine-tuned for Food-101).
60
+ # If using the default microsoft/resnet-18 weights, this will predict one
61
+ # of the 1000 ImageNet classes, not Food-101.
62
+ predicted_label = logits.argmax(-1).item()
63
+ return predicted_label
src/models/vgg16.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ from PIL import Image
6
+ import io
7
+ from pathlib import Path
8
+ from typing import Dict, Any
9
+ from src.models.food_classification_model import FoodClassificationModel
10
+
11
+
12
+ class VGG16(FoodClassificationModel):
13
+ """Interface for accessing the VGG-16 model architecture."""
14
+
15
+ def __init__(self, weights: str = "IMAGENET1K_V1", num_classes: int = 101):
16
+ """
17
+ Initialize VGG-16. Prefer loading local fine-tuned weights if available.
18
+
19
+ Priority:
20
+ 1) Load ImageNet base and replace classifier, then load local fine-tuned checkpoint
21
+ from <repo_root>/models/vgg16/vgg16-397923af.pth (if exists).
22
+ 2) Otherwise, fall back to ImageNet weights only (not Food-101 trained), and
23
+ instruct user to provide or train a .pth for Food-101 fine-tuning.
24
+ """
25
+ repo_root = Path(__file__).resolve().parents[2]
26
+ local_weights = repo_root / "models" / "vgg16/vgg16-397923af.pth"
27
+
28
+ # Base model with ImageNet weights
29
+ self.model = models.vgg16(weights=weights)
30
+
31
+ num_features = self.model.classifier[6].in_features
32
+ self.model.classifier[6] = nn.Linear(num_features, num_classes)
33
+
34
+ # If local fine-tuned weights exist, load them
35
+ if local_weights.exists():
36
+ try:
37
+ raw_ckpt: Dict[str, Any] = torch.load(local_weights, map_location="cpu")
38
+
39
+ # Unwrap common checkpoint formats
40
+ if isinstance(raw_ckpt, dict) and "state_dict" in raw_ckpt:
41
+ ckpt = raw_ckpt["state_dict"]
42
+ else:
43
+ ckpt = raw_ckpt
44
+
45
+ # Normalize key prefixes commonly introduced by wrappers
46
+ def strip_prefix(
47
+ sd: Dict[str, torch.Tensor], prefix: str
48
+ ) -> Dict[str, torch.Tensor]:
49
+ if all(k.startswith(prefix) for k in sd.keys()):
50
+ return {k[len(prefix) :]: v for k, v in sd.items()}
51
+ return sd
52
+
53
+ for p in ("module.", "model.", "net."):
54
+ ckpt = strip_prefix(ckpt, p)
55
+
56
+ # Filter out mismatched keys (e.g., classifier.6 for 1000->101 classes)
57
+ model_sd = self.model.state_dict()
58
+ filtered_ckpt = {}
59
+ skipped = []
60
+ for k, v in ckpt.items():
61
+ if k in model_sd and isinstance(v, torch.Tensor):
62
+ if model_sd[k].shape == v.shape:
63
+ filtered_ckpt[k] = v
64
+ else:
65
+ skipped.append(
66
+ (k, tuple(v.shape), tuple(model_sd[k].shape))
67
+ )
68
+ # Silently ignore keys not present in the current model
69
+
70
+ missing_before = set(model_sd.keys()) - set(filtered_ckpt.keys())
71
+ self.model.load_state_dict(filtered_ckpt, strict=False)
72
+
73
+ # Optional: print a brief summary to logs for transparency
74
+ if skipped:
75
+ skipped_str = ", ".join(
76
+ [f"{k}: {src} -> {dst}" for k, src, dst in skipped[:5]]
77
+ )
78
+ more = "" if len(skipped) <= 5 else f" (+{len(skipped)-5} more)"
79
+ print(
80
+ f"[VGG16] Partially loaded checkpoint from '{local_weights}'. "
81
+ f"Skipped mismatched keys: {skipped_str}{more}"
82
+ )
83
+
84
+ if (
85
+ "classifier.6.weight" in missing_before
86
+ or "classifier.6.bias" in missing_before
87
+ ):
88
+ print(
89
+ "[VGG16] Final classifier layer initialized for 101 classes and was not loaded from checkpoint."
90
+ )
91
+
92
+ except Exception as e:
93
+ raise RuntimeError(
94
+ f"Failed to load local VGG16 weights from '{local_weights}': {e}"
95
+ )
96
+ else:
97
+ # No local fine-tuned weights: keep ImageNet weights but warn with action
98
+ raise RuntimeError(
99
+ "Local fine-tuned weights not found at 'models/vgg16/vgg16-397923af.pth'.\n"
100
+ "Please place your fine-tuned checkpoint there, or train/export one.\n"
101
+ "Alternatively, switch to a HF model with a Makefile download target."
102
+ )
103
+
104
+ self.model.eval()
105
+
106
+ self.transform = transforms.Compose(
107
+ [
108
+ transforms.Resize(256),
109
+ transforms.CenterCrop(224),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize(
112
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
113
+ ),
114
+ ]
115
+ )
116
+
117
+ def classify(self, image: bytes) -> int:
118
+ pil_image = Image.open(io.BytesIO(image))
119
+
120
+ if pil_image.mode != "RGB":
121
+ pil_image = pil_image.convert("RGB")
122
+
123
+ input_tensor = self.transform(pil_image)
124
+ input_batch = input_tensor.unsqueeze(0) # Add batch dimension
125
+
126
+ with torch.no_grad():
127
+ outputs = self.model(input_batch)
128
+ predicted_idx = torch.argmax(outputs).item()
129
+
130
+ return predicted_idx
src/train/__pycache__/preprocess_data.cpython-310.pyc ADDED
Binary file (4.08 kB). View file
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff