Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
2c41fa3
1
Parent(s):
ff443f3
(FEAT): Enhance results management with utility functions for session state initialization and reset
Browse files- .gitignore +1 -0
- utils/results_manager.py +61 -16
.gitignore
CHANGED
|
@@ -25,3 +25,4 @@ datasets/**
|
|
| 25 |
!datasets/.README.md
|
| 26 |
# ---------------------------------------
|
| 27 |
|
|
|
|
|
|
| 25 |
!datasets/.README.md
|
| 26 |
# ---------------------------------------
|
| 27 |
|
| 28 |
+
__pycache__.py
|
utils/results_manager.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Dict, List, Any, Optional
|
|
| 9 |
from pathlib import Path
|
| 10 |
import io
|
| 11 |
|
|
|
|
| 12 |
class ResultsManager:
|
| 13 |
"""Manages session-wide results for multi-file inference"""
|
| 14 |
|
|
@@ -73,7 +74,7 @@ class ResultsManager:
|
|
| 73 |
if not results:
|
| 74 |
return pd.DataFrame()
|
| 75 |
|
| 76 |
-
|
| 77 |
df_data = []
|
| 78 |
for result in results:
|
| 79 |
row = {
|
|
@@ -99,7 +100,7 @@ class ResultsManager:
|
|
| 99 |
if df.empty:
|
| 100 |
return b""
|
| 101 |
|
| 102 |
-
|
| 103 |
csv_buffer = io.StringIO()
|
| 104 |
df.to_csv(csv_buffer, index=False)
|
| 105 |
return csv_buffer.getvalue().encode('utf-8')
|
|
@@ -128,9 +129,9 @@ class ResultsManager:
|
|
| 128 |
"avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
|
| 129 |
"files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
|
| 130 |
}
|
| 131 |
-
|
| 132 |
correct_predictions = sum(
|
| 133 |
-
1 for r in results
|
| 134 |
if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
|
| 135 |
)
|
| 136 |
total_with_gt = stats["files_with_ground_truth"]
|
|
@@ -138,7 +139,7 @@ class ResultsManager:
|
|
| 138 |
stats["accuracy"] = correct_predictions / total_with_gt
|
| 139 |
else:
|
| 140 |
stats["accuracy"] = None
|
| 141 |
-
|
| 142 |
return stats
|
| 143 |
|
| 144 |
@staticmethod
|
|
@@ -146,26 +147,70 @@ class ResultsManager:
|
|
| 146 |
"""Remove a result by filename. Returns True if removed, False if not found."""
|
| 147 |
results = ResultsManager.get_results()
|
| 148 |
original_length = len(results)
|
| 149 |
-
|
| 150 |
# Filter out results with matching filename
|
| 151 |
st.session_state[ResultsManager.RESULTS_KEY] = [
|
| 152 |
r for r in results if r["filename"] != filename
|
| 153 |
]
|
| 154 |
-
|
| 155 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
@staticmethod
|
| 158 |
def display_results_table() -> None:
|
| 159 |
"""Display the results table in Streamlit UI"""
|
| 160 |
df = ResultsManager.get_results_dataframe()
|
| 161 |
|
| 162 |
if df.empty:
|
| 163 |
-
st.info(
|
|
|
|
| 164 |
return
|
| 165 |
|
| 166 |
st.subheader(f"Inference Results ({len(df)} files)")
|
| 167 |
|
| 168 |
-
|
| 169 |
stats = ResultsManager.get_summary_stats()
|
| 170 |
if stats:
|
| 171 |
col1, col2, col3, col4 = st.columns(4)
|
|
@@ -174,17 +219,18 @@ class ResultsManager:
|
|
| 174 |
with col2:
|
| 175 |
st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
|
| 176 |
with col3:
|
| 177 |
-
st.metric(
|
|
|
|
| 178 |
with col4:
|
| 179 |
if stats["accuracy"] is not None:
|
| 180 |
st.metric("Accuracy", f"{stats['accuracy']:.3f}")
|
| 181 |
else:
|
| 182 |
st.metric("Accuracy", "N/A")
|
| 183 |
|
| 184 |
-
|
| 185 |
st.dataframe(df, use_container_width=True)
|
| 186 |
|
| 187 |
-
|
| 188 |
col1, col2, col3 = st.columns([1, 1, 2])
|
| 189 |
|
| 190 |
with col1:
|
|
@@ -208,6 +254,5 @@ class ResultsManager:
|
|
| 208 |
)
|
| 209 |
|
| 210 |
with col3:
|
| 211 |
-
if st.button("Clear All Results", help="Clear all stored results"):
|
| 212 |
-
|
| 213 |
-
st.rerun()
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
import io
|
| 11 |
|
| 12 |
+
|
| 13 |
class ResultsManager:
|
| 14 |
"""Manages session-wide results for multi-file inference"""
|
| 15 |
|
|
|
|
| 74 |
if not results:
|
| 75 |
return pd.DataFrame()
|
| 76 |
|
| 77 |
+
# ===Flatten the results for DataFrame===
|
| 78 |
df_data = []
|
| 79 |
for result in results:
|
| 80 |
row = {
|
|
|
|
| 100 |
if df.empty:
|
| 101 |
return b""
|
| 102 |
|
| 103 |
+
# ===Use StringIO to create CSV in memory===
|
| 104 |
csv_buffer = io.StringIO()
|
| 105 |
df.to_csv(csv_buffer, index=False)
|
| 106 |
return csv_buffer.getvalue().encode('utf-8')
|
|
|
|
| 129 |
"avg_processing_time": sum(r["processing_time"] for r in results) / len(results),
|
| 130 |
"files_with_ground_truth": sum(1 for r in results if r["ground_truth"] is not None),
|
| 131 |
}
|
| 132 |
+
# ===Calculate accuracy if ground truth is available===
|
| 133 |
correct_predictions = sum(
|
| 134 |
+
1 for r in results
|
| 135 |
if r["ground_truth"] is not None and r["prediction"] == r["ground_truth"]
|
| 136 |
)
|
| 137 |
total_with_gt = stats["files_with_ground_truth"]
|
|
|
|
| 139 |
stats["accuracy"] = correct_predictions / total_with_gt
|
| 140 |
else:
|
| 141 |
stats["accuracy"] = None
|
| 142 |
+
|
| 143 |
return stats
|
| 144 |
|
| 145 |
@staticmethod
|
|
|
|
| 147 |
"""Remove a result by filename. Returns True if removed, False if not found."""
|
| 148 |
results = ResultsManager.get_results()
|
| 149 |
original_length = len(results)
|
| 150 |
+
|
| 151 |
# Filter out results with matching filename
|
| 152 |
st.session_state[ResultsManager.RESULTS_KEY] = [
|
| 153 |
r for r in results if r["filename"] != filename
|
| 154 |
]
|
| 155 |
+
|
| 156 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
# ==UTILITY FUNCTIONS==
|
| 160 |
+
def init_session_state():
|
| 161 |
+
"""Keep a persistent session state"""
|
| 162 |
+
defaults = {
|
| 163 |
+
"status_message": "Ready to analyze polymer spectra 🔬",
|
| 164 |
+
"status_type": "info",
|
| 165 |
+
"input_text": None,
|
| 166 |
+
"filename": None,
|
| 167 |
+
"input_source": None, # "upload", "batch" or "sample"
|
| 168 |
+
"sample_select": "-- Select Sample --",
|
| 169 |
+
"input_mode": "Upload File", # controls which pane is visible
|
| 170 |
+
"inference_run_once": False,
|
| 171 |
+
"x_raw": None, "y_raw": None, "y_resampled": None,
|
| 172 |
+
"log_messages": [],
|
| 173 |
+
"uploader_version": 0,
|
| 174 |
+
"current_upload_key": "upload_txt_0",
|
| 175 |
+
"active_tab": "Details",
|
| 176 |
+
"batch_mode": False,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
# Init session state with defaults
|
| 180 |
+
for key, value in defaults.items():
|
| 181 |
+
if key not in st.session_state:
|
| 182 |
+
st.session_state[key] = value
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def reset_ephemeral_state():
|
| 186 |
+
"""Comprehensive reset for the entire app state."""
|
| 187 |
+
# Define keys that should NOT be cleared by a full reset
|
| 188 |
+
keep_keys = {"model_select", "input_mode"}
|
| 189 |
+
|
| 190 |
+
for k in list(st.session_state.keys()):
|
| 191 |
+
if k not in keep_keys:
|
| 192 |
+
st.session_state.pop(k, None)
|
| 193 |
+
|
| 194 |
+
# Re-initialize the core state after clearing
|
| 195 |
+
ResultsManager.init_session_state()
|
| 196 |
+
|
| 197 |
+
# CRITICAL: Bump the uploader version to force a widget reset
|
| 198 |
+
st.session_state["uploader_version"] += 1
|
| 199 |
+
st.session_state["current_upload_key"] = f"upload_txt_{st.session_state['uploader_version']}"
|
| 200 |
+
|
| 201 |
@staticmethod
|
| 202 |
def display_results_table() -> None:
|
| 203 |
"""Display the results table in Streamlit UI"""
|
| 204 |
df = ResultsManager.get_results_dataframe()
|
| 205 |
|
| 206 |
if df.empty:
|
| 207 |
+
st.info(
|
| 208 |
+
"No inference results yet. Upload files and run analysis to see results here.")
|
| 209 |
return
|
| 210 |
|
| 211 |
st.subheader(f"Inference Results ({len(df)} files)")
|
| 212 |
|
| 213 |
+
# ==Summary stats==
|
| 214 |
stats = ResultsManager.get_summary_stats()
|
| 215 |
if stats:
|
| 216 |
col1, col2, col3, col4 = st.columns(4)
|
|
|
|
| 219 |
with col2:
|
| 220 |
st.metric("Avg Confidence", f"{stats['avg_confidence']:.3f}")
|
| 221 |
with col3:
|
| 222 |
+
st.metric(
|
| 223 |
+
"Stable/Weathered", f"{stats['stable_predictions']}/{stats['weathered_predictions']}")
|
| 224 |
with col4:
|
| 225 |
if stats["accuracy"] is not None:
|
| 226 |
st.metric("Accuracy", f"{stats['accuracy']:.3f}")
|
| 227 |
else:
|
| 228 |
st.metric("Accuracy", "N/A")
|
| 229 |
|
| 230 |
+
# ==Results Table==
|
| 231 |
st.dataframe(df, use_container_width=True)
|
| 232 |
|
| 233 |
+
# ==Export Button==
|
| 234 |
col1, col2, col3 = st.columns([1, 1, 2])
|
| 235 |
|
| 236 |
with col1:
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
with col3:
|
| 257 |
+
if st.button("Clear All Results", help="Clear all stored results", on_click=ResultsManager.reset_ephemeral_state):
|
| 258 |
+
st.rerun()
|
|
|