Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
182c9ce
1
Parent(s):
b1b7e3c
(FIX): Streamline remove redundant scripts.preprocess_dataset import; enhance resampling logic with diagnosis
Browse files- Refactors spectrum resampling + improve diagnostics
- Introduces robust checks for strictly increasing sequences in resampling results, logging ambigous cases
- Adds session state persistence for both raw and resampled data, and enriches diagnostics with detailed statistics about the resampled data
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from models.resnet_cnn import ResNet1D
|
| 2 |
from models.figure2_cnn import Figure2CNN
|
| 3 |
-
import logging
|
| 4 |
import hashlib
|
| 5 |
import gc
|
| 6 |
import time
|
|
@@ -22,14 +21,7 @@ if utils_path.is_dir() and str(utils_path) not in sys.path:
|
|
| 22 |
matplotlib.use("Agg") # ensure headless rendering in Spaces
|
| 23 |
|
| 24 |
# Import local modules
|
| 25 |
-
|
| 26 |
-
try:
|
| 27 |
-
from scripts.preprocess_dataset import resample_spectrum
|
| 28 |
-
except (ImportError, ModuleNotFoundError):
|
| 29 |
-
try:
|
| 30 |
-
from utils.preprocessing import resample_spectrum
|
| 31 |
-
except (ImportError, ModuleNotFoundError):
|
| 32 |
-
raise ImportError("Could not import 'resample_spectrum' from either 'scripts.preprocess_dataset' or 'utils.preprocessing'. Please ensure the function exists in one of these modules.")
|
| 33 |
|
| 34 |
KEEP_KEYS = {
|
| 35 |
# === global UI context we want to keep after "Reset" ===
|
|
@@ -129,7 +121,7 @@ def label_file(filename: str) -> int:
|
|
| 129 |
def load_state_dict(_mtime, model_path):
|
| 130 |
"""Load state dict with mtime in cache key to detect file changes"""
|
| 131 |
try:
|
| 132 |
-
return torch.load(model_path, map_location="cpu")
|
| 133 |
except (FileNotFoundError, RuntimeError) as e:
|
| 134 |
st.warning(f"Error loading state dict: {e}")
|
| 135 |
return None
|
|
@@ -235,11 +227,11 @@ def parse_spectrum_data(raw_text):
|
|
| 235 |
return x, y
|
| 236 |
|
| 237 |
|
| 238 |
-
def create_spectrum_plot(x_raw, y_raw, y_resampled):
|
| 239 |
"""Create spectrum visualization plot"""
|
| 240 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 241 |
|
| 242 |
-
# Raw spectrum
|
| 243 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 244 |
ax[0].set_title("Raw Input Spectrum")
|
| 245 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
@@ -247,19 +239,16 @@ def create_spectrum_plot(x_raw, y_raw, y_resampled):
|
|
| 247 |
ax[0].grid(True, alpha=0.3)
|
| 248 |
ax[0].legend()
|
| 249 |
|
| 250 |
-
# Resampled spectrum
|
| 251 |
-
x_resampled =
|
| 252 |
-
ax[1].
|
| 253 |
-
color="steelblue", linewidth=1)
|
| 254 |
-
ax[1].set_title(f"Resampled ({TARGET_LEN} points)")
|
| 255 |
ax[1].set_xlabel("Wavenumber (cm⁻¹)")
|
| 256 |
ax[1].set_ylabel("Intensity")
|
| 257 |
ax[1].grid(True, alpha=0.3)
|
| 258 |
ax[1].legend()
|
| 259 |
|
| 260 |
plt.tight_layout()
|
| 261 |
-
|
| 262 |
-
# Convert to image
|
| 263 |
buf = io.BytesIO()
|
| 264 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
|
| 265 |
buf.seek(0)
|
|
@@ -546,7 +535,30 @@ def main():
|
|
| 546 |
|
| 547 |
# Resample
|
| 548 |
with st.spinner("Resampling spectrum..."):
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
# Persist results (drives right column)
|
| 552 |
st.session_state["x_raw"] = x_raw
|
|
@@ -571,6 +583,7 @@ def main():
|
|
| 571 |
# Get data from session state
|
| 572 |
x_raw = st.session_state.get('x_raw')
|
| 573 |
y_raw = st.session_state.get('y_raw')
|
|
|
|
| 574 |
y_resampled = st.session_state.get('y_resampled')
|
| 575 |
filename = st.session_state.get('filename', 'Unknown')
|
| 576 |
|
|
@@ -578,8 +591,7 @@ def main():
|
|
| 578 |
|
| 579 |
# Create and display plot
|
| 580 |
try:
|
| 581 |
-
spectrum_plot = create_spectrum_plot(
|
| 582 |
-
x_raw, y_raw, y_resampled)
|
| 583 |
st.image(
|
| 584 |
spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
|
| 585 |
except (ValueError, RuntimeError, TypeError) as e:
|
|
@@ -706,6 +718,33 @@ def main():
|
|
| 706 |
st.text_area("Logs", "\n".join(
|
| 707 |
st.session_state.get("log_messages", [])), height=200)
|
| 708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
with tab3:
|
| 710 |
st.markdown("""
|
| 711 |
**🔍 Analysis Process**
|
|
|
|
| 1 |
from models.resnet_cnn import ResNet1D
|
| 2 |
from models.figure2_cnn import Figure2CNN
|
|
|
|
| 3 |
import hashlib
|
| 4 |
import gc
|
| 5 |
import time
|
|
|
|
| 21 |
matplotlib.use("Agg") # ensure headless rendering in Spaces
|
| 22 |
|
| 23 |
# Import local modules
|
| 24 |
+
from utils.preprocessing import resample_spectrum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
KEEP_KEYS = {
|
| 27 |
# === global UI context we want to keep after "Reset" ===
|
|
|
|
| 121 |
def load_state_dict(_mtime, model_path):
|
| 122 |
"""Load state dict with mtime in cache key to detect file changes"""
|
| 123 |
try:
|
| 124 |
+
return torch.load(model_path, map_location="cpu", weights_only=True)
|
| 125 |
except (FileNotFoundError, RuntimeError) as e:
|
| 126 |
st.warning(f"Error loading state dict: {e}")
|
| 127 |
return None
|
|
|
|
| 227 |
return x, y
|
| 228 |
|
| 229 |
|
| 230 |
+
def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
|
| 231 |
"""Create spectrum visualization plot"""
|
| 232 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
| 233 |
|
| 234 |
+
# == Raw spectrum ==
|
| 235 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
| 236 |
ax[0].set_title("Raw Input Spectrum")
|
| 237 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
|
|
| 239 |
ax[0].grid(True, alpha=0.3)
|
| 240 |
ax[0].legend()
|
| 241 |
|
| 242 |
+
# == Resampled spectrum ==
|
| 243 |
+
ax[1].plot(x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1)
|
| 244 |
+
ax[1].set_title(f"Resampled ({len(y_resampled)} points)")
|
|
|
|
|
|
|
| 245 |
ax[1].set_xlabel("Wavenumber (cm⁻¹)")
|
| 246 |
ax[1].set_ylabel("Intensity")
|
| 247 |
ax[1].grid(True, alpha=0.3)
|
| 248 |
ax[1].legend()
|
| 249 |
|
| 250 |
plt.tight_layout()
|
| 251 |
+
# == Convert to image ==
|
|
|
|
| 252 |
buf = io.BytesIO()
|
| 253 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
|
| 254 |
buf.seek(0)
|
|
|
|
| 535 |
|
| 536 |
# Resample
|
| 537 |
with st.spinner("Resampling spectrum..."):
|
| 538 |
+
# ===Resample Unpack===
|
| 539 |
+
r1, r2 = resample_spectrum(x_raw, y_raw, TARGET_LEN)
|
| 540 |
+
|
| 541 |
+
def _is_strictly_increasing(a):
|
| 542 |
+
try:
|
| 543 |
+
a = np.asarray(a)
|
| 544 |
+
return a.ndim == 1 and a.size >= 2 and np.all(np.diff(a) > 0)
|
| 545 |
+
except Exception:
|
| 546 |
+
return False
|
| 547 |
+
|
| 548 |
+
if _is_strictly_increasing(r1) and not _is_strictly_increasing(r2):
|
| 549 |
+
x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
|
| 550 |
+
elif _is_strictly_increasing(r2) and not _is_strictly_increasing(r1):
|
| 551 |
+
x_resampled, y_resampled = np.asarray(r2), np.asarray(r1)
|
| 552 |
+
else:
|
| 553 |
+
# == Ambigous; assume (x, y) and log
|
| 554 |
+
x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
|
| 555 |
+
log_message("Resample outputs ambigous; assumed (x, y).")
|
| 556 |
+
|
| 557 |
+
# ===Persists for plotting + inference===
|
| 558 |
+
st.session_state["x_raw"] = x_raw
|
| 559 |
+
st.session_state["y_raw"] = y_raw
|
| 560 |
+
st.session_state["x_resampled"] = x_resampled # ←-- NEW
|
| 561 |
+
st.session_state["y_resampled"] = y_resampled
|
| 562 |
|
| 563 |
# Persist results (drives right column)
|
| 564 |
st.session_state["x_raw"] = x_raw
|
|
|
|
| 583 |
# Get data from session state
|
| 584 |
x_raw = st.session_state.get('x_raw')
|
| 585 |
y_raw = st.session_state.get('y_raw')
|
| 586 |
+
x_resampled = st.session_state.get('x_resampled') # ← NEW
|
| 587 |
y_resampled = st.session_state.get('y_resampled')
|
| 588 |
filename = st.session_state.get('filename', 'Unknown')
|
| 589 |
|
|
|
|
| 591 |
|
| 592 |
# Create and display plot
|
| 593 |
try:
|
| 594 |
+
spectrum_plot = create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled)
|
|
|
|
| 595 |
st.image(
|
| 596 |
spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
|
| 597 |
except (ValueError, RuntimeError, TypeError) as e:
|
|
|
|
| 718 |
st.text_area("Logs", "\n".join(
|
| 719 |
st.session_state.get("log_messages", [])), height=200)
|
| 720 |
|
| 721 |
+
try:
|
| 722 |
+
resampler_mod = getattr(resample_spectrum, "__module__", "unknown")
|
| 723 |
+
resampler_doc = getattr(resample_spectrum, "__doc__", None)
|
| 724 |
+
resampler_doc = resampler_doc.splitlines()[0] if isinstance(resampler_doc, str) and resampler_doc else "no doc"
|
| 725 |
+
|
| 726 |
+
y_rs = st.session_state.get("y_resampled", None)
|
| 727 |
+
diag = {}
|
| 728 |
+
if y_rs is not None:
|
| 729 |
+
arr = np.asarray(y_rs)
|
| 730 |
+
diag = {
|
| 731 |
+
"y_resampled_len": int(arr.size),
|
| 732 |
+
"y_resampled_min": float(np.min(arr)) if arr.size else None,
|
| 733 |
+
"y_resampled_max": float(np.max(arr)) if arr.size else None,
|
| 734 |
+
"y_resampled_ptp": float(np.ptp(arr)) if arr.size else None,
|
| 735 |
+
"y_resampled_unique": int(np.unique(arr).size) if arr.size else None,
|
| 736 |
+
"y_resampled_all_equal": bool(np.ptp(arr) == 0.0) if arr.size else None,
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
st.markdown("**Resampler Info")
|
| 740 |
+
st.json({
|
| 741 |
+
"module": resampler_mod,
|
| 742 |
+
"doc": resampler_doc,
|
| 743 |
+
**({"y_resampled_stats": diag} if diag else {})
|
| 744 |
+
})
|
| 745 |
+
except Exception as _e:
|
| 746 |
+
st.warning(f"Diagnostics skipped: {_e}")
|
| 747 |
+
|
| 748 |
with tab3:
|
| 749 |
st.markdown("""
|
| 750 |
**🔍 Analysis Process**
|