devjas1 commited on
Commit
64728dc
·
1 Parent(s): 0be85e4

(FEAT+UX)[Comparison Tab Revamp]: Redesign model comparison tab for multi-model, modality-aware, and async processing.

Browse files

- Improved model selection UI:
- Added modality selector for "raman" and "ftir".
- Filtered compatible models based on modality using new functions.
- Displayed model descriptions and citations dynamically.
- Enhanced multi-model selection:
- Provided rich metadata and descriptions for each selectable model.
- Show details for selected models in an expandable section.
- Added asynchronous processing option:
- Checkbox to enable async inference for multiple models.
- Integrated async batch inference logic, progress bar, and status updates.
- Synchronous mode retained for smaller comparisons.
- Improved result display and analysis:
- Results table now indicates model agreement.
- Enhanced UI feedback for failed models.
- Added agreement analysis (success/warning if models agree/disagree).
- Dashboard with confidence, performance, and detailed breakdown tabs.
- Confidence tab: Bar chart, value labels, summary stats.
- Performance tab: Processing time chart, stats, and model speed ranking.
- Detailed tab: Per-model expanders with predictions, confidences, logits.
- Export options for clipboard and CSV download with dynamic filenames.
- Refactored preprocessing and inference logic for clarity and performance.
- General UI/UX enhancements for clarity and multi-model workflows.

Files changed (1) hide show
  1. modules/ui_components.py +441 -111
modules/ui_components.py CHANGED
@@ -25,7 +25,6 @@ from core_logic import (
25
  label_file,
26
  )
27
  from utils.results_manager import ResultsManager
28
- from utils.confidence import calculate_softmax_confidence
29
  from utils.multifile import process_multiple_files, display_batch_results
30
  from utils.preprocessing import resample_spectrum
31
 
@@ -971,7 +970,12 @@ def render_comparison_tab():
971
  """Render the multi-model comparison interface"""
972
  import streamlit as st
973
  import matplotlib.pyplot as plt
974
- from models.registry import choices, validate_model_list
 
 
 
 
 
975
  from utils.results_manager import ResultsManager
976
  from core_logic import get_sample_files, run_inference, parse_spectrum_data
977
  from utils.preprocessing import preprocess_spectrum
@@ -984,19 +988,58 @@ def render_comparison_tab():
984
  "Compare predictions across different AI models for comprehensive analysis."
985
  )
986
 
987
- # Model selection for comparison
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
  st.markdown("##### Select Models for Comparison")
989
 
990
- available_models = choices()
 
 
 
 
 
 
 
 
 
 
991
  selected_models = st.multiselect(
992
  "Choose models to compare",
993
- available_models,
994
- default=(
995
- available_models[:2] if len(available_models) >= 2 else available_models
996
- ),
997
  help="Select 2 or more models to compare their predictions side-by-side",
 
998
  )
999
 
 
 
 
 
 
 
 
 
 
1000
  if len(selected_models) < 2:
1001
  st.warning("⚠️ Please select at least 2 models for comparison.")
1002
 
@@ -1062,101 +1105,380 @@ def render_comparison_tab():
1062
  str(input_text), filename or "unknown_filename"
1063
  )
1064
 
1065
- # Store results
1066
- comparison_results = {}
1067
- processing_times = {}
 
 
 
1068
 
1069
- progress_bar = st.progress(0)
1070
- status_text = st.empty()
 
 
 
 
 
 
 
 
 
1071
 
1072
- for i, model_name in enumerate(selected_models):
1073
- status_text.text(f"Running inference with {model_name}...")
1074
 
1075
- start_time = time.time()
1076
 
1077
- # Preprocess spectrum with modality-specific parameters
1078
- _, y_processed = preprocess_spectrum(
1079
- x_raw, y_raw, modality=modality, target_len=500
 
 
1080
  )
1081
 
