diff --git "a/modules/ui_components.py" "b/modules/ui_components.py" --- "a/modules/ui_components.py" +++ "b/modules/ui_components.py" @@ -1,950 +1,1601 @@ -import os -import torch -import streamlit as st -import hashlib -import io -from PIL import Image -import numpy as np -import matplotlib.pyplot as plt -from typing import Union -import time -from config import MODEL_CONFIG, TARGET_LEN, LABEL_MAP -from modules.callbacks import ( - on_model_change, - on_input_mode_change, - on_sample_change, - reset_ephemeral_state, - log_message, - clear_batch_results, -) -from core_logic import ( - get_sample_files, - load_model, - run_inference, - parse_spectrum_data, - label_file, -) -from modules.callbacks import reset_results -from utils.results_manager import ResultsManager -from utils.confidence import calculate_softmax_confidence -from utils.multifile import process_multiple_files, display_batch_results -from utils.preprocessing import resample_spectrum - - -def load_css(file_path): - with open(file_path, encoding="utf-8") as f: - st.markdown(f"", unsafe_allow_html=True) - - -@st.cache_data -def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None): - """Create spectrum visualization plot""" - fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100) - - # == Raw spectrum == - ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1) - ax[0].set_title("Raw Input Spectrum") - ax[0].set_xlabel("Wavenumber (cm⁻¹)") - ax[0].set_ylabel("Intensity") - ax[0].grid(True, alpha=0.3) - ax[0].legend() - - # == Resampled spectrum == - ax[1].plot( - x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1 - ) - ax[1].set_title(f"Resampled ({len(y_resampled)} points)") - ax[1].set_xlabel("Wavenumber (cm⁻¹)") - ax[1].set_ylabel("Intensity") - ax[1].grid(True, alpha=0.3) - ax[1].legend() - - fig.tight_layout() - # == Convert to image == - buf = io.BytesIO() - plt.savefig(buf, format="png", bbox_inches="tight", dpi=100) - buf.seek(0) - plt.close(fig) # Prevent memory leaks - - return Image.open(buf) - - -def render_confidence_progress( - probs: np.ndarray, - labels: list[str] = ["Stable", "Weathered"], - highlight_idx: Union[int, None] = None, - side_by_side: bool = True, -): - """Render Streamlit native progress bars with scientific formatting.""" - p = np.asarray(probs, dtype=float) - p = np.clip(p, 0.0, 1.0) - - if side_by_side: - cols = st.columns(len(labels)) - for i, (lbl, val, col) in enumerate(zip(labels, p, cols)): - with col: - is_highlighted = highlight_idx is not None and i == highlight_idx - label_text = f"**{lbl}**" if is_highlighted else lbl - st.markdown(f"{label_text}: {val*100:.1f}%") - st.progress(int(round(val * 100))) - else: - # Vertical layout for better readability - for i, (lbl, val) in enumerate(zip(labels, p)): - is_highlighted = highlight_idx is not None and i == highlight_idx - - # Create a container for each probability - with st.container(): - col1, col2 = st.columns([3, 1]) - with col1: - if is_highlighted: - st.markdown(f"**{lbl}** ← Predicted") - else: - st.markdown(f"{lbl}") - with col2: - st.metric(label="", value=f"{val*100:.1f}%", delta=None) - - # Progress bar with conditional styling - if is_highlighted: - st.progress(int(round(val * 100))) - st.caption("🎯 **Model Prediction**") - else: - st.progress(int(round(val * 100))) - - if i < len(labels) - 1: # Add spacing between items - st.markdown("") - - -def render_kv_grid(d: dict = {}, ncols: int = 2): - if d is None: - d = {} - if not d: - return - items = list(d.items()) - cols = st.columns(ncols) - for i, (k, v) in enumerate(items): - with cols[i % ncols]: - st.caption(f"**{k}:** {v}") - - -def render_model_meta(model_choice: str): - info = MODEL_CONFIG.get(model_choice, {}) - emoji = info.get("emoji", "") - desc = info.get("description", "").strip() - acc = info.get("accuracy", "-") - f1 = info.get("f1", "-") - - st.caption(f"{emoji} **Model Snapshot** - {model_choice}") - cols = st.columns(2) - with cols[0]: - st.metric("Accuracy", acc) - with cols[1]: - st.metric("F1 Score", f1) - if desc: - st.caption(desc) - - -def get_confidence_description(logit_margin): - """Get human-readable confidence description""" - if logit_margin > 1000: - return "VERY HIGH", "🟢" - elif logit_margin > 250: - return "HIGH", "🟡" - elif logit_margin > 100: - return "MODERATE", "🟠" - else: - return "LOW", "🔴" - - -def render_sidebar(): - with st.sidebar: - # Header - st.header("AI-Driven Polymer Classification") - st.caption( - "Predict polymer degradation (Stable vs Weathered) from Raman spectra using validated CNN models. — v0.1" - ) - model_labels = [ - f"{MODEL_CONFIG[name]['emoji']} {name}" for name in MODEL_CONFIG.keys() - ] - selected_label = st.selectbox( - "Choose AI Model", - model_labels, - key="model_select", - on_change=on_model_change, - ) - model_choice = selected_label.split(" ", 1)[1] - - # ===Compact metadata directly under dropdown=== - render_model_meta(model_choice) - - # ===Collapsed info to reduce clutter=== - with st.expander("About This App", icon=":material/info:", expanded=False): - st.markdown( - """ - **AI-Driven Polymer Aging Prediction and Classification** - - **Purpose**: Classify polymer degradation using AI
- **Input**: Raman spectroscopy .txt files
- **Models**: CNN architectures for binary classification
- **Next**: More trained CNNs in evaluation pipeline
- - - **Contributors**
- - Dr. Sanmukh Kuppannagari (Mentor)
- - Dr. Metin Karailyan (Mentor)
- - Jaser Hasan (Author)
- - - **Links**
- [HF Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
- [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling) - - - **Citation Figure2CNN (baseline)** - Neo et al., 2023, *Resour. Conserv. Recycl.*, 188, 106718. - [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718) - """, - unsafe_allow_html=True, - ) - - -# col1 goes here - -# In modules/ui_components.py - - -def render_input_column(): - st.markdown("##### Data Input") - - mode = st.radio( - "Input mode", - ["Upload File", "Batch Upload", "Sample Data"], - key="input_mode", - horizontal=True, - on_change=on_input_mode_change, - ) - - # == Input Mode Logic == - # ... (The if/elif/else block for Upload, Batch, and Sample modes remains exactly the same) ... - # ==Upload tab== - if mode == "Upload File": - upload_key = st.session_state["current_upload_key"] - up = st.file_uploader( - "Upload Raman spectrum (.txt)", - type="txt", - help="Upload a text file with wavenumber and intensity columns", - key=upload_key, # ← versioned key - ) - - # ==Process change immediately (no on_change; simpler & reliable)== - if up is not None: - raw = up.read() - text = raw.decode("utf-8") if isinstance(raw, bytes) else raw - # == only reparse if its a different file|source == - if ( - st.session_state.get("filename") != getattr(up, "name", None) - or st.session_state.get("input_source") != "upload" - ): - st.session_state["input_text"] = text - st.session_state["filename"] = getattr(up, "name", None) - st.session_state["input_source"] = "upload" - # Ensure single file mode - st.session_state["batch_mode"] = False - st.session_state["status_message"] = ( - f"File '{st.session_state['filename']}' ready for analysis" - ) - st.session_state["status_type"] = "success" - reset_results("New file uploaded") - - # ==Batch Upload tab== - elif mode == "Batch Upload": - st.session_state["batch_mode"] = True - # --- START: BUG 1 & 3 FIX --- - # Use a versioned key to ensure the file uploader resets properly. - batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}" - uploaded_files = st.file_uploader( - "Upload multiple Raman spectrum files (.txt)", - type="txt", - accept_multiple_files=True, - help="Upload one or more text files with wavenumber and intensity columns.", - key=batch_upload_key, - ) - # --- END: BUG 1 & 3 FIX --- - - if uploaded_files: - # --- START: Bug 1 Fix --- - # Use a dictionary to keep only unique files based on name and size - unique_files = {(file.name, file.size): file for file in uploaded_files} - unique_file_list = list(unique_files.values()) - - num_uploaded = len(uploaded_files) - num_unique = len(unique_file_list) - - # Optionally, inform the user that duplicates were removed - if num_uploaded > num_unique: - st.info( - f"ℹ️ {num_uploaded - num_unique} duplicate file(s) were removed." - ) - - # Use the unique list - st.session_state["batch_files"] = unique_file_list - st.session_state["status_message"] = ( - f"{num_unique} ready for batch analysis" - ) - st.session_state["status_type"] = "success" - # --- END: Bug 1 Fix --- - else: - st.session_state["batch_files"] = [] - # This check prevents resetting the status if files are already staged - if not st.session_state.get("batch_files"): - st.session_state["status_message"] = ( - "No files selected for batch processing" - ) - st.session_state["status_type"] = "info" - - # ==Sample tab== - elif mode == "Sample Data": - st.session_state["batch_mode"] = False - sample_files = get_sample_files() - if sample_files: - options = ["-- Select Sample --"] + [p.name for p in sample_files] - sel = st.selectbox( - "Choose sample spectrum:", - options, - key="sample_select", - on_change=on_sample_change, - ) - if sel != "-- Select Sample --": - st.session_state["status_message"] = ( - f"📁 Sample '{sel}' ready for analysis" - ) - st.session_state["status_type"] = "success" - else: - st.info("No sample data available") - # == Status box (displays the message) == - msg = st.session_state.get("status_message", "Ready") - typ = st.session_state.get("status_type", "info") - if typ == "success": - st.success(msg) - elif typ == "error": - st.error(msg) - else: - st.info(msg) - - # --- DE-NESTED LOGIC STARTS HERE --- - # This code now runs on EVERY execution, guaranteeing the buttons will appear. - - # Safely get model choice from session state - model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1] - model = load_model(model_choice) - - # Determine if the app is ready for inference - is_batch_ready = st.session_state.get("batch_mode", False) and st.session_state.get( - "batch_files" - ) - is_single_ready = not st.session_state.get( - "batch_mode", False - ) and st.session_state.get("input_text") - inference_ready = (is_batch_ready or is_single_ready) and model is not None - # Store for other modules to access - st.session_state["inference_ready"] = inference_ready - - # Render buttons - with st.form("analysis_form", clear_on_submit=False): - submitted = st.form_submit_button( - "Run Analysis", type="primary", disabled=not inference_ready - ) - st.button( - "Reset All", - on_click=reset_ephemeral_state, - help="Clear all uploaded files and results.", - ) - - # Handle form submission - if submitted and inference_ready: - if st.session_state.get("batch_mode"): - batch_files = st.session_state.get("batch_files", []) - with st.spinner(f"Processing {len(batch_files)} files ..."): - st.session_state["batch_results"] = process_multiple_files( - uploaded_files=batch_files, - model_choice=model_choice, - load_model_func=load_model, - run_inference_func=run_inference, - label_file_func=label_file, - ) - else: - try: - x_raw, y_raw = parse_spectrum_data(st.session_state["input_text"]) - x_resampled, y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN) - st.session_state.update( - { - "x_raw": x_raw, - "y_raw": y_raw, - "x_resampled": x_resampled, - "y_resampled": y_resampled, - "inference_run_once": True, - } - ) - except (ValueError, TypeError) as e: - st.error(f"Error processing spectrum data: {e}") - - -# col2 goes here - - -def render_results_column(): - # Get the current mode and check for batch results - is_batch_mode = st.session_state.get("batch_mode", False) - has_batch_results = "batch_results" in st.session_state - - if is_batch_mode and has_batch_results: - # THEN render the main interactive dashboard from ResultsManager - ResultsManager.display_results_table() - - elif st.session_state.get("inference_run_once", False) and not is_batch_mode: - st.markdown("##### Analysis Results") - # Get data from session state - x_raw = st.session_state.get("x_raw") - y_raw = st.session_state.get("y_raw") - x_resampled = st.session_state.get("x_resampled") # ← NEW - y_resampled = st.session_state.get("y_resampled") - filename = st.session_state.get("filename", "Unknown") - - if all(v is not None for v in [x_raw, y_raw, y_resampled]): - # ===Run inference=== - if y_resampled is None: - raise ValueError( - "y_resampled is None. Ensure spectrum data is properly resampled before proceeding." - ) - cache_key = hashlib.md5( - f"{y_resampled.tobytes()}{st.session_state.get('model_select', 'Unknown').split(' ', 1)[1]}".encode() - ).hexdigest() - prediction, logits_list, probs, inference_time, logits = run_inference( - y_resampled, - ( - st.session_state.get("model_select", "").split(" ", 1)[1] - if "model_select" in st.session_state - else None - ), - _cache_key=cache_key, - ) - if prediction is None: - st.error( - "❌ Inference failed: Model not loaded. Please check that weights are available." - ) - st.stop() # prevents the rest of the code in this block from executing - - log_message( - f"Inference completed in {inference_time:.2f}s, prediction: {prediction}" - ) - - # ===Get ground truth=== - true_label_idx = label_file(filename) - true_label_str = ( - LABEL_MAP.get(true_label_idx, "Unknown") - if true_label_idx is not None - else "Unknown" - ) - # ===Get prediction=== - predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}") - - # Enhanced confidence calculation - if logits is not None: - # Use new softmax-based confidence - probs_np, max_confidence, confidence_level, confidence_emoji = ( - calculate_softmax_confidence(logits) - ) - confidence_desc = confidence_level - else: - # Fallback to legace method - logit_margin = abs( - (logits_list[0] - logits_list[1]) - if logits_list is not None and len(logits_list) >= 2 - else 0 - ) - confidence_desc, confidence_emoji = get_confidence_description( - logit_margin - ) - max_confidence = logit_margin / 10.0 # Normalize for display - probs_np = np.array([]) - - # Store result in results manager for single file too - ResultsManager.add_results( - filename=filename, - model_name=( - st.session_state.get("model_select", "").split(" ", 1)[1] - if "model_select" in st.session_state - else "Unknown" - ), - prediction=int(prediction), - predicted_class=predicted_class, - confidence=max_confidence, - logits=logits_list if logits_list else [], - ground_truth=true_label_idx if true_label_idx >= 0 else None, - processing_time=inference_time if inference_time is not None else 0.0, - metadata={ - "confidence_level": confidence_desc, - "confidence_emoji": confidence_emoji, - }, - ) - - # ===Precompute Stats=== - model_choice = ( - st.session_state.get("model_select", "").split(" ", 1)[1] - if "model_select" in st.session_state - else None - ) - if not model_choice: - st.error( - "⚠️ Model choice is not defined. Please select a model from the sidebar." - ) - st.stop() - model_path = MODEL_CONFIG[model_choice]["path"] - mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None - file_hash = ( - hashlib.md5(open(model_path, "rb").read()).hexdigest() - if os.path.exists(model_path) - else "N/A" - ) - # Removed unused variable 'input_tensor' - - start_render = time.time() - - active_tab = st.selectbox( - "View Results", - ["Details", "Technical", "Explanation"], - key="active_tab", # reuse the key you were managing manually - ) - - if active_tab == "Details": - st.markdown('
', unsafe_allow_html=True) - # Use a dynamic and informative title for the expander - with st.expander(f"Results for {filename}", expanded=True): - - # --- START: STREAMLINED METRICS --- - # A single, powerful row for the most important results. - key_metric_cols = st.columns(3) - - # Metric 1: The Prediction - key_metric_cols[0].metric("Prediction", predicted_class) - - # Metric 2: The Confidence (with level in tooltip) - confidence_icon = ( - "🟢" - if max_confidence >= 0.8 - else "🟡" if max_confidence >= 0.6 else "🔴" - ) - key_metric_cols[1].metric( - "Confidence", - f"{confidence_icon} {max_confidence:.1%}", - help=f"Confidence Level: {confidence_desc}", - ) - - # Metric 3: Ground Truth + Correctness (Combined) - if true_label_idx is not None: - is_correct = predicted_class == true_label_str - delta_text = "✅ Correct" if is_correct else "❌ Incorrect" - # Use delta_color="normal" to let the icon provide the visual cue - key_metric_cols[2].metric( - "Ground Truth", - true_label_str, - delta=delta_text, - delta_color="normal", - ) - else: - key_metric_cols[2].metric("Ground Truth", "N/A") - - st.divider() - # --- END: STREAMLINED METRICS --- - - # --- START: CONSOLIDATED CONFIDENCE ANALYSIS --- - st.markdown("##### Probability Breakdown") - - # This custom bullet bar logic remains as it is highly specific and valuable - def create_bullet_bar(probability, width=20, predicted=False): - filled_count = int(probability * width) - bar = "▤" * filled_count + "▢" * (width - filled_count) - percentage = f"{probability:.1%}" - pred_marker = "↩ Predicted" if predicted else "" - return f"{bar} {percentage} {pred_marker}" - - if probs is not None: - stable_prob, weathered_prob = probs[0], probs[1] - else: - st.error( - "❌ Probability values are missing. Please check the inference process." - ) - # Default values to prevent further errors - stable_prob, weathered_prob = 0.0, 0.0 - is_stable_predicted, is_weathered_predicted = ( - int(prediction) == 0 - ), (int(prediction) == 1) - - st.markdown( - f""" -
- Stable (Unweathered)
- {create_bullet_bar(stable_prob, predicted=is_stable_predicted)}

- Weathered (Degraded)
- {create_bullet_bar(weathered_prob, predicted=is_weathered_predicted)} -
- """, - unsafe_allow_html=True, - ) - # --- END: CONSOLIDATED CONFIDENCE ANALYSIS --- - - st.divider() - - # --- START: CLEAN METADATA FOOTER --- - # Secondary info is now a clean, single-line caption - st.caption( - f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**." - ) - # --- END: CLEAN METADATA FOOTER --- - - st.markdown("
", unsafe_allow_html=True) - - elif active_tab == "Technical": - with st.container(): - st.markdown("Technical Diagnostics") - - # Model performance metrics - with st.container(border=True): - st.markdown("##### **Model Performance**") - tech_col1, tech_col2 = st.columns(2) - - with tech_col1: - st.metric("Inference Time", f"{inference_time:.3f}s") - st.metric( - "Input Length", - f"{len(x_raw) if x_raw is not None else 0} points", - ) - st.metric("Resampled Length", f"{TARGET_LEN} points") - - with tech_col2: - st.metric( - "Model Loaded", - ( - "✅ Yes" - if st.session_state.get("model_loaded", False) - else "❌ No" - ), - ) - st.metric("Device", "CPU") - st.metric("Confidence Score", f"{max_confidence:.3f}") - - # Raw logits display - with st.container(border=True): - st.markdown("##### **Raw Model Outputs (Logits)**") - logits_df = { - "Class": ( - [ - LABEL_MAP.get(i, f"Class {i}") - for i in range(len(logits_list)) - ] - if logits_list is not None - else [] - ), - "Logit Value": ( - [f"{score:.4f}" for score in logits_list] - if logits_list is not None - else [] - ), - "Probability": ( - [f"{prob:.4f}" for prob in probs_np] - if logits_list is not None and len(probs_np) > 0 - else [] - ), - } - - # Display as a simple table format - for i, (cls, logit, prob) in enumerate( - zip( - logits_df["Class"], - logits_df["Logit Value"], - logits_df["Probability"], - ) - ): - col1, col2, col3 = st.columns([2, 1, 1]) - with col1: - if i == prediction: - st.markdown(f"**{cls}** ← Predicted") - else: - st.markdown(cls) - with col2: - st.caption(f"Logit: {logit}") - with col3: - st.caption(f"Prob: {prob}") - - # Spectrum statistics in organized sections - with st.container(border=True): - st.markdown("##### **Spectrum Analysis**") - spec_cols = st.columns(2) - - with spec_cols[0]: - st.markdown("**Original Spectrum:**") - render_kv_grid( - { - "Length": f"{len(x_raw) if x_raw is not None else 0} points", - "Range": ( - f"{min(x_raw):.1f} - {max(x_raw):.1f} cm⁻¹" - if x_raw is not None - else "N/A" - ), - "Min Intensity": ( - f"{min(y_raw):.2e}" - if y_raw is not None - else "N/A" - ), - "Max Intensity": ( - f"{max(y_raw):.2e}" - if y_raw is not None - else "N/A" - ), - }, - ncols=1, - ) - - with spec_cols[1]: - st.markdown("**Processed Spectrum:**") - render_kv_grid( - { - "Length": f"{TARGET_LEN} points", - "Resampling": "Linear interpolation", - "Normalization": "None", - "Input Shape": f"(1, 1, {TARGET_LEN})", - }, - ncols=1, - ) - - # Model information - with st.container(border=True): - st.markdown("##### **Model Information**") - model_info_cols = st.columns(2) - - with model_info_cols[0]: - render_kv_grid( - { - "Architecture": model_choice, - "Path": MODEL_CONFIG[model_choice]["path"], - "Weights Modified": ( - time.strftime( - "%Y-%m-%d %H:%M:%S", time.localtime(mtime) - ) - if mtime - else "N/A" - ), - }, - ncols=1, - ) - - with model_info_cols[1]: - if os.path.exists(model_path): - file_hash = hashlib.md5( - open(model_path, "rb").read() - ).hexdigest() - render_kv_grid( - { - "Weights Hash": f"{file_hash[:16]}...", - "Output Shape": f"(1, {len(LABEL_MAP)})", - "Activation": "Softmax", - }, - ncols=1, - ) - - # Debug logs (collapsed by default) - with st.expander("📋 Debug Logs", expanded=False): - log_content = "\n".join( - st.session_state.get("log_messages", []) - ) - if log_content.strip(): - st.code(log_content, language="text") - else: - st.caption("No debug logs available") - - elif active_tab == "Explanation": - with st.container(): - st.markdown("### 🔍 Methodology & Interpretation") - - # Process explanation - st.markdown("Analysis Pipeline") - process_steps = [ - "📁 **Data Upload**: Raman spectrum file loaded and validated", - "🔍 **Preprocessing**: Spectrum parsed and resampled to 500 data points using linear interpolation", - "🧠 **AI Inference**: Convolutional Neural Network analyzes spectral patterns and molecular signatures", - "📊 **Classification**: Binary prediction with confidence scoring using softmax probabilities", - "✅ **Validation**: Ground truth comparison (when available from filename)", - ] - - for step in process_steps: - st.markdown(step) - - st.markdown("---") - - # Model interpretation - st.markdown("#### Scientific Interpretation") - - interp_col1, interp_col2 = st.columns(2) - - with interp_col1: - st.markdown("**Stable (Unweathered) Polymers:**") - st.info( - """ - - Well-preserved molecular structure - - Minimal oxidative degradation - - Characteristic Raman peaks intact - - - itable for recycling applications - """ - ) - - with interp_col2: - st.markdown("**Weathered (Degraded) Polymers:**") - st.warning( - """ - - Oxidized molecular bonds - - Surface degradation present - - Altered spectral signatures - - May require additional processing - """ - ) - - st.markdown("---") - - # Applications - st.markdown("#### Research Applications") - - applications = [ - "🔬 **Material Science**: Polymer degradation studies", - "♻️ **Recycling Research**: Viability assessment for circular economy", - "🌱 **Environmental Science**: Microplastic weathering analysis", - "🏭 **Quality Control**: Manufacturing process monitoring", - "📈 **Longevity Studies**: Material aging prediction", - ] - - for app in applications: - st.markdown(app) - - # Technical details - # MODIFIED: Wrap the expander in a div with the 'expander-advanced' class - st.markdown( - '
', unsafe_allow_html=True - ) - with st.expander("🔧 Technical Details", expanded=False): - st.markdown( - """ - **Model Architecture:** - - Convolutional layers for feature extraction - - Residual connections for gradient flow - - Fully connected layers for classification - - Softmax activation for probability distribution - - **Performance Metrics:** - - Accuracy: 94.8-96.2% on validation set - - F1-Score: 94.3-95.9% across classes - - Robust to spectral noise and baseline variations - - **Data Processing:** - - Input: Raman spectra (any length) - - Resampling: Linear interpolation to 500 points - - Normalization: None (preserves intensity relationships) - """ - ) - st.markdown( - "
", unsafe_allow_html=True - ) # Close the wrapper div - - render_time = time.time() - start_render - log_message( - f"col2 rendered in {render_time:.2f}s, active tab: {active_tab}" - ) - - with st.expander("Spectrum Preprocessing Results", expanded=False): - st.caption("
Spectral Analysis", unsafe_allow_html=True) - - # Add some context about the preprocessing - st.markdown( - """ - **Preprocessing Overview:** - - **Original Spectrum**: Raw Raman data as uploaded - - **Resampled Spectrum**: Data interpolated to 500 points for model input - - **Purpose**: Ensures consistent input dimensions for neural network - """ - ) - - # Create and display plot - cache_key = hashlib.md5( - f"{(x_raw.tobytes() if x_raw is not None else b'')}" - f"{(y_raw.tobytes() if y_raw is not None else b'')}" - f"{(x_resampled.tobytes() if x_resampled is not None else b'')}" - f"{(y_resampled.tobytes() if y_resampled is not None else b'')}".encode() - ).hexdigest() - spectrum_plot = create_spectrum_plot( - x_raw, y_raw, x_resampled, y_resampled, _cache_key=cache_key - ) - st.image( - spectrum_plot, - caption="Raman Spectrum: Raw vs Processed", - use_container_width=True, - ) - - else: - st.markdown( - """ - ##### How to Get Started - - 1. **Select an AI Model:** Use the dropdown menu in the sidebar to choose a model. - 2. **Provide Your Data:** Select one of the three input modes: - - **Upload File:** Analyze a single spectrum. - - **Batch Upload:** Process multiple files at once. - - **Sample Data:** Explore functionality with pre-loaded examples. - 3. **Run Analysis:** Click the "Run Analysis" button to generate the classification results. - - --- - - ##### Supported Data Format - - - **File Type:** Plain text (`.txt`) - - **Content:** Must contain two columns: `wavenumber` and `intensity`. - - **Separators:** Values can be separated by spaces or commas. - - **Preprocessing:** Your spectrum will be automatically resampled to 500 data points to match the model's input requirements. - - --- - - ##### Example Applications - - 🔬 Research on polymer degradation - - ♻️ Recycling feasibility assessment - - 🌱 Sustainability impact studies - - 🏭 Quality control in manufacturing - """ - ) - else: - # ===Getting Started=== - st.markdown( - """ - ##### How to Get Started - - 1. **Select an AI Model:** Use the dropdown menu in the sidebar to choose a model. - 2. **Provide Your Data:** Select one of the three input modes: - - **Upload File:** Analyze a single spectrum. - - **Batch Upload:** Process multiple files at once. - - **Sample Data:** Explore functionality with pre-loaded examples. - 3. **Run Analysis:** Click the "Run Analysis" button to generate the classification results. - - --- - - ##### Supported Data Format - - - **File Type:** Plain text (`.txt`) - - **Content:** Must contain two columns: `wavenumber` and `intensity`. - - **Separators:** Values can be separated by spaces or commas. - - **Preprocessing:** Your spectrum will be automatically resampled to 500 data points to match the model's input requirements. - - --- - - ##### Example Applications - - 🔬 Research on polymer degradation - - ♻️ Recycling feasibility assessment - - 🌱 Sustainability impact studies - - 🏭 Quality control in manufacturing - """ - ) +import os +import torch +import streamlit as st +import hashlib +import io +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt +from typing import Union +import time +from config import TARGET_LEN, LABEL_MAP, MODEL_WEIGHTS_DIR +from models.registry import choices, get_model_info +from modules.callbacks import ( + on_model_change, + on_input_mode_change, + on_sample_change, + reset_results, + reset_ephemeral_state, + log_message, +) +from core_logic import ( + get_sample_files, + load_model, + run_inference, + parse_spectrum_data, + label_file, +) +from utils.results_manager import ResultsManager +from utils.multifile import process_multiple_files +from utils.preprocessing import resample_spectrum, validate_spectrum_modality +from utils.confidence import calculate_softmax_confidence + + +def load_css(file_path): + with open(file_path, encoding="utf-8") as f: + st.markdown(f"", unsafe_allow_html=True) + + +@st.cache_data +def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled, _cache_key=None): + """Create spectrum visualization plot""" + fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100) + + # Raw spectrum + ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1) + ax[0].set_title("Raw Input Spectrum") + ax[0].set_xlabel("Wavenumber (cm⁻¹)") + ax[0].set_ylabel("Intensity") + ax[0].grid(True, alpha=0.3) + ax[0].legend() + + # Resampled spectrum + ax[1].plot( + x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1 + ) + ax[1].set_title(f"Resampled ({len(y_resampled)} points)") + ax[1].set_xlabel("Wavenumber (cm⁻¹)") + ax[1].set_ylabel("Intensity") + ax[1].grid(True, alpha=0.3) + ax[1].legend() + + fig.tight_layout() + # Convert to image + buf = io.BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight", dpi=100) + buf.seek(0) + plt.close(fig) # Prevent memory leaks + + return Image.open(buf) + + +# ////////////////////////////////////////// + + +def render_confidence_progress( + probs: np.ndarray, + labels: list[str] = ["Stable", "Weathered"], + highlight_idx: Union[int, None] = None, + side_by_side: bool = True, +): + """Render Streamlit native progress bars with scientific formatting.""" + p = np.asarray(probs, dtype=float) + p = np.clip(p, 0.0, 1.0) + + if side_by_side: + cols = st.columns(len(labels)) + for i, (lbl, val, col) in enumerate(zip(labels, p, cols)): + with col: + is_highlighted = highlight_idx is not None and i == highlight_idx + label_text = f"**{lbl}**" if is_highlighted else lbl + st.markdown(f"{label_text}: {val*100:.1f}%") + st.progress(int(round(val * 100))) + else: + # Vertical layout for better readability + for i, (lbl, val) in enumerate(zip(labels, p)): + is_highlighted = highlight_idx is not None and i == highlight_idx + + # Create a container for each probability + with st.container(): + col1, col2 = st.columns([3, 1]) + with col1: + if is_highlighted: + st.markdown(f"**{lbl}** ← Predicted") + else: + st.markdown(f"{lbl}") + with col2: + st.metric(label="", value=f"{val*100:.1f}%", delta=None) + + # Progress bar with conditional styling + if is_highlighted: + st.progress(int(round(val * 100))) + st.caption("🎯 **Model Prediction**") + else: + st.progress(int(round(val * 100))) + + if i < len(labels) - 1: # Add spacing between items + st.markdown("") + + +from typing import Optional + + +def render_kv_grid(d: Optional[dict] = None, ncols: int = 2): + if d is None: + d = {} + if not d: + return + items = list(d.items()) + cols = st.columns(ncols) + for i, (k, v) in enumerate(items): + with cols[i % ncols]: + st.caption(f"**{k}:** {v}") + + +# ////////////////////////////////////////// + + +def render_model_meta(model_choice: str): + info = get_model_info(model_choice) + emoji = info.get("emoji", "") + desc = info.get("description", "").strip() + acc = info.get("performance", {}).get("accuracy", "-") + f1 = info.get("performance", {}).get("f1_score", "-") + + st.caption(f"{emoji} **Model Snapshot** - {model_choice}") + cols = st.columns(2) + with cols[0]: + st.metric("Accuracy", acc) + with cols[1]: + st.metric("F1 Score", f1) + if desc: + st.caption(desc) + + +# ////////////////////////////////////////// + + +def get_confidence_description(logit_margin): + """Get human-readable confidence description""" + if logit_margin > 1000: + return "VERY HIGH", "🟢" + elif logit_margin > 250: + return "HIGH", "🟡" + elif logit_margin > 100: + return "MODERATE", "🟠" + else: + return "LOW", "🔴" + + +# ////////////////////////////////////////// + + +def render_sidebar(): + with st.sidebar: + # Header + st.header("AI-Driven Polymer Classification") + st.caption( + "Predict polymer degradation (Stable vs Weathered) from Raman/FTIR spectra using validated CNN models. — v0.01" + ) + + # Modality Selection + st.markdown("##### Spectroscopy Modality") + modality = st.selectbox( + "Choose Modality", + ["raman", "ftir"], + index=0, + key="modality_select", + format_func=lambda x: f"{'Raman' if x == 'raman' else 'FTIR'}", + ) + + # Display modality info + if modality == "ftir": + st.info("FTIR mode: 400-4000 cm-1 range with atmospheric correction") + else: + st.info("Raman mode: 200-4000 cm-1 range with standard preprocessing") + + # Model selection + st.markdown("##### AI Model Selection") + + model_emojis = { + "figure2": "📄", + "resnet": "🧠", + "resnet18vision": "👁️", + "enhanced_cnn": "✨", + "efficient_cnn": "⚡", + "hybrid_net": "🧬", + } + + available_models = choices() + model_labels = [ + f"{model_emojis.get(name, '🤖')} {name}" for name in available_models + ] + + selected_label = st.selectbox( + "Choose AI Model", + model_labels, + key="model_select", + on_change=on_model_change, + ) + model_choice = selected_label.split(" ", 1)[1] + + # Compact metadata directly under dropdown + render_model_meta(model_choice) + + # Collapsed info to reduce clutter + with st.expander("About This App", icon=":material/info:", expanded=False): + st.markdown( + """ + **AI-Driven Polymer Aging Prediction and Classification** + + **Purpose**: Classify polymer degradation using AI
+ **Input**: Raman spectroscopy .txt files
+ **Models**: CNN architectures for classification
+ **Modalities**: Raman and FTIR spectroscopy support
+ **Features**: Multi-model comparison and analysis
+ + + **Contributors**
+ - Dr. Sanmukh Kuppannagari (Mentor)
+ - Dr. Metin Karailyan (Mentor)
+ - Jaser Hasan (Author)
+ + + **Links**
+ [HF Space](https://huggingface.co/spaces/dev-jas/polymer-aging-ml)
+ [GitHub Repository](https://github.com/KLab-AI3/ml-polymer-recycling) + + + **Citation Figure2CNN (baseline)** + Neo et al., 2023, *Resour. Conserv. Recycl.*, 188, 106718. + [https://doi.org/10.1016/j.resconrec.2022.106718](https://doi.org/10.1016/j.resconrec.2022.106718) + """, + unsafe_allow_html=True, + ) + + +# ////////////////////////////////////////// +def render_input_column(): + st.markdown("##### Data Input") + + mode = st.radio( + "Input mode", + ["Upload File", "Batch Upload", "Sample Data"], + key="input_mode", + horizontal=True, + on_change=on_input_mode_change, + ) + + # == Input Mode Logic == + if mode == "Upload File": + upload_key = st.session_state["current_upload_key"] + up = st.file_uploader( + "Upload spectrum file (.txt, .csv, .json)", + type=["txt", "csv", "json"], + help="Upload spectroscopy data: TXT (2-column), CSV (with headers), or JSON format", + key=upload_key, # ← versioned key + ) + + # Process change immediately + if up is not None: + raw = up.read() + text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + # only reparse if its a different file|source + if ( + st.session_state.get("filename") != getattr(up, "name", None) + or st.session_state.get("input_source") != "upload" + ): + st.session_state["input_text"] = text + st.session_state["filename"] = getattr(up, "name", None) + st.session_state["input_source"] = "upload" + # Ensure single file mode + st.session_state["batch_mode"] = False + st.session_state["status_message"] = ( + f"File '{st.session_state['filename']}' ready for analysis" + ) + st.session_state["status_type"] = "success" + reset_results("New file uploaded") + + # Batch Upload tab + elif mode == "Batch Upload": + st.session_state["batch_mode"] = True + # Use a versioned key to ensure the file uploader resets properly. + batch_upload_key = f"batch_upload_{st.session_state['uploader_version']}" + uploaded_files = st.file_uploader( + "Upload multiple spectrum files (.txt, .csv, .json)", + type=["txt", "csv", "json"], + accept_multiple_files=True, + help="Upload spectroscopy files in TXT, CSV, or JSON format.", + key=batch_upload_key, + ) + + if uploaded_files: + # Use a dictionary to keep only unique files based on name and size + unique_files = {(file.name, file.size): file for file in uploaded_files} + unique_file_list = list(unique_files.values()) + + num_uploaded = len(uploaded_files) + num_unique = len(unique_file_list) + + # Optionally, inform the user that duplicates were removed + if num_uploaded > num_unique: + st.info(f"{num_uploaded - num_unique} duplicate file(s) were removed.") + + # Use the unique list + st.session_state["batch_files"] = unique_file_list + st.session_state["status_message"] = ( + f"{num_unique} ready for batch analysis" + ) + st.session_state["status_type"] = "success" + else: + st.session_state["batch_files"] = [] + # This check prevents resetting the status if files are already staged + if not st.session_state.get("batch_files"): + st.session_state["status_message"] = ( + "No files selected for batch processing" + ) + st.session_state["status_type"] = "info" + + # Sample tab + elif mode == "Sample Data": + st.session_state["batch_mode"] = False + sample_files = get_sample_files() + if sample_files: + options = ["-- Select Sample --"] + [p.name for p in sample_files] + sel = st.selectbox( + "Choose sample spectrum:", + options, + key="sample_select", + on_change=on_sample_change, + ) + if sel != "-- Select Sample --": + st.session_state["status_message"] = ( + f"📁 Sample '{sel}' ready for analysis" + ) + st.session_state["status_type"] = "success" + else: + st.info("No sample data available") + # == Status box (displays the message) == + msg = st.session_state.get("status_message", "Ready") + typ = st.session_state.get("status_type", "info") + if typ == "success": + st.success(msg) + elif typ == "error": + st.error(msg) + else: + st.info(msg) + + # Safely get model choice from session state + model_choice = st.session_state.get("model_select", " ").split(" ", 1)[1] + model = load_model(model_choice) + + # Determine if the app is ready for inference + is_batch_ready = st.session_state.get("batch_mode", False) and st.session_state.get( + "batch_files" + ) + is_single_ready = not st.session_state.get( + "batch_mode", False + ) and st.session_state.get("input_text") + inference_ready = (is_batch_ready or is_single_ready) and model is not None + # Store for other modules to access + st.session_state["inference_ready"] = inference_ready + + # Render buttons + with st.form("analysis_form", clear_on_submit=False): + submitted = st.form_submit_button( + "Run Analysis", type="primary", disabled=not inference_ready + ) + st.button( + "Reset All", + on_click=reset_ephemeral_state, + help="Clear all uploaded files and results.", + ) + + # Handle form submission + if submitted and inference_ready: + if st.session_state.get("batch_mode"): + batch_files = st.session_state.get("batch_files", []) + with st.spinner(f"Processing {len(batch_files)} files ..."): + st.session_state["batch_results"] = process_multiple_files( + uploaded_files=batch_files, + model_choice=model_choice, + run_inference_func=run_inference, + label_file_func=label_file, + modality=st.session_state.get("modality_select", "raman"), + ) + else: + try: + x_raw, y_raw = parse_spectrum_data(st.session_state["input_text"]) + + # Validate that spectrum matches selected modality + selected_modality = st.session_state.get("modality_select", "raman") + is_valid, issues = validate_spectrum_modality( + x_raw, y_raw, selected_modality + ) + + if not is_valid: + st.warning("⚠️ **Spectrum-Modality Mismatch Detected**") + for issue in issues: + st.warning(f"• {issue}") + + # Ask user if they want to continue + st.info( + "💡 **Suggestion**: Check if the correct modality is selected in the sidebar, or verify your data file." + ) + + if st.button("⚠️ Continue Anyway", key="continue_with_mismatch"): + st.warning( + "Proceeding with potentially mismatched data. Results may be unreliable." + ) + else: + st.stop() # Stop processing until user confirms + + x_resampled, y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN) + st.session_state.update( + { + "x_raw": x_raw, + "y_raw": y_raw, + "x_resampled": x_resampled, + "y_resampled": y_resampled, + "inference_run_once": True, + } + ) + except (ValueError, TypeError) as e: + st.error(f"Error processing spectrum data: {e}") + + +# ////////////////////////////////////////// + + +def render_results_column(): + # Get the current mode and check for batch results + is_batch_mode = st.session_state.get("batch_mode", False) + has_batch_results = "batch_results" in st.session_state + + if is_batch_mode and has_batch_results: + # THEN render the main interactive dashboard from ResultsManager + ResultsManager.display_results_table() + + elif st.session_state.get("inference_run_once", False) and not is_batch_mode: + st.markdown("##### Analysis Results") + # Get data from session state + x_raw = st.session_state.get("x_raw") + y_raw = st.session_state.get("y_raw") + x_resampled = st.session_state.get("x_resampled") # ← NEW + y_resampled = st.session_state.get("y_resampled") + filename = st.session_state.get("filename", "Unknown") + + if all(v is not None for v in [x_raw, y_raw, y_resampled]): + # Run inference + if y_resampled is None: + raise ValueError( + "y_resampled is None. Ensure spectrum data is properly resampled before proceeding." + ) + cache_key = hashlib.md5( + f"{y_resampled.tobytes()}{st.session_state.get('model_select', 'Unknown').split(' ', 1)[1]}".encode() + ).hexdigest() + # MODIFIED: Pass modality to run_inference + prediction, logits_list, probs, inference_time, logits = run_inference( + y_resampled, + ( + st.session_state.get("model_select", "").split(" ", 1)[1] + if "model_select" in st.session_state + else None + ), + modality=st.session_state.get("modality_select", "raman"), + _cache_key=cache_key, + ) + if prediction is None: + st.error( + "❌ Inference failed: Model not loaded. Please check that weights are available." + ) + st.stop() # prevents the rest of the code in this block from executing + + log_message( + f"Inference completed in {inference_time:.2f}s, prediction: {prediction}" + ) + + # Get ground truth + true_label_idx = label_file(filename) + true_label_str = ( + LABEL_MAP.get(true_label_idx, "Unknown") + if true_label_idx is not None + else "Unknown" + ) + # Get prediction + predicted_class = LABEL_MAP.get(int(prediction), f"Class {int(prediction)}") + + # Enhanced confidence calculation + if logits is not None: + # Use new softmax-based confidence + probs_np, max_confidence, confidence_level, confidence_emoji = ( + calculate_softmax_confidence(logits) + ) + confidence_desc = confidence_level + else: + # Fallback to legacy method + logit_margin = abs( + (logits_list[0] - logits_list[1]) + if logits_list is not None and len(logits_list) >= 2 + else 0 + ) + confidence_desc, confidence_emoji = get_confidence_description( + logit_margin + ) + max_confidence = logit_margin / 10.0 # Normalize for display + probs_np = np.array([]) + + # Store result in results manager for single file too + ResultsManager.add_results( + filename=filename, + model_name=( + st.session_state.get("model_select", "").split(" ", 1)[1] + if "model_select" in st.session_state + else "Unknown" + ), + prediction=int(prediction), + predicted_class=predicted_class, + confidence=max_confidence, + logits=logits_list if logits_list else [], + ground_truth=true_label_idx if true_label_idx >= 0 else None, + processing_time=inference_time if inference_time is not None else 0.0, + metadata={ + "confidence_level": confidence_desc, + "confidence_emoji": confidence_emoji, + }, + ) + + # Precompute Stats + model_choice = ( + st.session_state.get("model_select", "").split(" ", 1)[1] + if "model_select" in st.session_state + else None + ) + if not model_choice: + st.error( + "⚠️ Model choice is not defined. Please select a model from the sidebar." + ) + st.stop() + model_path = os.path.join(MODEL_WEIGHTS_DIR, f"{model_choice}_model.pth") + mtime = os.path.getmtime(model_path) if os.path.exists(model_path) else None + file_hash = ( + hashlib.md5(open(model_path, "rb").read()).hexdigest() + if os.path.exists(model_path) + else "N/A" + ) + + start_render = time.time() + + active_tab = st.selectbox( + "View Results", + ["Details", "Technical", "Explanation"], + key="active_tab", # reuse the key you were managing manually + ) + + if active_tab == "Details": + st.markdown('
', unsafe_allow_html=True) + # Use a dynamic and informative title for the expander + with st.expander(f"Results for {filename}", expanded=True): + + # --- START: STREAMLINED METRICS --- + # A single, powerful row for the most important results. + key_metric_cols = st.columns(3) + + # Metric 1: The Prediction + key_metric_cols[0].metric("Prediction", predicted_class) + + # Metric 2: The Confidence (with level in tooltip) + confidence_icon = ( + "🟢" + if max_confidence >= 0.8 + else "🟡" if max_confidence >= 0.6 else "🔴" + ) + key_metric_cols[1].metric( + "Confidence", + f"{confidence_icon} {max_confidence:.1%}", + help=f"Confidence Level: {confidence_desc}", + ) + + # Metric 3: Ground Truth + Correctness (Combined) + if true_label_idx is not None: + is_correct = predicted_class == true_label_str + delta_text = "✅ Correct" if is_correct else "❌ Incorrect" + # Use delta_color="normal" to let the icon provide the visual cue + key_metric_cols[2].metric( + "Ground Truth", + true_label_str, + delta=delta_text, + delta_color="normal", + ) + else: + key_metric_cols[2].metric("Ground Truth", "N/A") + + st.divider() + # --- END: STREAMLINED METRICS --- + + # --- START: CONSOLIDATED CONFIDENCE ANALYSIS --- + st.markdown("##### Probability Breakdown") + + # This custom bullet bar logic remains as it is highly specific and valuable + def create_bullet_bar(probability, width=20, predicted=False): + filled_count = int(probability * width) + bar = "▤" * filled_count + "▢" * (width - filled_count) + percentage = f"{probability:.1%}" + pred_marker = "↩ Predicted" if predicted else "" + return f"{bar} {percentage} {pred_marker}" + + if probs is not None: + stable_prob, weathered_prob = probs[0], probs[1] + else: + st.error( + "❌ Probability values are missing. Please check the inference process." + ) + # Default values to prevent further errors + stable_prob, weathered_prob = 0.0, 0.0 + is_stable_predicted, is_weathered_predicted = ( + int(prediction) == 0 + ), (int(prediction) == 1) + + st.markdown( + f""" +
+ Stable (Unweathered)
+ {create_bullet_bar(stable_prob, predicted=is_stable_predicted)}

+ Weathered (Degraded)
+ {create_bullet_bar(weathered_prob, predicted=is_weathered_predicted)} +
+ """, + unsafe_allow_html=True, + ) + + st.divider() + + # METADATA FOOTER + st.caption( + f"Analyzed with **{st.session_state.get('model_select', 'Unknown')}** in **{inference_time:.2f}s**." + ) + st.markdown("
", unsafe_allow_html=True) + + elif active_tab == "Technical": + with st.container(): + st.markdown("Technical Diagnostics") + + # Model performance metrics + with st.container(border=True): + st.markdown("##### **Model Performance**") + tech_col1, tech_col2 = st.columns(2) + + with tech_col1: + st.metric("Inference Time", f"{inference_time:.3f}s") + st.metric( + "Input Length", + f"{len(x_raw) if x_raw is not None else 0} points", + ) + st.metric("Resampled Length", f"{TARGET_LEN} points") + + with tech_col2: + st.metric( + "Model Loaded", + ( + "✅ Yes" + if st.session_state.get("model_loaded", False) + else "❌ No" + ), + ) + st.metric("Device", "CPU") + st.metric("Confidence Score", f"{max_confidence:.3f}") + + # Raw logits display + with st.container(border=True): + st.markdown("##### **Raw Model Outputs (Logits)**") + logits_df = { + "Class": ( + [ + LABEL_MAP.get(i, f"Class {i}") + for i in range(len(logits_list)) + ] + if logits_list is not None + else [] + ), + "Logit Value": ( + [f"{score:.4f}" for score in logits_list] + if logits_list is not None + else [] + ), + "Probability": ( + [f"{prob:.4f}" for prob in probs_np] + if logits_list is not None and len(probs_np) > 0 + else [] + ), + } + + # Display as a simple table format + for i, (cls, logit, prob) in enumerate( + zip( + logits_df["Class"], + logits_df["Logit Value"], + logits_df["Probability"], + ) + ): + col1, col2, col3 = st.columns([2, 1, 1]) + with col1: + if i == prediction: + st.markdown(f"**{cls}** ← Predicted") + else: + st.markdown(cls) + with col2: + st.caption(f"Logit: {logit}") + with col3: + st.caption(f"Prob: {prob}") + + # Spectrum statistics in organized sections + with st.container(border=True): + st.markdown("##### **Spectrum Analysis**") + spec_cols = st.columns(2) + + with spec_cols[0]: + st.markdown("**Original Spectrum:**") + render_kv_grid( + { + "Length": f"{len(x_raw) if x_raw is not None else 0} points", + "Range": ( + f"{min(x_raw):.1f} - {max(x_raw):.1f} cm⁻¹" + if x_raw is not None + else "N/A" + ), + "Min Intensity": ( + f"{min(y_raw):.2e}" + if y_raw is not None + else "N/A" + ), + "Max Intensity": ( + f"{max(y_raw):.2e}" + if y_raw is not None + else "N/A" + ), + }, + ncols=1, + ) + + with spec_cols[1]: + st.markdown("**Processed Spectrum:**") + render_kv_grid( + { + "Length": f"{TARGET_LEN} points", + "Resampling": "Linear interpolation", + "Normalization": "None", + "Input Shape": f"(1, 1, {TARGET_LEN})", + }, + ncols=1, + ) + + # Model information + with st.container(border=True): + st.markdown("##### **Model Information**") + model_info_cols = st.columns(2) + + with model_info_cols[0]: + render_kv_grid( + { + "Architecture": model_choice, + "Path": model_path, + "Weights Modified": ( + time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(mtime) + ) + if mtime + else "N/A" + ), + }, + ncols=1, + ) + + with model_info_cols[1]: + if os.path.exists(model_path): + file_hash = hashlib.md5( + open(model_path, "rb").read() + ).hexdigest() + render_kv_grid( + { + "Weights Hash": f"{file_hash[:16]}...", + "Output Shape": f"(1, {len(LABEL_MAP)})", + "Activation": "Softmax", + }, + ncols=1, + ) + + # Debug logs (collapsed by default) + with st.expander("📋 Debug Logs", expanded=False): + log_content = "\n".join( + st.session_state.get("log_messages", []) + ) + if log_content.strip(): + st.code(log_content, language="text") + else: + st.caption("No debug logs available") + + elif active_tab == "Explanation": + with st.container(): + st.markdown("### 🔍 Methodology & Interpretation") + + # Process explanation + st.markdown("Analysis Pipeline") + process_steps = [ + "📁 **Data Upload**: Raman spectrum file loaded and validated", + "🔍 **Preprocessing**: Spectrum parsed and resampled to 500 data points using linear interpolation", + "🧠 **AI Inference**: Convolutional Neural Network analyzes spectral patterns and molecular signatures", + "📊 **Classification**: Binary prediction with confidence scoring using softmax probabilities", + "✅ **Validation**: Ground truth comparison (when available from filename)", + ] + + for step in process_steps: + st.markdown(step) + + st.markdown("---") + + # Model interpretation + st.markdown("#### Scientific Interpretation") + + interp_col1, interp_col2 = st.columns(2) + + with interp_col1: + st.markdown("**Stable (Unweathered) Polymers:**") + st.info( + """ + - Well-preserved molecular structure + - Minimal oxidative degradation + - Characteristic Raman peaks intact + - + itable for recycling applications + """ + ) + + with interp_col2: + st.markdown("**Weathered (Degraded) Polymers:**") + st.warning( + """ + - Oxidized molecular bonds + - Surface degradation present + - Altered spectral signatures + - May require additional processing + """ + ) + + st.markdown("---") + + # Applications + st.markdown("#### Research Applications") + + applications = [ + "🔬 **Material Science**: Polymer degradation studies", + "♻️ **Recycling Research**: Viability assessment for circular economy", + "🌱 **Environmental Science**: Microplastic weathering analysis", + "🏭 **Quality Control**: Manufacturing process monitoring", + "📈 **Longevity Studies**: Material aging prediction", + ] + + for app in applications: + st.markdown(app) + + # Technical details + # MODIFIED: Wrap the expander in a div with the 'expander-advanced' class + st.markdown( + '
', unsafe_allow_html=True + ) + with st.expander("🔧 Technical Details", expanded=False): + st.markdown( + """ + **Model Architecture:** + - Convolutional layers for feature extraction + - Residual connections for gradient flow + - Fully connected layers for classification + - Softmax activation for probability distribution + + **Performance Metrics:** + - Accuracy: 94.8-96.2% on validation set + - F1-Score: 94.3-95.9% across classes + - Robust to spectral noise and baseline variations + + **Data Processing:** + - Input: Raman spectra (any length) + - Resampling: Linear interpolation to 500 points + - Normalization: None (preserves intensity relationships) + """ + ) + st.markdown( + "
", unsafe_allow_html=True + ) # Close the wrapper div + + render_time = time.time() - start_render + log_message( + f"col2 rendered in {render_time:.2f}s, active tab: {active_tab}" + ) + + with st.expander("Spectrum Preprocessing Results", expanded=False): + st.caption("
Spectral Analysis", unsafe_allow_html=True) + + # Add some context about the preprocessing + st.markdown( + """ + **Preprocessing Overview:** + - **Original Spectrum**: Raw Raman data as uploaded + - **Resampled Spectrum**: Data interpolated to 500 points for model input + - **Purpose**: Ensures consistent input dimensions for neural network + """ + ) + + # Create and display plot + cache_key = hashlib.md5( + f"{(x_raw.tobytes() if x_raw is not None else b'')}" + f"{(y_raw.tobytes() if y_raw is not None else b'')}" + f"{(x_resampled.tobytes() if x_resampled is not None else b'')}" + f"{(y_resampled.tobytes() if y_resampled is not None else b'')}".encode() + ).hexdigest() + spectrum_plot = create_spectrum_plot( + x_raw, y_raw, x_resampled, y_resampled, _cache_key=cache_key + ) + st.image( + spectrum_plot, + caption="Raman Spectrum: Raw vs Processed", + use_container_width=True, + ) + + else: + st.markdown( + """ + ##### How to Get Started + + 1. **Select an AI Model:** Use the dropdown menu in the sidebar to choose a model. + 2. **Provide Your Data:** Select one of the three input modes: + - **Upload File:** Analyze a single spectrum. + - **Batch Upload:** Process multiple files at once. + - **Sample Data:** Explore functionality with pre-loaded examples. + 3. **Run Analysis:** Click the "Run Analysis" button to generate the classification results. + + --- + + ##### Supported Data Format + + - **File Type(s):** `.txt`, `.csv`, `.json` + - **Content:** Must contain two columns: `wavenumber` and `intensity`. + - **Separators:** Values can be separated by spaces or commas. + - **Preprocessing:** Your spectrum will be automatically resampled to 500 data points to match the model's input requirements. + """ + ) + else: + # Getting Started + st.markdown( + """ + ##### How to Get Started + + 1. **Select an AI Model:** Use the dropdown menu in the sidebar to choose a model. + 2. **Provide Your Data:** Select one of the three input modes: + - **Upload File:** Analyze a single spectrum. + - **Batch Upload:** Process multiple files at once. + - **Sample Data:** Explore functionality with pre-loaded examples. + 3. **Run Analysis:** Click the "Run Analysis" button to generate the classification results. + + --- + + ##### Supported Data Format + + - **File Type(s):** `.txt`, `.csv`, `.json` + - **Content:** Must contain two columns: `wavenumber` and `intensity`. + - **Separators:** Values can be separated by spaces or commas. + - **Preprocessing:** Your spectrum will be automatically resampled to 500 data points to match the model's input requirements. + """ + ) + + +# ////////////////////////////////////////// + + +def render_comparison_tab(): + """Render the multi-model comparison interface""" + import streamlit as st + import matplotlib.pyplot as plt + from models.registry import ( + choices, + validate_model_list, + models_for_modality, + get_models_metadata, + ) + from utils.results_manager import ResultsManager + from core_logic import get_sample_files, run_inference, parse_spectrum_data + from utils.preprocessing import preprocess_spectrum + from utils.multifile import parse_spectrum_data + import numpy as np + import time + + st.markdown("### Multi-Model Comparison Analysis") + st.markdown( + "Compare predictions across different AI models for comprehensive analysis." + ) + + # Modality selector - Use independant state for comparison tab + col_mod1, col_mod2 = st.columns([1, 2]) + with col_mod1: + # Get the current sidebar modality but don't try to sync back + current_modality = st.session_state.get("modality_select", "raman") + modality = st.selectbox( + "Select Modality", + ["raman", "ftir"], + index=0 if current_modality == "raman" else 1, + help="Choose the spectroscopy modality for analysis", + key="comparison_tab_modality", # Independant key for session state to avoid duplication of UI elements + ) # Note: Intentially not synching back to avoid state conflicts + + with col_mod2: + # Filter models by modality + compatible_models = models_for_modality(modality) + if not compatible_models: + st.error(f"No models available for {modality.upper()} modality") + return + + st.info(f"📊 {len(compatible_models)} models available for {modality.upper()}") + + # Enhanced model selection with metadata + st.markdown("##### Select Models for Comparison") + + # Display model information + models_metadata = get_models_metadata() + + # Create enhanced multiselect with model descriptions + model_options = [] + model_descriptions = {} + for model in compatible_models: + desc = models_metadata.get(model, {}).get("description", "No description") + model_options.append(model) + model_descriptions[model] = desc + + selected_models = st.multiselect( + "Choose models to compare", + model_options, + default=(model_options[:2] if len(model_options) >= 2 else model_options), + help="Select 2 or more models to compare their predictions side-by-side", + key="comparison_model_select", + ) + + # Display selected model information + if selected_models: + with st.expander("Selected Model Details", expanded=False): + for model in selected_models: + info = models_metadata.get(model, {}) + st.markdown(f"**{model}**: {info.get('description', 'No description')}") + if "citation" in info: + st.caption(f"Citation: {info['citation']}") + + if len(selected_models) < 2: + st.warning("⚠️ Please select at least 2 models for comparison.") + + # Input selection for comparison + col1, col2 = st.columns([1, 1.5]) + + with col1: + st.markdown("###### Input Data") + + # File upload for comparison + comparison_file = st.file_uploader( + "Upload spectrum for comparison", + type=["txt", "csv", "json"], + key="comparison_file_upload", + help="Upload a spectrum file to test across all selected models", + ) + + # Or select sample data + selected_sample = None # Initialize with a default value + sample_files = get_sample_files() + if sample_files: + sample_options = ["-- Select Sample --"] + [p.name for p in sample_files] + selected_sample = st.selectbox( + "Or choose sample data", sample_options, key="comparison_sample_select" + ) + + # Get modality from session state + modality = st.session_state.get("modality_select", "raman") + st.info(f"Using {modality.upper()} preprocessing parameters") + + # Run comparison button + run_comparison = st.button( + "Run Multi-Model Comparison", + type="primary", + disabled=not ( + comparison_file + or (sample_files and selected_sample != "-- Select Sample --") + ), + ) + + with col2: + st.markdown("###### Comparison Results") + + if run_comparison: + # Determine input source + input_text = None + filename = "unknown" + + if comparison_file: + raw = comparison_file.read() + input_text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + filename = comparison_file.name + elif sample_files and selected_sample != "-- Select Sample --": + sample_path = next(p for p in sample_files if p.name == selected_sample) + with open(sample_path, "r", encoding="utf-8") as f: + input_text = f.read() + filename = selected_sample + + if input_text: + try: + # Parse spectrum data + x_raw, y_raw = parse_spectrum_data( + str(input_text), filename or "unknown_filename" + ) + + # Validate spectrum modality + is_valid, issues = validate_spectrum_modality( + x_raw, y_raw, modality + ) + if not is_valid: + st.error("**Spectrum-Modality Mismatch in Comparison**") + for issue in issues: + st.error(f"• {issue}") + st.info( + "Please check the selected modality or verify your data file." + ) + return # Exit comparison if validation fails + + # Preprocess spectrum once + _, y_processed = preprocess_spectrum( + x_raw, y_raw, modality=modality, target_len=500 + ) + + # Synchronous processing + comparison_results = {} + progress_bar = st.progress(0) + status_text = st.empty() + + for i, model_name in enumerate(selected_models): + status_text.text(f"Running inference with {model_name}...") + + start_time = time.time() + + # Run inference + prediction, logits_list, probs, inference_time, logits = ( + run_inference(y_processed, model_name, modality=modality) + ) + + processing_time = time.time() - start_time + + # --- FIX FOR SYNCHRONOUS PATH: Handle silent failure --- + if prediction is None: + comparison_results[model_name] = { + "status": "failed", + "error": "Model failed to load or returned None.", + } + else: + # Map prediction to class name + class_names = ["Stable", "Weathered"] + predicted_class = ( + class_names[int(prediction)] + if int(prediction) < len(class_names) + else f"Class_{prediction}" + ) + confidence = ( + float(np.max(probs)) + if probs is not None and probs.size > 0 + else 0.0 + ) + + comparison_results[model_name] = { + "prediction": prediction, + "predicted_class": predicted_class, + "confidence": confidence, + "probs": (probs.tolist() if probs is not None else []), + "logits": ( + logits_list if logits_list is not None else [] + ), + "processing_time": inference_time or 0.0, + "status": "success", + } + + progress_bar.progress((i + 1) / len(selected_models)) + + status_text.text("Comparison complete!") + + # Enhanced results display + if comparison_results: + # Filter successful results + successful_results = { + k: v + for k, v in comparison_results.items() + if v.get("status") == "success" + } + failed_results = { + k: v + for k, v in comparison_results.items() + if v.get("status") == "failed" + } + + if failed_results: + st.error( + f"Failed models: {', '.join(failed_results.keys())}" + ) + for model, result in failed_results.items(): + st.error( + f"{model}: {result.get('error', 'Unknown error')}" + ) + + if successful_results: + try: + st.markdown("###### Model Predictions") + + # Create enhanced comparison table + import pandas as pd + + table_data = [] + for model_name, result in successful_results.items(): + row = { + "Model": model_name, + "Prediction": result["predicted_class"], + "Confidence": f"{result['confidence']:.3f}", + "Processing Time (s)": f"{result['processing_time']:.3f}", + "Agreement": ( + "✓" + if len( + set( + r["prediction"] + for r in successful_results.values() + ) + ) + == 1 + else "✗" + ), + } + table_data.append(row) + + df = pd.DataFrame(table_data) + st.dataframe(df, use_container_width=True) + + # Model agreement analysis + predictions = [ + r["prediction"] for r in successful_results.values() + ] + agreement_rate = len(set(predictions)) == 1 + + if agreement_rate: + st.success("🎯 All models agree on the prediction!") + else: + st.warning( + "⚠️ Models disagree - review individual confidences" + ) + + # Enhanced visualization section + st.markdown("##### Enhanced Analysis Dashboard") + + tab1, tab2, tab3 = st.tabs( + [ + "Confidence Analysis", + "Performance Metrics", + "Detailed Breakdown", + ] + ) + + with tab1: + try: + # Enhanced confidence comparison + col1, col2 = st.columns(2) + + with col1: + # Bar chart of confidences + models = list(successful_results.keys()) + confidences = [ + successful_results[m]["confidence"] + for m in models + ] + + if len(confidences) == 0: + st.warning( + "No confidence data available for visualization." + ) + else: + fig, ax = plt.subplots(figsize=(8, 5)) + colors = plt.cm.Set3( + np.linspace(0, 1, len(models)) + ) + + bars = ax.bar( + models, + confidences, + alpha=0.8, + color=colors, + ) + + # Add value labels on bars + for bar, conf in zip(bars, confidences): + height = bar.get_height() + ax.text( + bar.get_x() + + bar.get_width() / 2.0, + height + 0.01, + f"{conf:.3f}", + ha="center", + va="bottom", + ) + + ax.set_ylabel("Confidence") + ax.set_title( + "Model Confidence Comparison" + ) + ax.set_ylim(0, 1.1) + plt.xticks(rotation=45) + plt.tight_layout() + st.pyplot(fig) + + with col2: + # Confidence distribution + st.markdown("**Confidence Statistics**") + if len(confidences) == 0: + st.warning( + "No confidence data available for statistics." + ) + else: + conf_stats = { + "Mean": np.mean(confidences), + "Std Dev": np.std(confidences), + "Min": np.min(confidences), + "Max": np.max(confidences), + "Range": np.max(confidences) + - np.min(confidences), + } + + for stat, value in conf_stats.items(): + st.metric(stat, f"{value:.4f}") + + except ValueError as e: + st.error(f"Error rendering results: {e}") + + except ValueError as e: + st.error(f"Error rendering results: {e}") + st.error(f"Error in Confidence Analysis tab: {e}") + + with tab2: + # Performance metrics + models = list(successful_results.keys()) + times = [ + successful_results[m]["processing_time"] + for m in models + ] + if len(times) == 0: + st.warning( + "No performance data available for visualization" + ) + else: + + perf_col1, perf_col2 = st.columns(2) + + with perf_col1: + # Processing time comparison + fig, ax = plt.subplots(figsize=(8, 5)) + bars = ax.bar( + models, times, alpha=0.8, color="skyblue" + ) + + for bar, time_val in zip(bars, times): + height = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2.0, + height + 0.001, + f"{time_val:.3f}s", + ha="center", + va="bottom", + ) + + ax.set_ylabel("Processing Time (s)") + ax.set_title("Model Processing Time Comparison") + plt.xticks(rotation=45) + plt.tight_layout() + st.pyplot(fig) + + with perf_col2: + # Performance statistics + st.markdown("**Performance Statistics**") + perf_stats = { + "Fastest Model": models[np.argmin(times)], + "Slowest Model": models[np.argmax(times)], + "Total Time": f"{np.sum(times):.3f}s", + "Average Time": f"{np.mean(times):.3f}s", + "Speed Difference": f"{np.max(times) - np.min(times):.3f}s", + } + + for stat, value in perf_stats.items(): + st.write(f"**{stat}**: {value}") + + with tab3: + # Detailed breakdown + for ( + model_name, + result, + ) in successful_results.items(): + with st.expander( + f"Detailed Results - {model_name}" + ): + col1, col2 = st.columns(2) + + with col1: + st.write( + f"**Prediction**: {result['predicted_class']}" + ) + st.write( + f"**Confidence**: {result['confidence']:.4f}" + ) + st.write( + f"**Processing Time**: {result['processing_time']:.4f}s" + ) + + # ROBUST CHECK FOR PROBABILITIES + if ( + "probs" in result + and result["probs"] is not None + and len(result["probs"]) > 0 + ): + st.write("**Class Probabilities**:") + class_names = [ + "Stable", + "Weathered", + ] + for i, prob in enumerate( + result["probs"] + ): + if i < len(class_names): + st.write( + f" - {class_names[i]}: {prob:.4f}" + ) + + with col2: + # ROBUST CHECK FOR LOGITS + if ( + "logits" in result + and result["logits"] is not None + and len(result["logits"]) > 0 + ): + st.write("**Raw Logits**:") + for i, logit in enumerate( + result["logits"] + ): + st.write( + f" - Class {i}: {logit:.4f}" + ) + + # Export options + st.markdown("##### Export Results") + export_col1, export_col2 = st.columns(2) + + with export_col1: + if st.button("📋 Copy Results to Clipboard"): + results_text = df.to_string(index=False) + st.code(results_text) + + with export_col2: + # Download results as CSV + csv_data = df.to_csv(index=False) + st.download_button( + label="💾 Download as CSV", + data=csv_data, + file_name=f"model_comparison_{filename}_{time.strftime('%Y%m%d_%H%M%S')}.csv", + mime="text/csv", + ) + except Exception as e: + import traceback + + st.error(f"Error during comparison: {str(e)}") + st.code(traceback.format_exc()) # Add traceback for debugging + + # Show recent comparison results if available + elif "last_comparison_results" in st.session_state: + st.info( + "Previous comparison results available. Upload a new file or select a sample to run new comparison." + ) + + # Show comparison history + comparison_stats = ResultsManager.get_comparison_stats() + if comparison_stats: + st.markdown("#### Comparison History") + + with st.expander("View detailed comparison statistics", expanded=False): + # Show model statistics table + stats_data = [] + for model_name, stats in comparison_stats.items(): + row = { + "Model": model_name, + "Total Predictions": stats["total_predictions"], + "Avg Confidence": f"{stats['avg_confidence']:.3f}", + "Avg Processing Time": f"{stats['avg_processing_time']:.3f}s", + "Accuracy": ( + f"{stats['accuracy']:.3f}" + if stats["accuracy"] is not None + else "N/A" + ), + } + stats_data.append(row) + + if stats_data: + import pandas as pd + + stats_df = pd.DataFrame(stats_data) + st.dataframe(stats_df, use_container_width=True) + + # Show agreement matrix if multiple models + agreement_matrix = ResultsManager.get_agreement_matrix() + if not agreement_matrix.empty and len(agreement_matrix) > 1: + st.markdown("**Model Agreement Matrix**") + st.dataframe(agreement_matrix.round(3), use_container_width=True) + + # Plot agreement heatmap + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow( + agreement_matrix.values, cmap="RdYlGn", vmin=0, vmax=1 + ) + + # Add text annotations + for i in range(len(agreement_matrix)): + for j in range(len(agreement_matrix.columns)): + text = ax.text( + j, + i, + f"{agreement_matrix.iloc[i, j]:.2f}", + ha="center", + va="center", + color="black", + ) + + ax.set_xticks(range(len(agreement_matrix.columns))) + ax.set_yticks(range(len(agreement_matrix))) + ax.set_xticklabels(agreement_matrix.columns, rotation=45) + ax.set_yticklabels(agreement_matrix.index) + ax.set_title("Model Agreement Matrix") + + plt.colorbar(im, ax=ax, label="Agreement Rate") + plt.tight_layout() + st.pyplot(fig) + + # Export functionality + if "last_comparison_results" in st.session_state: + st.markdown("##### Export Results") + + export_col1, export_col2 = st.columns(2) + + with export_col1: + if st.button("📥 Export Comparison (JSON)"): + import json + + results = st.session_state["last_comparison_results"] + json_str = json.dumps(results, indent=2, default=str) + st.download_button( + label="Download JSON", + data=json_str, + file_name=f"comparison_{results['filename'].split('.')[0]}.json", + mime="application/json", + ) + + with export_col2: + if st.button("📊 Export Full Report"): + report = ResultsManager.export_comparison_report() + st.download_button( + label="Download Full Report", + data=report, + file_name="model_comparison_report.json", + mime="application/json", + ) + + +# ////////////////////////////////////////// + + +from utils.performance_tracker import display_performance_dashboard + + +def render_performance_tab(): + """Render the performance tracking and analysis tab.""" + display_performance_dashboard() + + +# //////////////////////////////////////////