Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
Β·
9318b04
1
Parent(s):
345529d
FEAT(analyzer): Introduce centralized plot styling helper for theme-aware visualizations; enhance render_visual_diagnostics method with improved aesthetics and interactive filtering
Browse files- modules/analyzer.py +144 -102
modules/analyzer.py
CHANGED
|
@@ -17,21 +17,21 @@ from modules.ui_components import create_spectrum_plot
|
|
| 17 |
import hashlib
|
| 18 |
|
| 19 |
|
| 20 |
-
# --- NEW
|
| 21 |
@contextmanager
|
| 22 |
-
def
|
| 23 |
-
"""
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
try:
|
| 26 |
theme_opts = st.get_option("theme") or {}
|
| 27 |
except RuntimeError:
|
| 28 |
# Fallback to empty dict if theme config is not available
|
| 29 |
theme_opts = {}
|
| 30 |
-
|
| 31 |
text_color = theme_opts.get("textColor", "#000000")
|
| 32 |
bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
|
| 33 |
|
| 34 |
-
# Set Matplotlib's rcParams to match the theme
|
| 35 |
with plt.rc_context(
|
| 36 |
{
|
| 37 |
"figure.facecolor": bg_color,
|
|
@@ -42,12 +42,18 @@ def theme_aware_plot():
|
|
| 42 |
"ytick.color": text_color,
|
| 43 |
"grid.color": text_color,
|
| 44 |
"axes.edgecolor": text_color,
|
|
|
|
|
|
|
| 45 |
}
|
| 46 |
):
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
-
# --- END HELPER
|
| 51 |
|
| 52 |
|
| 53 |
class BatchAnalysis:
|
|
@@ -105,12 +111,10 @@ class BatchAnalysis:
|
|
| 105 |
),
|
| 106 |
)
|
| 107 |
|
| 108 |
-
# In modules/analyzer.py
|
| 109 |
-
|
| 110 |
def render_visual_diagnostics(self):
|
| 111 |
"""
|
| 112 |
-
Renders
|
| 113 |
-
|
| 114 |
"""
|
| 115 |
st.markdown("##### Visual Analysis")
|
| 116 |
if not self.has_ground_truth:
|
|
@@ -118,22 +122,18 @@ class BatchAnalysis:
|
|
| 118 |
return
|
| 119 |
|
| 120 |
valid_gt_df = self.df.dropna(subset=["Ground Truth"])
|
| 121 |
-
|
| 122 |
-
# Use a single row of columns for the two main plots
|
| 123 |
plot_col1, plot_col2 = st.columns(2)
|
| 124 |
|
| 125 |
-
# --- Chart 1: Confusion Matrix ---
|
| 126 |
-
with plot_col1:
|
| 127 |
-
with st.container(border=True):
|
| 128 |
st.markdown("**Confusion Matrix**")
|
| 129 |
cm = confusion_matrix(
|
| 130 |
valid_gt_df["Ground Truth"],
|
| 131 |
valid_gt_df["Prediction"],
|
| 132 |
labels=list(LABEL_MAP.keys()),
|
| 133 |
)
|
| 134 |
-
|
| 135 |
-
with theme_aware_plot(): # Apply theme-aware styling
|
| 136 |
-
fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
|
| 137 |
sns.heatmap(
|
| 138 |
cm,
|
| 139 |
annot=True,
|
|
@@ -145,58 +145,98 @@ class BatchAnalysis:
|
|
| 145 |
)
|
| 146 |
ax.set_ylabel("Actual Class", fontsize=12)
|
| 147 |
ax.set_xlabel("Predicted Class", fontsize=12)
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
st.caption("Click a cell below to filter the data grid:")
|
| 153 |
-
|
| 154 |
-
# Render CM filter buttons directly below the plot in the same column
|
| 155 |
-
cm_labels = list(LABEL_MAP.values())
|
| 156 |
-
for i, actual_label in enumerate(cm_labels):
|
| 157 |
-
btn_cols_row = st.columns(
|
| 158 |
-
len(cm_labels)
|
| 159 |
-
) # Create a row of columns for buttons
|
| 160 |
-
for j, predicted_label in enumerate(cm_labels):
|
| 161 |
-
cell_value = cm[i, j]
|
| 162 |
-
btn_cols_row[j].button( # Button for each cell
|
| 163 |
-
f"Actual: {actual_label}\nPred: {predicted_label} ({cell_value})",
|
| 164 |
-
key=f"cm_cell_{i}_{j}",
|
| 165 |
-
on_click=self._set_cm_filter,
|
| 166 |
-
args=(i, j, actual_label, predicted_label),
|
| 167 |
-
use_container_width=True,
|
| 168 |
-
)
|
| 169 |
-
# Clear filter button for CM
|
| 170 |
-
if st.session_state.get("cm_filter_active", False):
|
| 171 |
-
st.button(
|
| 172 |
-
"Clear Matrix Filter",
|
| 173 |
-
on_click=self._clear_cm_filter,
|
| 174 |
-
key="clear_cm_filter_btn_below",
|
| 175 |
-
)
|
| 176 |
|
| 177 |
-
# --- Chart 2: Confidence vs. Correctness Box Plot ---
|
| 178 |
-
with plot_col2:
|
| 179 |
-
with st.container(border=True):
|
| 180 |
st.markdown("**Confidence Analysis**")
|
| 181 |
valid_gt_df["Result"] = np.where(
|
| 182 |
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
|
| 183 |
-
"Correct",
|
| 184 |
-
"Incorrect",
|
| 185 |
)
|
| 186 |
-
|
| 187 |
-
with theme_aware_plot(): # Apply theme-aware styling
|
| 188 |
-
fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
|
| 189 |
sns.boxplot(
|
| 190 |
x="Result",
|
| 191 |
y="Confidence",
|
| 192 |
data=valid_gt_df,
|
| 193 |
ax=ax,
|
| 194 |
-
palette={"Correct": "
|
| 195 |
)
|
| 196 |
ax.set_ylabel("Model Confidence", fontsize=12)
|
| 197 |
-
ax.set_xlabel("Prediction
|
|
|
|
| 198 |
st.pyplot(fig, use_container_width=True)
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def _set_cm_filter(
|
| 202 |
self,
|
|
@@ -235,56 +275,58 @@ class BatchAnalysis:
|
|
| 235 |
# Start with a full copy of the dataframe to apply filters to
|
| 236 |
filtered_df = self.df.copy()
|
| 237 |
|
| 238 |
-
# --- Filter Section ---
|
| 239 |
-
st.
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
# Filter 2: By Ground Truth Correctness (if available)
|
| 251 |
-
if self.has_ground_truth:
|
| 252 |
-
filtered_df["Correct"] = (
|
| 253 |
-
filtered_df["Prediction"] == filtered_df["Ground Truth"]
|
| 254 |
)
|
| 255 |
-
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
#
|
| 258 |
-
|
| 259 |
-
filtered_df["Correct"]
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
)
|
| 267 |
-
# Filter based on the boolean 'Correct' column
|
| 268 |
-
filter_correctness_bools = [
|
| 269 |
-
True if c == "β
Correct" else False for c in selected_correctness
|
| 270 |
-
]
|
| 271 |
filtered_df = filtered_df[
|
| 272 |
-
filtered_df["
|
|
|
|
| 273 |
]
|
| 274 |
-
|
| 275 |
-
# --- NEW: Filter 3: By Confidence Range ---
|
| 276 |
-
min_conf, max_conf = filter_cols[2].slider(
|
| 277 |
-
"Filter by Confidence Range:",
|
| 278 |
-
min_value=0.0,
|
| 279 |
-
max_value=1.0,
|
| 280 |
-
value=(0.0, 1.0), # Default to the full range
|
| 281 |
-
step=0.01,
|
| 282 |
-
)
|
| 283 |
-
filtered_df = filtered_df[
|
| 284 |
-
(filtered_df["Confidence"] >= min_conf)
|
| 285 |
-
& (filtered_df["Confidence"] <= max_conf)
|
| 286 |
-
]
|
| 287 |
-
# --- END NEW FILTER ---
|
| 288 |
|
| 289 |
# Apply Confusion Matrix Drill-Down Filter (if active)
|
| 290 |
if st.session_state.get("cm_filter_active", False):
|
|
|
|
| 17 |
import hashlib
|
| 18 |
|
| 19 |
|
| 20 |
+
# --- NEW: Centralized plot styling helper ---
|
| 21 |
@contextmanager
|
| 22 |
+
def plot_style_context(figsize=(5, 4), constrained_layout=True, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
A context manager to apply consistent Matplotlib styling and
|
| 25 |
+
make plots theme-aware.
|
| 26 |
+
"""
|
| 27 |
try:
|
| 28 |
theme_opts = st.get_option("theme") or {}
|
| 29 |
except RuntimeError:
|
| 30 |
# Fallback to empty dict if theme config is not available
|
| 31 |
theme_opts = {}
|
|
|
|
| 32 |
text_color = theme_opts.get("textColor", "#000000")
|
| 33 |
bg_color = theme_opts.get("backgroundColor", "#FFFFFF")
|
| 34 |
|
|
|
|
| 35 |
with plt.rc_context(
|
| 36 |
{
|
| 37 |
"figure.facecolor": bg_color,
|
|
|
|
| 42 |
"ytick.color": text_color,
|
| 43 |
"grid.color": text_color,
|
| 44 |
"axes.edgecolor": text_color,
|
| 45 |
+
"axes.titlecolor": text_color, # Ensure title color matches
|
| 46 |
+
"figure.autolayout": True, # Auto-adjusts subplot params for a tight layout
|
| 47 |
}
|
| 48 |
):
|
| 49 |
+
fig, ax = plt.subplots(
|
| 50 |
+
figsize=figsize, constrained_layout=constrained_layout, **kwargs
|
| 51 |
+
)
|
| 52 |
+
yield fig, ax
|
| 53 |
+
plt.close(fig) # Always close figure to prevent memory leaks
|
| 54 |
|
| 55 |
|
| 56 |
+
# --- END NEW HELPER ---
|
| 57 |
|
| 58 |
|
| 59 |
class BatchAnalysis:
|
|
|
|
| 111 |
),
|
| 112 |
)
|
| 113 |
|
|
|
|
|
|
|
| 114 |
def render_visual_diagnostics(self):
|
| 115 |
"""
|
| 116 |
+
Renders diagnostic plots with corrected aesthetics and a robust,
|
| 117 |
+
interactive drill-down filter using st.selectbox.
|
| 118 |
"""
|
| 119 |
st.markdown("##### Visual Analysis")
|
| 120 |
if not self.has_ground_truth:
|
|
|
|
| 122 |
return
|
| 123 |
|
| 124 |
valid_gt_df = self.df.dropna(subset=["Ground Truth"])
|
|
|
|
|
|
|
| 125 |
plot_col1, plot_col2 = st.columns(2)
|
| 126 |
|
| 127 |
+
# --- Chart 1: Confusion Matrix (Aesthetically Corrected) ---
|
| 128 |
+
with plot_col1:
|
| 129 |
+
with st.container(border=True):
|
| 130 |
st.markdown("**Confusion Matrix**")
|
| 131 |
cm = confusion_matrix(
|
| 132 |
valid_gt_df["Ground Truth"],
|
| 133 |
valid_gt_df["Prediction"],
|
| 134 |
labels=list(LABEL_MAP.keys()),
|
| 135 |
)
|
| 136 |
+
with plot_style_context() as (fig, ax):
|
|
|
|
|
|
|
| 137 |
sns.heatmap(
|
| 138 |
cm,
|
| 139 |
annot=True,
|
|
|
|
| 145 |
)
|
| 146 |
ax.set_ylabel("Actual Class", fontsize=12)
|
| 147 |
ax.set_xlabel("Predicted Class", fontsize=12)
|
| 148 |
+
|
| 149 |
+
# --- AESTHETIC FIX: Rotate X-labels vertically for a compact look ---
|
| 150 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
| 151 |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
|
| 152 |
+
ax.set_title("Prediction vs. Actual (Counts)", fontsize=14)
|
| 153 |
+
st.pyplot(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
# --- Chart 2: Confidence vs. Correctness Box Plot (Unchanged) ---
|
| 156 |
+
with plot_col2:
|
| 157 |
+
with st.container(border=True):
|
| 158 |
st.markdown("**Confidence Analysis**")
|
| 159 |
valid_gt_df["Result"] = np.where(
|
| 160 |
valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
|
| 161 |
+
"β
Correct",
|
| 162 |
+
"β Incorrect",
|
| 163 |
)
|
| 164 |
+
with plot_style_context() as (fig, ax):
|
|
|
|
|
|
|
| 165 |
sns.boxplot(
|
| 166 |
x="Result",
|
| 167 |
y="Confidence",
|
| 168 |
data=valid_gt_df,
|
| 169 |
ax=ax,
|
| 170 |
+
palette={"β
Correct": "lightgreen", "β Incorrect": "salmon"},
|
| 171 |
)
|
| 172 |
ax.set_ylabel("Model Confidence", fontsize=12)
|
| 173 |
+
ax.set_xlabel("Prediction Outcome", fontsize=12)
|
| 174 |
+
ax.set_title("Confidence Distribution by Outcome", fontsize=14)
|
| 175 |
st.pyplot(fig, use_container_width=True)
|
| 176 |
+
|
| 177 |
+
st.divider()
|
| 178 |
+
|
| 179 |
+
# --- FUNCTIONALITY FIX: Replace Button Grid with st.selectbox ---
|
| 180 |
+
st.markdown("###### Interactive Confusion Matrix Drill-Down")
|
| 181 |
+
st.caption(
|
| 182 |
+
"Select a cell from the dropdown to filter the data grid in the 'Results Explorer' tab."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Create a list of options for the selectbox from the confusion matrix
|
| 186 |
+
cm = confusion_matrix(
|
| 187 |
+
valid_gt_df["Ground Truth"],
|
| 188 |
+
valid_gt_df["Prediction"],
|
| 189 |
+
labels=list(LABEL_MAP.keys()),
|
| 190 |
+
)
|
| 191 |
+
cm_labels = list(LABEL_MAP.values())
|
| 192 |
+
options = ["-- Select a cell to filter --"]
|
| 193 |
+
|
| 194 |
+
# This nested loop creates the human-readable options for the dropdown
|
| 195 |
+
for i, actual_label in enumerate(cm_labels):
|
| 196 |
+
for j, predicted_label in enumerate(cm_labels):
|
| 197 |
+
cell_value = cm[i, j]
|
| 198 |
+
# We only add cells with content to the dropdown to avoid clutter
|
| 199 |
+
if cell_value > 0:
|
| 200 |
+
option_str = f"Actual: {actual_label} | Predicted: {predicted_label} ({cell_value} files)"
|
| 201 |
+
options.append(option_str)
|
| 202 |
+
|
| 203 |
+
# The selectbox widget, which is more robust for state management
|
| 204 |
+
selected_option = st.selectbox(
|
| 205 |
+
"Drill-Down Filter",
|
| 206 |
+
options=options,
|
| 207 |
+
key="cm_selectbox", # Give it a key to track its state
|
| 208 |
+
index=0, # Default to the placeholder
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Logic to activate or deactivate the filter based on selection
|
| 212 |
+
if selected_option != "-- Select a cell to filter --":
|
| 213 |
+
# Parse the selection to get the actual and predicted classes
|
| 214 |
+
parts = selected_option.split("|")
|
| 215 |
+
actual_str = parts[0].replace("Actual: ", "").strip()
|
| 216 |
+
# FIX: Split on " (" to get the full label without the file count
|
| 217 |
+
predicted_str = parts[1].replace("Predicted: ", "").split(" (")[0].strip()
|
| 218 |
+
|
| 219 |
+
# Find the corresponding numeric indices with error handling
|
| 220 |
+
actual_matching = [k for k, v in LABEL_MAP.items() if v == actual_str]
|
| 221 |
+
predicted_matching = [k for k, v in LABEL_MAP.items() if v == predicted_str]
|
| 222 |
+
|
| 223 |
+
if not actual_matching or not predicted_matching:
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
actual_idx = actual_matching[0]
|
| 227 |
+
predicted_idx = predicted_matching[0]
|
| 228 |
+
|
| 229 |
+
# Use a simplified callback-like update to session state
|
| 230 |
+
st.session_state["cm_actual_filter"] = actual_idx
|
| 231 |
+
st.session_state["cm_predicted_filter"] = predicted_idx
|
| 232 |
+
st.session_state["cm_filter_label"] = (
|
| 233 |
+
f"Actual: {actual_str}, Predicted: {predicted_str}"
|
| 234 |
+
)
|
| 235 |
+
st.session_state["cm_filter_active"] = True
|
| 236 |
+
else:
|
| 237 |
+
# If the user selects the placeholder, deactivate the filter
|
| 238 |
+
if st.session_state.get("cm_filter_active", False):
|
| 239 |
+
self._clear_cm_filter()
|
| 240 |
|
| 241 |
def _set_cm_filter(
|
| 242 |
self,
|
|
|
|
| 275 |
# Start with a full copy of the dataframe to apply filters to
|
| 276 |
filtered_df = self.df.copy()
|
| 277 |
|
| 278 |
+
# --- Filter Section (STREAMLINED LAYOUT) ---
|
| 279 |
+
with st.container(border=True):
|
| 280 |
+
st.markdown("**Filters**")
|
| 281 |
+
filter_row1 = st.columns([1, 1])
|
| 282 |
+
filter_row2 = st.columns(1) # Slider takes full width
|
| 283 |
+
|
| 284 |
+
# Filter 1: By Predicted Class
|
| 285 |
+
selected_classes = filter_row1[0].multiselect(
|
| 286 |
+
"Filter by Prediction:",
|
| 287 |
+
options=self.df["Predicted Class"].unique(),
|
| 288 |
+
default=self.df["Predicted Class"].unique(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
)
|
| 290 |
+
filtered_df = filtered_df[
|
| 291 |
+
filtered_df["Predicted Class"].isin(selected_classes)
|
| 292 |
+
]
|
| 293 |
|
| 294 |
+
# Filter 2: By Ground Truth Correctness (if available)
|
| 295 |
+
if self.has_ground_truth:
|
| 296 |
+
filtered_df["Correct"] = (
|
| 297 |
+
filtered_df["Prediction"] == filtered_df["Ground Truth"]
|
| 298 |
+
)
|
| 299 |
+
correctness_options = ["β
Correct", "β Incorrect"]
|
| 300 |
+
filtered_df["Result_Display"] = np.where(
|
| 301 |
+
filtered_df["Correct"], "β
Correct", "β Incorrect"
|
| 302 |
+
)
|
| 303 |
|
| 304 |
+
selected_correctness = filter_row1[1].multiselect(
|
| 305 |
+
"Filter by Result:",
|
| 306 |
+
options=correctness_options,
|
| 307 |
+
default=correctness_options,
|
| 308 |
+
)
|
| 309 |
+
filter_correctness_bools = [
|
| 310 |
+
True if c == "β
Correct" else False for c in selected_correctness
|
| 311 |
+
]
|
| 312 |
+
filtered_df = filtered_df[
|
| 313 |
+
filtered_df["Correct"].isin(filter_correctness_bools)
|
| 314 |
+
]
|
| 315 |
+
|
| 316 |
+
# Filter 3: By Confidence Range (full width below others)
|
| 317 |
+
min_conf, max_conf = filter_row2[0].slider(
|
| 318 |
+
"Filter by Confidence Range:",
|
| 319 |
+
min_value=0.0,
|
| 320 |
+
max_value=1.0,
|
| 321 |
+
value=(0.0, 1.0),
|
| 322 |
+
step=0.01,
|
| 323 |
+
format="%.2f", # Format slider display for clarity
|
| 324 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
filtered_df = filtered_df[
|
| 326 |
+
(filtered_df["Confidence"] >= min_conf)
|
| 327 |
+
& (filtered_df["Confidence"] <= max_conf)
|
| 328 |
]
|
| 329 |
+
# --- END FILTER SECTION ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
# Apply Confusion Matrix Drill-Down Filter (if active)
|
| 332 |
if st.session_state.get("cm_filter_active", False):
|