1082
- # Run inference
1083
- prediction, logits_list, probs, inference_time, logits = (
1084
- run_inference(y_processed, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
1085
  )
1086
 
1087
- processing_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1088
 
1089
- if prediction is not None:
1090
- # Map prediction to class name
1091
- class_names = ["Stable", "Weathered"]
1092
- predicted_class = (
1093
- class_names[int(prediction)]
1094
- if prediction < len(class_names)
1095
- else f"Class_{prediction}"
1096
- )
1097
- confidence = (
1098
- max(probs)
1099
- if probs is not None and len(probs) > 0
1100
- else 0.0
 
 
 
 
1101
  )
1102
 
1103
- comparison_results[model_name] = {
1104
- "prediction": prediction,
1105
- "predicted_class": predicted_class,
1106
- "confidence": confidence,
1107
- "probs": probs if probs is not None else [],
1108
- "logits": (
1109
- logits_list if logits_list is not None else []
1110
- ),
1111
- "processing_time": processing_time,
1112
- }
1113
- processing_times[model_name] = processing_time
1114
 
1115
- progress_bar.progress((i + 1) / len(selected_models))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1116
 
1117
- status_text.text("Comparison complete!")
1118
 
1119
- # Display results
1120
  if comparison_results:
1121
- st.markdown("###### Model Predictions")
1122
-
1123
- # Create comparison table
1124
- import pandas as pd
1125
-
1126
- table_data = []
1127
- for model_name, result in comparison_results.items():
1128
- row = {
1129
- "Model": model_name,
1130
- "Prediction": result["predicted_class"],
1131
- "Confidence": f"{result['confidence']:.3f}",
1132
- "Processing Time (s)": f"{result['processing_time']:.3f}",
1133
- }
1134
- table_data.append(row)
1135
-
1136
- df = pd.DataFrame(table_data)
1137
- st.dataframe(df, use_container_width=True)
1138
-
1139
- # Show confidence comparison
1140
- st.markdown("##### Confidence Comparison")
1141
- conf_col1, conf_col2 = st.columns(2)
1142
-
1143
- with conf_col1:
1144
- # Bar chart of confidences
1145
- models = list(comparison_results.keys())
1146
- confidences = [
1147
- comparison_results[m]["confidence"] for m in models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1148
  ]
 
 
 
 
 
 
 
 
1149
 
1150
- fig, ax = plt.subplots(figsize=(8, 5))
1151
- bars = ax.bar(
1152
- models,
1153
- confidences,
1154
- alpha=0.7,
1155
- color=["steelblue", "orange", "green", "red"][
1156
- : len(models)
1157
- ],
 
1158
  )
1159
- ax.set_ylabel("Confidence")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1160
  ax.set_title("Model Confidence Comparison")
1161
  ax.set_ylim(0, 1)
1162
  plt.xticks(rotation=45)
@@ -1175,39 +1497,47 @@ def render_comparison_tab():
1175
  plt.tight_layout()
1176
  st.pyplot(fig)
1177
 
1178
- with conf_col2:
1179
- # Agreement analysis
1180
- predictions = [
1181
- comparison_results[m]["prediction"] for m in models
1182
- ]
1183
- unique_predictions = set(predictions)
1184
-
1185
- if len(unique_predictions) == 1:
1186
- st.success("✅ All models agree on the prediction!")
1187
- else:
1188
- st.warning("⚠️ Models disagree on the prediction")
 
 
1189
 
1190
- # Show prediction distribution
1191
- from collections import Counter
1192
 
1193
- pred_counts = Counter(predictions)
1194
 
1195
- st.markdown("**Prediction Distribution:**")
1196
- for pred, count in pred_counts.items():
1197
- class_name = (
1198
- ["Stable", "Weathered"][pred]
1199
- if pred < 2
1200
- else f"Class_{pred}"
1201
- )
1202
- percentage = (count / len(predictions)) * 100
1203
- st.write(
1204
- f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
1205
- )
1206
 
1207
  # Performance metrics
1208
  st.markdown("##### Performance Metrics")
1209
  perf_col1, perf_col2 = st.columns(2)
1210
 
 
 
 
 
 
 
 
1211
  with perf_col1:
1212
  avg_time = np.mean(list(processing_times.values()))
1213
  fastest_model = min(
@@ -1228,7 +1558,7 @@ def render_comparison_tab():
1228
  st.metric(
1229
  "Slowest Model",
1230
  f"{slowest_model}",
1231
- f"{processing_times[slowest_model]:.3f}s",
1232
  )
1233
 
1234
  with perf_col2:
 
25
  label_file,
26
  )
27
  from utils.results_manager import ResultsManager
 
28
  from utils.multifile import process_multiple_files, display_batch_results
29
  from utils.preprocessing import resample_spectrum
30
 
 
970
  """Render the multi-model comparison interface"""
971
  import streamlit as st
972
  import matplotlib.pyplot as plt
973
+ from models.registry import (
974
+ choices,
975
+ validate_model_list,
976
+ models_for_modality,
977
+ get_models_metadata,
978
+ )
979
  from utils.results_manager import ResultsManager
980
  from core_logic import get_sample_files, run_inference, parse_spectrum_data
981
  from utils.preprocessing import preprocess_spectrum
 
988
  "Compare predictions across different AI models for comprehensive analysis."
989
  )
