Spaces:
Sleeping
(FEAT)[UI/UX]: Add support for FTIR, multi-format upload, and model comparison tab
Browse filesSidebar:
- Added spectroscopy modality selection (Raman/FTIR) with explanatory info for each.
- Expanded model selection and improved project description to reflect FTIR and multi-model features.
Input column:
- File uploader now accepts .txt, .csv, and .json for single and batch uploads.
- Updated help text and file type validation.
New function 'render_comparison_tab':
- Allows users to select multiple models and upload/choose sample data for side-by-side prediction.
- Displays comparison results in tables and visualizations (confidence bar chart, agreement stats, performance metrics).
- Supports exporting results in JSON/full report formats.
- Shows historical comparison statistics with agreement matrix and heatmap.
New function render_performance_tab:
- Integrates performance dashboard from tracker utility.
- modules/ui_components.py +478 -51
|
@@ -13,9 +13,9 @@ from modules.callbacks import (
|
|
| 13 |
on_model_change,
|
| 14 |
on_input_mode_change,
|
| 15 |
on_sample_change,
|
|
|
|
| 16 |
reset_ephemeral_state,
|
| 17 |
log_message,
|
| 18 |
-
clear_batch_results,
|
| 19 |
)
|
| 20 |
from core_logic import (
|
| 21 |
get_sample_files,
|
|
@@ -24,7 +24,6 @@ from core_logic import (
|
|
| 24 |
parse_spectrum_data,
|
| 25 |
label_file,
|
| 26 |
)
|
| 27 |
-
from modules.callbacks import reset_results
|
| 28 |
from utils.results_manager import ResultsManager
|
| 29 |
from utils.confidence import calculate_softmax_confidence
|
| 30 |
from utils.multifile import process_multiple_files, display_batch_results
|
|
@@ -41,7 +40,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 41 |
"""Create spectrum visualization plot"""
|
| 42 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 43 |
|
| 44 |
-
#
|
| 45 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 46 |
ax[0].set_title("Raw Input Spectrum")
|
| 47 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
@@ -49,7 +48,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 49 |
ax[0].grid(True, alpha=0.3)
|
| 50 |
ax[0].legend()
|
| 51 |
|
| 52 |
-
#
|
| 53 |
ax[1].plot(
|
| 54 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
| 55 |
)
|
|
@@ -60,7 +59,7 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 60 |
ax[1].legend()
|
| 61 |
|
| 62 |
fig.tight_layout()
|
| 63 |
-
#
|
| 64 |
buf = io.BytesIO()
|
| 65 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
| 66 |
buf.seek(0)
|
|
@@ -69,6 +68,9 @@ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None
|
|
| 69 |
return Image.open(buf)
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
def render_confidence_progress(
|
| 73 |
probs: np.ndarray,
|
| 74 |
labels: list[str] = ["Stable", "Weathered"],
|
|
@@ -114,7 +116,10 @@ def render_confidence_progress(
|
|
| 114 |
st.markdown("")
|
| 115 |
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
if d is None:
|
| 119 |
d = {}
|
| 120 |
if not d:
|
|
@@ -126,6 +131,9 @@ def render_kv_grid(d: dict = {}, ncols: int = 2):
|
|
| 126 |
st.caption(f"**{k}:** {v}")
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
| 129 |
def render_model_meta(model_choice: str):
|
| 130 |
info = MODEL_CONFIG.get(model_choice, {})
|
| 131 |
emoji = info.get("emoji", "")
|
|
@@ -143,6 +151,9 @@ def render_model_meta(model_choice: str):
|
|
| 143 |
st.caption(desc)
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
|
|
|
| 146 |
def get_confidence_description(logit_margin):
|
| 147 |
"""Get human-readable confidence description"""
|
| 148 |
if logit_margin > 1000:
|
|
@@ -155,13 +166,35 @@ def get_confidence_description(logit_margin):
|
|
| 155 |
return "LOW", "🔴"
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
| 158 |
def render_sidebar():
|
| 159 |
with st.sidebar:
|
| 160 |
# Header
|
| 161 |
st.header("AI-Driven Polymer Classification")
|
| 162 |
st.caption(
|
| 163 |
-
"Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
model_labels = [
|
| 166 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
| 167 |
]
|
|
@@ -173,10 +206,10 @@ def render_sidebar():
|
|
| 173 |
)
|
| 174 |
model_choice = selected_label.split(" ", 1)[1]
|
| 175 |
|
| 176 |
-
#
|
| 177 |
render_model_meta(model_choice)
|
| 178 |
|
| 179 |
-
#
|
| 180 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
| 181 |
st.markdown(
|
| 182 |
"""
|
|
@@ -184,8 +217,9 @@ def render_sidebar():
|
|
| 184 |
|
| 185 |
**Purpose**: Classify polymer degradation using AI<br>
|
| 186 |
**Input**: Raman spectroscopy .txt files<br>
|
| 187 |
-
**Models**: CNN architectures for
|
| 188 |
-
**
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
**Contributors**<br>
|
|
@@ -207,11 +241,7 @@ def render_sidebar():
|
|
| 207 |
)
|
| 208 |
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
# In modules/ui_components.py
|
| 213 |
-
|
| 214 |
-
|
| 215 |
def render_input_column():
|
| 216 |
st.markdown("##### Data Input")
|
| 217 |
|
|
@@ -224,22 +254,20 @@ def render_input_column():
|
|
| 224 |
)
|
| 225 |
|
| 226 |
# == Input Mode Logic ==
|
| 227 |
-
# ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ...
|
| 228 |
-
# ==Upload tab==
|
| 229 |
if mode == "Upload File":
|
| 230 |
upload_key = st.session_state["current_upload_key"]
|
| 231 |
up = st.file_uploader(
|
| 232 |
-
"Upload
|
| 233 |
-
type="txt",
|
| 234 |
-
help="Upload
|
| 235 |
key=upload_key, # ← versioned key
|
| 236 |
)
|
| 237 |
|
| 238 |
-
#
|
| 239 |
if up is not None:
|
| 240 |
raw = up.read()
|
| 241 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 242 |
-
#
|
| 243 |
if (
|
| 244 |
st.session_state.get("filename") != getattr(up, "name", None)
|
| 245 |
or st.session_state.get("input_source") != "upload"
|
|
@@ -255,23 +283,20 @@ def render_input_column():
|
|
| 255 |
st.session_state["status_type"] = "success"
|
| 256 |
reset_results("New file uploaded")
|
| 257 |
|
| 258 |
-
#
|
| 259 |
elif mode == "Batch Upload":
|
| 260 |
st.session_state["batch_mode"] = True
|
| 261 |
-
# --- START: BUG 1 & 3 FIX ---
|
| 262 |
# Use a versioned key to ensure the file uploader resets properly.
|
| 263 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
| 264 |
uploaded_files = st.file_uploader(
|
| 265 |
-
"Upload multiple
|
| 266 |
-
type="txt",
|
| 267 |
accept_multiple_files=True,
|
| 268 |
-
help="Upload
|
| 269 |
key=batch_upload_key,
|
| 270 |
)
|
| 271 |
-
# --- END: BUG 1 & 3 FIX ---
|
| 272 |
|
| 273 |
if uploaded_files:
|
| 274 |
-
# --- START: Bug 1 Fix ---
|
| 275 |
# Use a dictionary to keep only unique files based on name and size
|
| 276 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
| 277 |
unique_file_list = list(unique_files.values())
|
|
@@ -281,9 +306,7 @@ def render_input_column():
|
|
| 281 |
|
| 282 |
# Optionally, inform the user that duplicates were removed
|
| 283 |
if num_uploaded > num_unique:
|
| 284 |
-
st.info(
|
| 285 |
-
f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed."
|
| 286 |
-
)
|
| 287 |
|
| 288 |
# Use the unique list
|
| 289 |
st.session_state["batch_files"] = unique_file_list
|
|
@@ -291,7 +314,6 @@ def render_input_column():
|
|
| 291 |
f"{num_unique} ready for batch analysis"
|
| 292 |
)
|
| 293 |
st.session_state["status_type"] = "success"
|
| 294 |
-
# --- END: Bug 1 Fix ---
|
| 295 |
else:
|
| 296 |
st.session_state["batch_files"] = []
|
| 297 |
# This check prevents resetting the status if files are already staged
|
|
@@ -301,7 +323,7 @@ def render_input_column():
|
|
| 301 |
)
|
| 302 |
st.session_state["status_type"] = "info"
|
| 303 |
|
| 304 |
-
#
|
| 305 |
elif mode == "Sample Data":
|
| 306 |
st.session_state["batch_mode"] = False
|
| 307 |
sample_files = get_sample_files()
|
|
@@ -330,9 +352,6 @@ def render_input_column():
|
|
| 330 |
else:
|
| 331 |
st.info(msg)
|
| 332 |
|
| 333 |
-
# --- DE-NESTED LOGIC STARTS HERE ---
|
| 334 |
-
# This code now runs on EVERY execution, guaranteeing the buttons will appear.
|
| 335 |
-
|
| 336 |
# Safely get model choice from session state
|
| 337 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
| 338 |
model = load_model(model_choice)
|
|
@@ -388,7 +407,7 @@ def render_input_column():
|
|
| 388 |
st.error(f"Error processing spectrum data: {e}")
|
| 389 |
|
| 390 |
|
| 391 |
-
#
|
| 392 |
|
| 393 |
|
| 394 |
def render_results_column():
|
|
@@ -410,7 +429,7 @@ def render_results_column():
|
|
| 410 |
filename = st.session_state.get("filename", "Unknown")
|
| 411 |
|
| 412 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
| 413 |
-
#
|
| 414 |
if y_resampled is None:
|
| 415 |
raise ValueError(
|
| 416 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
|
@@ -437,14 +456,14 @@ def render_results_column():
|
|
| 437 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
| 438 |
)
|
| 439 |
|
| 440 |
-
#
|
| 441 |
true_label_idx = label_file(filename)
|
| 442 |
true_label_str = (
|
| 443 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
| 444 |
if true_label_idx is not None
|
| 445 |
else "Unknown"
|
| 446 |
)
|
| 447 |
-
#
|
| 448 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
| 449 |
|
| 450 |
# Enhanced confidence calculation
|
|
@@ -455,7 +474,7 @@ def render_results_column():
|
|
| 455 |
)
|
| 456 |
confidence_desc = confidence_level
|
| 457 |
else:
|
| 458 |
-
# Fallback to
|
| 459 |
logit_margin = abs(
|
| 460 |
(logits_list[0] - logits_list[1])
|
| 461 |
if logits_list is not None and len(logits_list) >= 2
|
|
@@ -487,7 +506,7 @@ def render_results_column():
|
|
| 487 |
},
|
| 488 |
)
|
| 489 |
|
| 490 |
-
#
|
| 491 |
model_choice = (
|
| 492 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
| 493 |
if "model_select" in st.session_state
|
|
@@ -505,7 +524,6 @@ def render_results_column():
|
|
| 505 |
if os.path.exists(model_path)
|
| 506 |
else "N/A"
|
| 507 |
)
|
| 508 |
-
# Removed unused variable 'input_tensor'
|
| 509 |
|
| 510 |
start_render = time.time()
|
| 511 |
|
|
@@ -590,17 +608,13 @@ def render_results_column():
|
|
| 590 |
""",
|
| 591 |
unsafe_allow_html=True,
|
| 592 |
)
|
| 593 |
-
# --- END: CONSOLIDATED CONFIDENCE ANALYSIS ---
|
| 594 |
|
| 595 |
st.divider()
|
| 596 |
|
| 597 |
-
#
|
| 598 |
-
# Secondary info is now a clean, single-line caption
|
| 599 |
st.caption(
|
| 600 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
| 601 |
)
|
| 602 |
-
# --- END: CLEAN METADATA FOOTER ---
|
| 603 |
-
|
| 604 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 605 |
|
| 606 |
elif active_tab == "Technical":
|
|
@@ -918,7 +932,7 @@ def render_results_column():
|
|
| 918 |
"""
|
| 919 |
)
|
| 920 |
else:
|
| 921 |
-
#
|
| 922 |
st.markdown(
|
| 923 |
"""
|
| 924 |
##### How to Get Started
|
|
@@ -948,3 +962,416 @@ def render_results_column():
|
|
| 948 |
- 🏭 Quality control in manufacturing
|
| 949 |
"""
|
| 950 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
on_model_change,
|
| 14 |
on_input_mode_change,
|
| 15 |
on_sample_change,
|
| 16 |
+
reset_results,
|
| 17 |
reset_ephemeral_state,
|
| 18 |
log_message,
|
|
|
|
| 19 |
)
|
| 20 |
from core_logic import (
|
| 21 |
get_sample_files,
|
|
|
|
| 24 |
parse_spectrum_data,
|
| 25 |
label_file,
|
| 26 |
)
|
|
|
|
| 27 |
from utils.results_manager import ResultsManager
|
| 28 |
from utils.confidence import calculate_softmax_confidence
|
| 29 |
from utils.multifile import process_multiple_files, display_batch_results
|
|
|
|
| 40 |
"""Create spectrum visualization plot"""
|
| 41 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 42 |
|
| 43 |
+
# Raw spectrum
|
| 44 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 45 |
ax[0].set_title("Raw Input Spectrum")
|
| 46 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
|
|
| 48 |
ax[0].grid(True, alpha=0.3)
|
| 49 |
ax[0].legend()
|
| 50 |
|
| 51 |
+
# Resampled spectrum
|
| 52 |
ax[1].plot(
|
| 53 |
x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1
|
| 54 |
)
|
|
|
|
| 59 |
ax[1].legend()
|
| 60 |
|
| 61 |
fig.tight_layout()
|
| 62 |
+
# Convert to image
|
| 63 |
buf = io.BytesIO()
|
| 64 |
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
| 65 |
buf.seek(0)
|
|
|
|
| 68 |
return Image.open(buf)
|
| 69 |
|
| 70 |
|
| 71 |
+
# //////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
|
| 74 |
def render_confidence_progress(
|
| 75 |
probs: np.ndarray,
|
| 76 |
labels: list[str] = ["Stable", "Weathered"],
|
|
|
|
| 116 |
st.markdown("")
|
| 117 |
|
| 118 |
|
| 119 |
+
from typing import Optional
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def render_kv_grid(d: Optional[dict] = None, ncols: int = 2):
|
| 123 |
if d is None:
|
| 124 |
d = {}
|
| 125 |
if not d:
|
|
|
|
| 131 |
st.caption(f"**{k}:** {v}")
|
| 132 |
|
| 133 |
|
| 134 |
+
# //////////////////////////////////////////
|
| 135 |
+
|
| 136 |
+
|
| 137 |
def render_model_meta(model_choice: str):
|
| 138 |
info = MODEL_CONFIG.get(model_choice, {})
|
| 139 |
emoji = info.get("emoji", "")
|
|
|
|
| 151 |
st.caption(desc)
|
| 152 |
|
| 153 |
|
| 154 |
+
# //////////////////////////////////////////
|
| 155 |
+
|
| 156 |
+
|
| 157 |
def get_confidence_description(logit_margin):
|
| 158 |
"""Get human-readable confidence description"""
|
| 159 |
if logit_margin > 1000:
|
|
|
|
| 166 |
return "LOW", "🔴"
|
| 167 |
|
| 168 |
|
| 169 |
+
# //////////////////////////////////////////
|
| 170 |
+
|
| 171 |
+
|
| 172 |
def render_sidebar():
|
| 173 |
with st.sidebar:
|
| 174 |
# Header
|
| 175 |
st.header("AI-Driven Polymer Classification")
|
| 176 |
st.caption(
|
| 177 |
+
"Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Modality Selection
|
| 181 |
+
st.markdown("##### Spectroscopy Modality")
|
| 182 |
+
modality = st.selectbox(
|
| 183 |
+
"Choose Modality",
|
| 184 |
+
["raman", "ftir"],
|
| 185 |
+
index=0,
|
| 186 |
+
key="modality_select",
|
| 187 |
+
format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}",
|
| 188 |
)
|
| 189 |
+
|
| 190 |
+
# Display modality info
|
| 191 |
+
if modality == "ftir":
|
| 192 |
+
st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction")
|
| 193 |
+
else:
|
| 194 |
+
st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing")
|
| 195 |
+
|
| 196 |
+
# Model selection
|
| 197 |
+
st.markdown("##### AI Model Selection")
|
| 198 |
model_labels = [
|
| 199 |
f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys()
|
| 200 |
]
|
|
|
|
| 206 |
)
|
| 207 |
model_choice = selected_label.split(" ", 1)[1]
|
| 208 |
|
| 209 |
+
# Compact metadata directly under dropdown
|
| 210 |
render_model_meta(model_choice)
|
| 211 |
|
| 212 |
+
# Collapsed info to reduce clutter
|
| 213 |
with st.expander("About This App", icon=":material/info:", expanded=False):
|
| 214 |
st.markdown(
|
| 215 |
"""
|
|
|
|
| 217 |
|
| 218 |
**Purpose**: Classify polymer degradation using AI<br>
|
| 219 |
**Input**: Raman spectroscopy .txt files<br>
|
| 220 |
+
**Models**: CNN architectures for classification<br>
|
| 221 |
+
**Modalities**: Raman and FTIR spectroscopy support<br>
|
| 222 |
+
**Features**: Multi-model comparison and analysis<br>
|
| 223 |
|
| 224 |
|
| 225 |
**Contributors**<br>
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
| 244 |
+
# //////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
def render_input_column():
|
| 246 |
st.markdown("##### Data Input")
|
| 247 |
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
# == Input Mode Logic ==
|
|
|
|
|
|
|
| 257 |
if mode == "Upload File":
|
| 258 |
upload_key = st.session_state["current_upload_key"]
|
| 259 |
up = st.file_uploader(
|
| 260 |
+
"Upload spectrum file (.txt, .csv, .json)",
|
| 261 |
+
type=["txt", "csv", "json"],
|
| 262 |
+
help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format",
|
| 263 |
key=upload_key, # ← versioned key
|
| 264 |
)
|
| 265 |
|
| 266 |
+
# Process change immediately
|
| 267 |
if up is not None:
|
| 268 |
raw = up.read()
|
| 269 |
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 270 |
+
# only reparse if its a different file|source
|
| 271 |
if (
|
| 272 |
st.session_state.get("filename") != getattr(up, "name", None)
|
| 273 |
or st.session_state.get("input_source") != "upload"
|
|
|
|
| 283 |
st.session_state["status_type"] = "success"
|
| 284 |
reset_results("New file uploaded")
|
| 285 |
|
| 286 |
+
# Batch Upload tab
|
| 287 |
elif mode == "Batch Upload":
|
| 288 |
st.session_state["batch_mode"] = True
|
|
|
|
| 289 |
# Use a versioned key to ensure the file uploader resets properly.
|
| 290 |
batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}"
|
| 291 |
uploaded_files = st.file_uploader(
|
| 292 |
+
"Upload multiple spectrum files (.txt, .csv, .json)",
|
| 293 |
+
type=["txt", "csv", "json"],
|
| 294 |
accept_multiple_files=True,
|
| 295 |
+
help="Upload spectroscopy files in TXT, CSV, or JSON format.",
|
| 296 |
key=batch_upload_key,
|
| 297 |
)
|
|
|
|
| 298 |
|
| 299 |
if uploaded_files:
|
|
|
|
| 300 |
# Use a dictionary to keep only unique files based on name and size
|
| 301 |
unique_files = {(file.name, file.size): file for file in uploaded_files}
|
| 302 |
unique_file_list = list(unique_files.values())
|
|
|
|
| 306 |
|
| 307 |
# Optionally, inform the user that duplicates were removed
|
| 308 |
if num_uploaded > num_unique:
|
| 309 |
+
st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.")
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# Use the unique list
|
| 312 |
st.session_state["batch_files"] = unique_file_list
|
|
|
|
| 314 |
f"{num_unique} ready for batch analysis"
|
| 315 |
)
|
| 316 |
st.session_state["status_type"] = "success"
|
|
|
|
| 317 |
else:
|
| 318 |
st.session_state["batch_files"] = []
|
| 319 |
# This check prevents resetting the status if files are already staged
|
|
|
|
| 323 |
)
|
| 324 |
st.session_state["status_type"] = "info"
|
| 325 |
|
| 326 |
+
# Sample tab
|
| 327 |
elif mode == "Sample Data":
|
| 328 |
st.session_state["batch_mode"] = False
|
| 329 |
sample_files = get_sample_files()
|
|
|
|
| 352 |
else:
|
| 353 |
st.info(msg)
|
| 354 |
|
|
|
|
|
|
|
|
|
|
| 355 |
# Safely get model choice from session state
|
| 356 |
model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1]
|
| 357 |
model = load_model(model_choice)
|
|
|
|
| 407 |
st.error(f"Error processing spectrum data: {e}")
|
| 408 |
|
| 409 |
|
| 410 |
+
# //////////////////////////////////////////
|
| 411 |
|
| 412 |
|
| 413 |
def render_results_column():
|
|
|
|
| 429 |
filename = st.session_state.get("filename", "Unknown")
|
| 430 |
|
| 431 |
if all(v is not None for v in [x_raw, y_raw, y_resampled]):
|
| 432 |
+
# Run inference
|
| 433 |
if y_resampled is None:
|
| 434 |
raise ValueError(
|
| 435 |
"y_resampled is None. Ensure spectrum data is properly resampled before proceeding."
|
|
|
|
| 456 |
f"Inference completed in {inference_time:.2f}s, prediction: {prediction}"
|
| 457 |
)
|
| 458 |
|
| 459 |
+
# Get ground truth
|
| 460 |
true_label_idx = label_file(filename)
|
| 461 |
true_label_str = (
|
| 462 |
LABEL_MAP.get(true_label_idx, "Unknown")
|
| 463 |
if true_label_idx is not None
|
| 464 |
else "Unknown"
|
| 465 |
)
|
| 466 |
+
# Get prediction
|
| 467 |
predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}")
|
| 468 |
|
| 469 |
# Enhanced confidence calculation
|
|
|
|
| 474 |
)
|
| 475 |
confidence_desc = confidence_level
|
| 476 |
else:
|
| 477 |
+
# Fallback to legacy method
|
| 478 |
logit_margin = abs(
|
| 479 |
(logits_list[0] - logits_list[1])
|
| 480 |
if logits_list is not None and len(logits_list) >= 2
|
|
|
|
| 506 |
},
|
| 507 |
)
|
| 508 |
|
| 509 |
+
# Precompute Stats
|
| 510 |
model_choice = (
|
| 511 |
st.session_state.get("model_select", "").split(" ", 1)[1]
|
| 512 |
if "model_select" in st.session_state
|
|
|
|
| 524 |
if os.path.exists(model_path)
|
| 525 |
else "N/A"
|
| 526 |
)
|
|
|
|
| 527 |
|
| 528 |
start_render = time.time()
|
| 529 |
|
|
|
|
| 608 |
""",
|
| 609 |
unsafe_allow_html=True,
|
| 610 |
)
|
|
|
|
| 611 |
|
| 612 |
st.divider()
|
| 613 |
|
| 614 |
+
# METADATA FOOTER
|
|
|
|
| 615 |
st.caption(
|
| 616 |
f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**."
|
| 617 |
)
|
|
|
|
|
|
|
| 618 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 619 |
|
| 620 |
elif active_tab == "Technical":
|
|
|
|
| 932 |
"""
|
| 933 |
)
|
| 934 |
else:
|
| 935 |
+
# Getting Started
|
| 936 |
st.markdown(
|
| 937 |
"""
|
| 938 |
##### How to Get Started
|
|
|
|
| 962 |
- 🏭 Quality control in manufacturing
|
| 963 |
"""
|
| 964 |
)
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# //////////////////////////////////////////
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def render_comparison_tab():
|
| 971 |
+
"""Render the multi-model comparison interface"""
|
| 972 |
+
import streamlit as st
|
| 973 |
+
import matplotlib.pyplot as plt
|
| 974 |
+
from models.registry import choices, validate_model_list
|
| 975 |
+
from utils.results_manager import ResultsManager
|
| 976 |
+
from core_logic import get_sample_files, run_inference, parse_spectrum_data
|
| 977 |
+
from utils.preprocessing import preprocess_spectrum
|
| 978 |
+
from utils.multifile import parse_spectrum_data
|
| 979 |
+
import numpy as np
|
| 980 |
+
import time
|
| 981 |
+
|
| 982 |
+
st.markdown("### Multi-Model Comparison Analysis")
|
| 983 |
+
st.markdown(
|
| 984 |
+
"Compare predictions across different AI models for comprehensive analysis."
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# Model selection for comparison
|
| 988 |
+
st.markdown("##### Select Models for Comparison")
|
| 989 |
+
|
| 990 |
+
available_models = choices()
|
| 991 |
+
selected_models = st.multiselect(
|
| 992 |
+
"Choose models to compare",
|
| 993 |
+
available_models,
|
| 994 |
+
default=(
|
| 995 |
+
available_models[:2] if len(available_models) >= 2 else available_models
|
| 996 |
+
),
|
| 997 |
+
help="Select 2 or more models to compare their predictions side-by-side",
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
if len(selected_models) < 2:
|
| 1001 |
+
st.warning("⚠️ Please select at least 2 models for comparison.")
|
| 1002 |
+
|
| 1003 |
+
# Input selection for comparison
|
| 1004 |
+
col1, col2 = st.columns([1, 1.5])
|
| 1005 |
+
|
| 1006 |
+
with col1:
|
| 1007 |
+
st.markdown("###### Input Data")
|
| 1008 |
+
|
| 1009 |
+
# File upload for comparison
|
| 1010 |
+
comparison_file = st.file_uploader(
|
| 1011 |
+
"Upload spectrum for comparison",
|
| 1012 |
+
type=["txt", "csv", "json"],
|
| 1013 |
+
key="comparison_file_upload",
|
| 1014 |
+
help="Upload a spectrum file to test across all selected models",
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
# Or select sample data
|
| 1018 |
+
selected_sample = None # Initialize with a default value
|
| 1019 |
+
sample_files = get_sample_files()
|
| 1020 |
+
if sample_files:
|
| 1021 |
+
sample_options = ["-- Select Sample --"] + [p.name for p in sample_files]
|
| 1022 |
+
selected_sample = st.selectbox(
|
| 1023 |
+
"Or choose sample data", sample_options, key="comparison_sample_select"
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
# Get modality from session state
|
| 1027 |
+
modality = st.session_state.get("modality_select", "raman")
|
| 1028 |
+
st.info(f"Using {modality.upper()} preprocessing parameters")
|
| 1029 |
+
|
| 1030 |
+
# Run comparison button
|
| 1031 |
+
run_comparison = st.button(
|
| 1032 |
+
"Run Multi-Model Comparison",
|
| 1033 |
+
type="primary",
|
| 1034 |
+
disabled=not (
|
| 1035 |
+
comparison_file
|
| 1036 |
+
or (sample_files and selected_sample != "-- Select Sample --")
|
| 1037 |
+
),
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
with col2:
|
| 1041 |
+
st.markdown("###### Comparison Results")
|
| 1042 |
+
|
| 1043 |
+
if run_comparison:
|
| 1044 |
+
# Determine input source
|
| 1045 |
+
input_text = None
|
| 1046 |
+
filename = "unknown"
|
| 1047 |
+
|
| 1048 |
+
if comparison_file:
|
| 1049 |
+
raw = comparison_file.read()
|
| 1050 |
+
input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
| 1051 |
+
filename = comparison_file.name
|
| 1052 |
+
elif sample_files and selected_sample != "-- Select Sample --":
|
| 1053 |
+
sample_path = next(p for p in sample_files if p.name == selected_sample)
|
| 1054 |
+
with open(sample_path, "r") as f:
|
| 1055 |
+
input_text = f.read()
|
| 1056 |
+
filename = selected_sample
|
| 1057 |
+
|
| 1058 |
+
if input_text:
|
| 1059 |
+
try:
|
| 1060 |
+
# Parse spectrum data
|
| 1061 |
+
x_raw, y_raw = parse_spectrum_data(
|
| 1062 |
+
str(input_text), filename or "unknown_filename"
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# Store results
|
| 1066 |
+
comparison_results = {}
|
| 1067 |
+
processing_times = {}
|
| 1068 |
+
|
| 1069 |
+
progress_bar = st.progress(0)
|
| 1070 |
+
status_text = st.empty()
|
| 1071 |
+
|
| 1072 |
+
for i, model_name in enumerate(selected_models):
|
| 1073 |
+
status_text.text(f"Running inference with {model_name}...")
|
| 1074 |
+
|
| 1075 |
+
start_time = time.time()
|
| 1076 |
+
|
| 1077 |
+
# Preprocess spectrum with modality-specific parameters
|
| 1078 |
+
_, y_processed = preprocess_spectrum(
|
| 1079 |
+
x_raw, y_raw, modality=modality, target_len=500
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# Run inference
|
| 1083 |
+
prediction, logits_list, probs, inference_time, logits = (
|
| 1084 |
+
run_inference(y_processed, model_name)
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
processing_time = time.time() - start_time
|
| 1088 |
+
|
| 1089 |
+
if prediction is not None:
|
| 1090 |
+
# Map prediction to class name
|
| 1091 |
+
class_names = ["Stable", "Weathered"]
|
| 1092 |
+
predicted_class = (
|
| 1093 |
+
class_names[int(prediction)]
|
| 1094 |
+
if prediction < len(class_names)
|
| 1095 |
+
else f"Class_{prediction}"
|
| 1096 |
+
)
|
| 1097 |
+
confidence = (
|
| 1098 |
+
max(probs)
|
| 1099 |
+
if probs is not None and len(probs) > 0
|
| 1100 |
+
else 0.0
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
comparison_results[model_name] = {
|
| 1104 |
+
"prediction": prediction,
|
| 1105 |
+
"predicted_class": predicted_class,
|
| 1106 |
+
"confidence": confidence,
|
| 1107 |
+
"probs": probs if probs is not None else [],
|
| 1108 |
+
"logits": (
|
| 1109 |
+
logits_list if logits_list is not None else []
|
| 1110 |
+
),
|
| 1111 |
+
"processing_time": processing_time,
|
| 1112 |
+
}
|
| 1113 |
+
processing_times[model_name] = processing_time
|
| 1114 |
+
|
| 1115 |
+
progress_bar.progress((i + 1) / len(selected_models))
|
| 1116 |
+
|
| 1117 |
+
status_text.text("Comparison complete!")
|
| 1118 |
+
|
| 1119 |
+
# Display results
|
| 1120 |
+
if comparison_results:
|
| 1121 |
+
st.markdown("###### Model Predictions")
|
| 1122 |
+
|
| 1123 |
+
# Create comparison table
|
| 1124 |
+
import pandas as pd
|
| 1125 |
+
|
| 1126 |
+
table_data = []
|
| 1127 |
+
for model_name, result in comparison_results.items():
|
| 1128 |
+
row = {
|
| 1129 |
+
"Model": model_name,
|
| 1130 |
+
"Prediction": result["predicted_class"],
|
| 1131 |
+
"Confidence": f"{result['confidence']:.3f}",
|
| 1132 |
+
"Processing Time (s)": f"{result['processing_time']:.3f}",
|
| 1133 |
+
}
|
| 1134 |
+
table_data.append(row)
|
| 1135 |
+
|
| 1136 |
+
df = pd.DataFrame(table_data)
|
| 1137 |
+
st.dataframe(df, use_container_width=True)
|
| 1138 |
+
|
| 1139 |
+
# Show confidence comparison
|
| 1140 |
+
st.markdown("##### Confidence Comparison")
|
| 1141 |
+
conf_col1, conf_col2 = st.columns(2)
|
| 1142 |
+
|
| 1143 |
+
with conf_col1:
|
| 1144 |
+
# Bar chart of confidences
|
| 1145 |
+
models = list(comparison_results.keys())
|
| 1146 |
+
confidences = [
|
| 1147 |
+
comparison_results[m]["confidence"] for m in models
|
| 1148 |
+
]
|
| 1149 |
+
|
| 1150 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 1151 |
+
bars = ax.bar(
|
| 1152 |
+
models,
|
| 1153 |
+
confidences,
|
| 1154 |
+
alpha=0.7,
|
| 1155 |
+
color=["steelblue", "orange", "green", "red"][
|
| 1156 |
+
: len(models)
|
| 1157 |
+
],
|
| 1158 |
+
)
|
| 1159 |
+
ax.set_ylabel("Confidence")
|
| 1160 |
+
ax.set_title("Model Confidence Comparison")
|
| 1161 |
+
ax.set_ylim(0, 1)
|
| 1162 |
+
plt.xticks(rotation=45)
|
| 1163 |
+
|
| 1164 |
+
# Add value labels on bars
|
| 1165 |
+
for bar, conf in zip(bars, confidences):
|
| 1166 |
+
height = bar.get_height()
|
| 1167 |
+
ax.text(
|
| 1168 |
+
bar.get_x() + bar.get_width() / 2.0,
|
| 1169 |
+
height + 0.01,
|
| 1170 |
+
f"{conf:.3f}",
|
| 1171 |
+
ha="center",
|
| 1172 |
+
va="bottom",
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
plt.tight_layout()
|
| 1176 |
+
st.pyplot(fig)
|
| 1177 |
+
|
| 1178 |
+
with conf_col2:
|
| 1179 |
+
# Agreement analysis
|
| 1180 |
+
predictions = [
|
| 1181 |
+
comparison_results[m]["prediction"] for m in models
|
| 1182 |
+
]
|
| 1183 |
+
unique_predictions = set(predictions)
|
| 1184 |
+
|
| 1185 |
+
if len(unique_predictions) == 1:
|
| 1186 |
+
st.success("✅ All models agree on the prediction!")
|
| 1187 |
+
else:
|
| 1188 |
+
st.warning("⚠️ Models disagree on the prediction")
|
| 1189 |
+
|
| 1190 |
+
# Show prediction distribution
|
| 1191 |
+
from collections import Counter
|
| 1192 |
+
|
| 1193 |
+
pred_counts = Counter(predictions)
|
| 1194 |
+
|
| 1195 |
+
st.markdown("**Prediction Distribution:**")
|
| 1196 |
+
for pred, count in pred_counts.items():
|
| 1197 |
+
class_name = (
|
| 1198 |
+
["Stable", "Weathered"][pred]
|
| 1199 |
+
if pred < 2
|
| 1200 |
+
else f"Class_{pred}"
|
| 1201 |
+
)
|
| 1202 |
+
percentage = (count / len(predictions)) * 100
|
| 1203 |
+
st.write(
|
| 1204 |
+
f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
# Performance metrics
|
| 1208 |
+
st.markdown("##### Performance Metrics")
|
| 1209 |
+
perf_col1, perf_col2 = st.columns(2)
|
| 1210 |
+
|
| 1211 |
+
with perf_col1:
|
| 1212 |
+
avg_time = np.mean(list(processing_times.values()))
|
| 1213 |
+
fastest_model = min(
|
| 1214 |
+
processing_times.keys(),
|
| 1215 |
+
key=lambda k: processing_times[k],
|
| 1216 |
+
)
|
| 1217 |
+
slowest_model = max(
|
| 1218 |
+
processing_times.keys(),
|
| 1219 |
+
key=lambda k: processing_times[k],
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
st.metric("Average Processing Time", f"{avg_time:.3f}s")
|
| 1223 |
+
st.metric(
|
| 1224 |
+
"Fastest Model",
|
| 1225 |
+
f"{fastest_model}",
|
| 1226 |
+
f"{processing_times[fastest_model]:.3f}s",
|
| 1227 |
+
)
|
| 1228 |
+
st.metric(
|
| 1229 |
+
"Slowest Model",
|
| 1230 |
+
f"{slowest_model}",
|
| 1231 |
+
f"{processing_times[slowest_model]:.3f}s",
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
with perf_col2:
|
| 1235 |
+
most_confident = max(
|
| 1236 |
+
comparison_results.keys(),
|
| 1237 |
+
key=lambda k: comparison_results[k]["confidence"],
|
| 1238 |
+
)
|
| 1239 |
+
least_confident = min(
|
| 1240 |
+
comparison_results.keys(),
|
| 1241 |
+
key=lambda k: comparison_results[k]["confidence"],
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
st.metric(
|
| 1245 |
+
"Most Confident",
|
| 1246 |
+
f"{most_confident}",
|
| 1247 |
+
f"{comparison_results[most_confident]['confidence']:.3f}",
|
| 1248 |
+
)
|
| 1249 |
+
st.metric(
|
| 1250 |
+
"Least Confident",
|
| 1251 |
+
f"{least_confident}",
|
| 1252 |
+
f"{comparison_results[least_confident]['confidence']:.3f}",
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
# Store results in session state for potential export
|
| 1256 |
+
# Store results in session state for potential export
|
| 1257 |
+
st.session_state["last_comparison_results"] = {
|
| 1258 |
+
"filename": filename,
|
| 1259 |
+
"modality": modality,
|
| 1260 |
+
"models": comparison_results,
|
| 1261 |
+
"summary": {
|
| 1262 |
+
"agreement": len(unique_predictions) == 1,
|
| 1263 |
+
"avg_processing_time": avg_time,
|
| 1264 |
+
"fastest_model": fastest_model,
|
| 1265 |
+
"most_confident": most_confident,
|
| 1266 |
+
},
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
except Exception as e:
|
| 1270 |
+
st.error(f"Error during comparison: {str(e)}")
|
| 1271 |
+
|
| 1272 |
+
# Show recent comparison results if available
|
| 1273 |
+
elif "last_comparison_results" in st.session_state:
|
| 1274 |
+
st.info(
|
| 1275 |
+
"Previous comparison results available. Upload a new file or select a sample to run new comparison."
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
# Show comparison history
|
| 1279 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 1280 |
+
if comparison_stats:
|
| 1281 |
+
st.markdown("#### Comparison History")
|
| 1282 |
+
|
| 1283 |
+
with st.expander("View detailed comparison statistics", expanded=False):
|
| 1284 |
+
# Show model statistics table
|
| 1285 |
+
stats_data = []
|
| 1286 |
+
for model_name, stats in comparison_stats.items():
|
| 1287 |
+
row = {
|
| 1288 |
+
"Model": model_name,
|
| 1289 |
+
"Total Predictions": stats["total_predictions"],
|
| 1290 |
+
"Avg Confidence": f"{stats['avg_confidence']:.3f}",
|
| 1291 |
+
"Avg Processing Time": f"{stats['avg_processing_time']:.3f}s",
|
| 1292 |
+
"Accuracy": (
|
| 1293 |
+
f"{stats['accuracy']:.3f}"
|
| 1294 |
+
if stats["accuracy"] is not None
|
| 1295 |
+
else "N/A"
|
| 1296 |
+
),
|
| 1297 |
+
}
|
| 1298 |
+
stats_data.append(row)
|
| 1299 |
+
|
| 1300 |
+
if stats_data:
|
| 1301 |
+
import pandas as pd
|
| 1302 |
+
|
| 1303 |
+
stats_df = pd.DataFrame(stats_data)
|
| 1304 |
+
st.dataframe(stats_df, use_container_width=True)
|
| 1305 |
+
|
| 1306 |
+
# Show agreement matrix if multiple models
|
| 1307 |
+
agreement_matrix = ResultsManager.get_agreement_matrix()
|
| 1308 |
+
if not agreement_matrix.empty and len(agreement_matrix) > 1:
|
| 1309 |
+
st.markdown("**Model Agreement Matrix**")
|
| 1310 |
+
st.dataframe(agreement_matrix.round(3), use_container_width=True)
|
| 1311 |
+
|
| 1312 |
+
# Plot agreement heatmap
|
| 1313 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 1314 |
+
im = ax.imshow(
|
| 1315 |
+
agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
# Add text annotations
|
| 1319 |
+
for i in range(len(agreement_matrix)):
|
| 1320 |
+
for j in range(len(agreement_matrix.columns)):
|
| 1321 |
+
text = ax.text(
|
| 1322 |
+
j,
|
| 1323 |
+
i,
|
| 1324 |
+
f"{agreement_matrix.iloc[i, j]:.2f}",
|
| 1325 |
+
ha="center",
|
| 1326 |
+
va="center",
|
| 1327 |
+
color="black",
|
| 1328 |
+
)
|
| 1329 |
+
|
| 1330 |
+
ax.set_xticks(range(len(agreement_matrix.columns)))
|
| 1331 |
+
ax.set_yticks(range(len(agreement_matrix)))
|
| 1332 |
+
ax.set_xticklabels(agreement_matrix.columns, rotation=45)
|
| 1333 |
+
ax.set_yticklabels(agreement_matrix.index)
|
| 1334 |
+
ax.set_title("Model Agreement Matrix")
|
| 1335 |
+
|
| 1336 |
+
plt.colorbar(im, ax=ax, label="Agreement Rate")
|
| 1337 |
+
plt.tight_layout()
|
| 1338 |
+
st.pyplot(fig)
|
| 1339 |
+
|
| 1340 |
+
# Export functionality
|
| 1341 |
+
if "last_comparison_results" in st.session_state:
|
| 1342 |
+
st.markdown("##### Export Results")
|
| 1343 |
+
|
| 1344 |
+
export_col1, export_col2 = st.columns(2)
|
| 1345 |
+
|
| 1346 |
+
with export_col1:
|
| 1347 |
+
if st.button("📥 Export Comparison (JSON)"):
|
| 1348 |
+
import json
|
| 1349 |
+
|
| 1350 |
+
results = st.session_state["last_comparison_results"]
|
| 1351 |
+
json_str = json.dumps(results, indent=2, default=str)
|
| 1352 |
+
st.download_button(
|
| 1353 |
+
label="Download JSON",
|
| 1354 |
+
data=json_str,
|
| 1355 |
+
file_name=f"comparison_{results['filename'].split('.')[0]}.json",
|
| 1356 |
+
mime="application/json",
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
with export_col2:
|
| 1360 |
+
if st.button("📊 Export Full Report"):
|
| 1361 |
+
report = ResultsManager.export_comparison_report()
|
| 1362 |
+
st.download_button(
|
| 1363 |
+
label="Download Full Report",
|
| 1364 |
+
data=report,
|
| 1365 |
+
file_name="model_comparison_report.json",
|
| 1366 |
+
mime="application/json",
|
| 1367 |
+
)
|
| 1368 |
+
|
| 1369 |
+
|
| 1370 |
+
# //////////////////////////////////////////
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
def render_performance_tab():
|
| 1374 |
+
"""Render the performance tracking and analysis tab."""
|
| 1375 |
+
from utils.performance_tracker import display_performance_dashboard
|
| 1376 |
+
|
| 1377 |
+
display_performance_dashboard()
|