990
 
991
+ # Modality selector
992
+ col_mod1, col_mod2 = st.columns([1, 2])
993
+ with col_mod1:
994
+ modality = st.selectbox(
995
+ "Select Modality",
996
+ ["raman", "ftir"],
997
+ index=0,
998
+ help="Choose the spectroscopy modality for analysis",
999
+ key="comparison_modality",
1000
+ )
1001
+ st.session_state["modality_select"] = modality
1002
+
1003
+ with col_mod2:
1004
+ # Filter models by modality
1005
+ compatible_models = models_for_modality(modality)
1006
+ if not compatible_models:
1007
+ st.error(f"No models available for {modality.upper()} modality")
1008
+ return
1009
+
1010
+ st.info(f"📊 {len(compatible_models)} models available for {modality.upper()}")
1011
+
1012
+ # Enhanced model selection with metadata
1013
  st.markdown("##### Select Models for Comparison")
1014
 
1015
+ # Display model information
1016
+ models_metadata = get_models_metadata()
1017
+
1018
+ # Create enhanced multiselect with model descriptions
1019
+ model_options = []
1020
+ model_descriptions = {}
1021
+ for model in compatible_models:
1022
+ desc = models_metadata.get(model, {}).get("description", "No description")
1023
+ model_options.append(model)
1024
+ model_descriptions[model] = desc
1025
+
1026
  selected_models = st.multiselect(
1027
  "Choose models to compare",
1028
+ model_options,
1029
+ default=(model_options[:2] if len(model_options) >= 2 else model_options),
 
 
1030
  help="Select 2 or more models to compare their predictions side-by-side",
1031
+ key="comparison_model_select",
1032
  )
1033
 
1034
+ # Display selected model information
1035
+ if selected_models:
1036
+ with st.expander("Selected Model Details", expanded=False):
1037
+ for model in selected_models:
1038
+ info = models_metadata.get(model, {})
1039
+ st.markdown(f"**{model}**: {info.get('description', 'No description')}")
1040
+ if "citation" in info:
1041
+ st.caption(f"Citation: {info['citation']}")
1042
+
1043
  if len(selected_models) < 2:
1044
  st.warning("⚠️ Please select at least 2 models for comparison.")
1045
 
 
1105
  str(input_text), filename or "unknown_filename"
1106
  )
1107
 
1108
+ # Enhanced comparison with async processing option
1109
+ use_async = st.checkbox(
1110
+ "Use asynchronous processing",
1111
+ value=len(selected_models) > 2,
1112
+ help="Process models concurrently for faster results",
1113
+ )
1114
 
1115
+ # Preprocess spectrum once
1116
+ _, y_processed = preprocess_spectrum(
1117
+ x_raw, y_raw, modality=modality, target_len=500
1118
+ )
1119
+
1120
+ if use_async:
1121
+ # Async processing
1122
+ from utils.async_inference import (
1123
+ submit_batch_inference,
1124
+ wait_for_batch_completion,
1125
+ )
1126
 
1127
+ status_text = st.empty()
1128
+ status_text.text("Starting asynchronous inference...")
1129
 
1130
+ progress_bar = st.progress(0)
1131
 
1132
+ # Submit all models for async processing
1133
+ task_ids = submit_batch_inference(
1134
+ model_names=selected_models,
1135
+ input_data=y_processed,
1136
+ inference_func=run_inference,
1137
  )
1138
 
1139
+ # Progress callback
1140
+ def update_progress(progress_data):
1141
+ completed = sum(
1142
+ 1
1143
+ for p in progress_data.values()
1144
+ if p["status"] in ["completed", "failed"]
1145
+ )
1146
+ progress_bar.progress(completed / len(selected_models))
1147
+ status_text.text(
1148
+ f"Processing: {completed}/{len(selected_models)} models complete"
1149
+ )
1150
+
1151
+ # Wait for completion
1152
+ async_results = wait_for_batch_completion(
1153
+ task_ids, timeout=60.0, progress_callback=update_progress
1154
  )
1155
 
1156
+ comparison_results = {}
1157
+ for model_name in selected_models:
1158
+ if model_name in async_results:
1159
+ result = async_results[model_name]
1160
+ if "error" not in result:
1161
+ (
1162
+ prediction,
1163
+ logits_list,
1164
+ probs,
1165
+ inference_time,
1166
+ logits,
1167
+ ) = result
1168
+ if prediction is not None:
1169
+ class_names = ["Stable", "Weathered"]
1170
+ predicted_class = (
1171
+ class_names[int(prediction)]
1172
+ if prediction < len(class_names)
1173
+ else f"Class_{prediction}"
1174
+ )
1175
+ confidence = (
1176
+ max(probs)
1177
+ if probs and len(probs) > 0
1178
+ else 0.0
1179
+ )
1180
+
1181
+ comparison_results[model_name] = {
1182
+ "prediction": prediction,
1183
+ "predicted_class": predicted_class,
1184
+ "confidence": confidence,
1185
+ "probs": probs if probs is not None else [],
1186
+ "logits": (
1187
+ logits_list
1188
+ if logits_list is not None
1189
+ else []
1190
+ ),
1191
+ "processing_time": inference_time or 0.0,
1192
+ "status": "success",
1193
+ }
1194
+ else:
1195
+ comparison_results[model_name] = {
1196
+ "status": "failed",
1197
+ "error": result["error"],
1198
+ }
1199
 
1200
+ status_text.text("Asynchronous processing complete!")
1201
+
1202
+ else:
1203
+ # Synchronous processing (original)
1204
+ comparison_results = {}
1205
+ progress_bar = st.progress(0)
1206
+ status_text = st.empty()
1207
+
1208
+ for i, model_name in enumerate(selected_models):
1209
+ status_text.text(f"Running inference with {model_name}...")
1210
+
1211
+ start_time = time.time()
1212
+
1213
+ # Run inference
1214
+ prediction, logits_list, probs, inference_time, logits = (
1215
+ run_inference(y_processed, model_name)
1216
  )
1217
 
1218
+ processing_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
1219
 
1220
+ if prediction is not None:
1221
+ # Map prediction to class name
1222
+ class_names = ["Stable", "Weathered"]
1223
+ predicted_class = (
1224
+ class_names[int(prediction)]
1225
+ if prediction < len(class_names)
1226
+ else f"Class_{prediction}"
1227
+ )
1228
+ confidence = (
1229
+ max(probs)
1230
+ if probs is not None and len(probs) > 0
1231
+ else 0.0
1232
+ )
1233
+
1234
+ comparison_results[model_name] = {
1235
+ "prediction": prediction,
1236
+ "predicted_class": predicted_class,
1237
+ "confidence": confidence,
1238
+ "probs": probs if probs is not None else [],
1239
+ "logits": (
1240
+ logits_list if logits_list is not None else []
1241
+ ),
1242
+ "processing_time": processing_time,
1243
+ "status": "success",
1244
+ }
1245
+
1246
+ progress_bar.progress((i + 1) / len(selected_models))
1247
 
1248
+ status_text.text("Comparison complete!")
1249
 
1250
+ # Enhanced results display
1251
  if comparison_results:
1252
+ # Filter successful results
1253
+ successful_results = {
1254
+ k: v
1255
+ for k, v in comparison_results.items()
1256
+ if v.get("status") == "success"
1257
+ }
1258
+ failed_results = {
1259
+ k: v
1260
+ for k, v in comparison_results.items()
1261
+ if v.get("status") == "failed"
1262
+ }
1263
+
1264
+ if failed_results:
1265
+ st.error(
1266
+ f"Failed models: {', '.join(failed_results.keys())}"
1267
+ )
1268
+ for model, result in failed_results.items():
1269
+ st.error(
1270
+ f"{model}: {result.get('error', 'Unknown error')}"
1271
+ )
1272
+
1273
+ if successful_results:
1274
+ st.markdown("###### Model Predictions")
1275
+
1276
+ # Create enhanced comparison table
1277
+ import pandas as pd
1278
+
1279
+ table_data = []
1280
+ for model_name, result in successful_results.items():
1281
+ row = {
1282
+ "Model": model_name,
1283
+ "Prediction": result["predicted_class"],
1284
+ "Confidence": f"{result['confidence']:.3f}",
1285
+ "Processing Time (s)": f"{result['processing_time']:.3f}",
1286
+ "Agreement": (
1287
+ "✓"
1288
+ if len(
1289
+ set(
1290
+ r["prediction"]
1291
+ for r in successful_results.values()
1292
+ )
1293
+ )
1294
+ == 1
1295
+ else "✗"
1296
+ ),
1297
+ }
1298
+ table_data.append(row)
1299
+
1300
+ df = pd.DataFrame(table_data)
1301
+ st.dataframe(df, use_container_width=True)
1302
+
1303
+ # Model agreement analysis
1304
+ predictions = [
1305
+ r["prediction"] for r in successful_results.values()
1306
  ]
1307
+ agreement_rate = len(set(predictions)) == 1
1308
+
1309
+ if agreement_rate:
1310
+ st.success("🎯 All models agree on the prediction!")
1311
+ else:
1312
+ st.warning(
1313
+ "⚠️ Models disagree - review individual confidences"
1314
+ )
1315
 
1316
+ # Enhanced visualization section
1317
+ st.markdown("##### Enhanced Analysis Dashboard")
1318
+
1319
+ tab1, tab2, tab3 = st.tabs(
1320
+ [
1321
+ "Confidence Analysis",
1322
+ "Performance Metrics",
1323
+ "Detailed Breakdown",
1324
+ ]
1325
  )
1326
+
1327
+ with tab1:
1328
+ # Enhanced confidence comparison
1329
+ col1, col2 = st.columns(2)
1330
+
1331
+ with col1:
1332
+ # Bar chart of confidences
1333
+ models = list(successful_results.keys())
1334
+ confidences = [
1335
+ successful_results[m]["confidence"]
1336
+ for m in models
1337
+ ]
1338
+
1339
+ fig, ax = plt.subplots(figsize=(8, 5))
1340
+ colors = plt.cm.Set3(np.linspace(0, 1, len(models)))
1341
+ bars = ax.bar(
1342
+ models, confidences, alpha=0.8, color=colors
1343
+ )
1344
+
1345
+ # Add value labels on bars
1346
+ for bar, conf in zip(bars, confidences):
1347
+ height = bar.get_height()
1348
+ ax.text(
1349
+ bar.get_x() + bar.get_width() / 2.0,
1350
+ height + 0.01,
1351
+ f"{conf:.3f}",
1352
+ ha="center",
1353
+ va="bottom",
1354
+ )
1355
+
1356
+ ax.set_ylabel("Confidence")
1357
+ ax.set_title("Model Confidence Comparison")
1358
+ ax.set_ylim(0, 1.1)
1359
+ plt.xticks(rotation=45)
1360
+ plt.tight_layout()
1361
+ st.pyplot(fig)
1362
+
1363
+ with col2:
1364
+ # Confidence distribution
1365
+ st.markdown("**Confidence Statistics**")
1366
+ conf_stats = {
1367
+ "Mean": np.mean(confidences),
1368
+ "Std Dev": np.std(confidences),
1369
+ "Min": np.min(confidences),
1370
+ "Max": np.max(confidences),
1371
+ "Range": np.max(confidences)
1372
+ - np.min(confidences),
1373
+ }
1374
+
1375
+ for stat, value in conf_stats.items():
1376
+ st.metric(stat, f"{value:.4f}")
1377
+
1378
+ with tab2:
1379
+ # Performance metrics
1380
+ times = [
1381
+ successful_results[m]["processing_time"]
1382
+ for m in models
1383
+ ]
1384
+
1385
+ perf_col1, perf_col2 = st.columns(2)
1386
+
1387
+ with perf_col1:
1388
+ # Processing time comparison
1389
+ fig, ax = plt.subplots(figsize=(8, 5))
1390
+ bars = ax.bar(
1391
+ models, times, alpha=0.8, color="skyblue"
1392
+ )
1393
+
1394
+ for bar, time_val in zip(bars, times):
1395
+ height = bar.get_height()
1396
+ ax.text(
1397
+ bar.get_x() + bar.get_width() / 2.0,
1398
+ height + 0.001,
1399
+ f"{time_val:.3f}s",
1400
+ ha="center",
1401
+ va="bottom",
1402
+ )
1403
+
1404
+ ax.set_ylabel("Processing Time (s)")
1405
+ ax.set_title("Model Processing Time Comparison")
1406
+ plt.xticks(rotation=45)
1407
+ plt.tight_layout()
1408
+ st.pyplot(fig)
1409
+
1410
+ with perf_col2:
1411
+ # Performance statistics
1412
+ st.markdown("**Performance Statistics**")
1413
+ perf_stats = {
1414
+ "Fastest Model": models[np.argmin(times)],
1415
+ "Slowest Model": models[np.argmax(times)],
1416
+ "Total Time": f"{np.sum(times):.3f}s",
1417
+ "Average Time": f"{np.mean(times):.3f}s",
1418
+ "Speed Difference": f"{np.max(times) - np.min(times):.3f}s",
1419
+ }
1420
+
1421
+ for stat, value in perf_stats.items():
1422
+ st.write(f"**{stat}**: {value}")
1423
+
1424
+ with tab3:
1425
+ # Detailed breakdown
1426
+ for model_name, result in successful_results.items():
1427
+ with st.expander(
1428
+ f"Detailed Results - {model_name}"
1429
+ ):
1430
+ col1, col2 = st.columns(2)
1431
+
1432
+ with col1:
1433
+ st.write(
1434
+ f"**Prediction**: {result['predicted_class']}"
1435
+ )
1436
+ st.write(
1437
+ f"**Confidence**: {result['confidence']:.4f}"
1438
+ )
1439
+ st.write(
1440
+ f"**Processing Time**: {result['processing_time']:.4f}s"
1441
+ )
1442
+
1443
+ if result["probs"]:
1444
+ st.write("**Class Probabilities**:")
1445
+ class_names = ["Stable", "Weathered"]
1446
+ for i, prob in enumerate(
1447
+ result["probs"]
1448
+ ):
1449
+ if i < len(class_names):
1450
+ st.write(
1451
+ f" - {class_names[i]}: {prob:.4f}"
1452
+ )
1453
+
1454
+ with col2:
1455
+ if result["logits"]:
1456
+ st.write("**Raw Logits**:")
1457
+ for i, logit in enumerate(
1458
+ result["logits"]
1459
+ ):
1460
+ st.write(
1461
+ f" - Class {i}: {logit:.4f}"
1462
+ )
1463
+
1464
+ # Export options
1465
+ st.markdown("##### Export Results")
1466
+ export_col1, export_col2 = st.columns(2)
1467
+
1468
+ with export_col1:
1469
+ if st.button("📋 Copy Results to Clipboard"):
1470
+ results_text = df.to_string(index=False)
1471
+ st.code(results_text)
1472
+
1473
+ with export_col2:
1474
+ # Download results as CSV
1475
+ csv_data = df.to_csv(index=False)
1476
+ st.download_button(
1477
+ label="💾 Download as CSV",
1478
+ data=csv_data,
1479
+ file_name=f"model_comparison_{filename}_{time.strftime('%Y%m%d_%H%M%S')}.csv",
1480
+ mime="text/csv",
1481
+ )
1482
  ax.set_title("Model Confidence Comparison")
1483
  ax.set_ylim(0, 1)
1484
  plt.xticks(rotation=45)
 
1497
  plt.tight_layout()
1498
  st.pyplot(fig)
1499
 
1500
+ conf_col2 = st.columns(2)
1501
+ with conf_col2[1]: # Access the second column explicitly
1502
+ # Agreement analysis
1503
+ predictions = [
1504
+ comparison_results[m]["prediction"]
1505
+ for m in comparison_results.keys()
1506
+ ]
1507
+ unique_predictions = set(predictions)
1508
+
1509
+ if len(unique_predictions) == 1:
1510
+ st.success(" All models agree on the prediction!")
1511
+ else:
1512
+ st.warning("⚠️ Models disagree on the prediction")
1513
 
1514
+ from collections import Counter
 
1515
 
1516
+ pred_counts = Counter(predictions)
1517
 
1518
+ st.markdown("**Prediction Distribution:**")
1519
+ for pred, count in pred_counts.items():
1520
+ class_name = (
1521
+ ["Stable", "Weathered"][pred]
1522
+ if pred < 2
1523
+ else f"Class_{pred}"
1524
+ )
1525
+ percentage = (count / len(predictions)) * 100
1526
+ st.write(
1527
+ f"- {class_name}: {count}/{len(predictions)} models ({percentage:.1f}%)"
1528
+ )
1529
 
1530
  # Performance metrics
1531
  st.markdown("##### Performance Metrics")
1532
  perf_col1, perf_col2 = st.columns(2)
1533
 
1534
+ # Collect processing times for each model
1535
+ processing_times = {
1536
+ model_name: result["processing_time"]
1537
+ for model_name, result in comparison_results.items()
1538
+ if result.get("status") == "success"
1539
+ }
1540
+
1541
  with perf_col1:
1542
  avg_time = np.mean(list(processing_times.values()))
1543
  fastest_model = min(
 
1558
  st.metric(
1559
  "Slowest Model",
1560
  f"{slowest_model}",
1561
+ f"{processing_times.get(slowest_model, 0):.3f}s",
1562
  )
1563
 
1564
  with perf_col2: