devjas1 commited on
Commit
342a6af
·
1 Parent(s): ec8ba60

(FEAT/REFAC)[Multifile Support Streamlined]: Enhance multifile batch processing for robust spectrum analysis

Browse files

- Refactored `utils/multifile.py` to support multi-format batch uploads (TXT, CSV, JSON) with automatic format detection.
- Added robust parsing functions for JSON, CSV, and TXT spectrum data, including flexible key mapping and error handling.
- Implemented validation for monotonicity, NaN values, and reasonable wavenumber ranges.
- Improved single-file and batch processing workflows:
- `process_single_file`: Handles parsing, preprocessing, inference, confidence calculation, and result aggregation.
- `process_multiple_files`: Iteratively processes uploaded files, updates progress, and logs results.
- Integrated with ResultsManager for session-wide result tracking and visualization.
- Enhanced error logging and warning reporting via ErrorHandler.
- Supports modality-aware preprocessing and model selection.
- Optimized for large-scale, multi-model batch analysis

Files changed (1) hide show
  1. utils/multifile.py +66 -199
utils/multifile.py CHANGED
@@ -12,10 +12,11 @@ import csv
12
  import io
13
  from pathlib import Path
14
 
15
- from .preprocessing import resample_spectrum
16
  from .errors import ErrorHandler, safe_execute
17
  from .results_manager import ResultsManager
18
  from .confidence import calculate_softmax_confidence
 
19
 
20
 
21
  def detect_file_format(filename: str, content: str) -> str:
@@ -255,49 +256,45 @@ def parse_spectrum_data(
255
  def parse_txt_spectrum(
256
  content: str, filename: str = "unknown"
257
  ) -> Tuple[np.ndarray, np.ndarray]:
258
- """
259
- Parse spectrum data from TXT format (original implementation).
260
- """
261
  lines = content.strip().split("\n")
 
262
 
263
- # ==Remove empty lines and comments==
264
- data_lines = []
265
- for line in lines:
266
  line = line.strip()
267
- if line and not line.startswith("#") and not line.startswith("%"):
268
- data_lines.append(line)
269
-
270
- if not data_lines:
271
- raise ValueError("No data lines found in file")
272
-
273
- # ==Try to parse==
274
- x_vals, y_vals = [], []
275
 
276
- for i, line in enumerate(data_lines):
277
  try:
278
  # Handle different separators
279
- parts = line.replace(",", " ").split()
280
- numbers = [
281
- p
282
- for p in parts
283
- if p.replace(".", "", 1)
284
- .replace("-", "", 1)
285
- .replace("+", "", 1)
286
- .isdigit()
287
- ]
 
 
288
  if len(numbers) >= 2:
289
- x_val = float(numbers[0])
290
- y_val = float(numbers[1])
291
- x_vals.append(x_val)
292
- y_vals.append(y_val)
 
 
 
293
 
294
- except ValueError:
295
  ErrorHandler.log_warning(
296
- f"Could not parse line {i+1}: {line}", f"Parsing {filename}"
 
297
  )
298
  continue
299
 
300
- if len(x_vals) < 10: # ==Need minimum points for interpolation==
301
  raise ValueError(
302
  f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
303
  )
@@ -342,9 +339,10 @@ def process_single_file(
342
  filename: str,
343
  text_content: str,
344
  model_choice: str,
345
- load_model_func,
346
  run_inference_func,
347
  label_file_func,
 
 
348
  ) -> Optional[Dict[str, Any]]:
349
  """
350
  Process a single spectrum file
@@ -353,7 +351,6 @@ def process_single_file(
353
  filename: Name of the file
354
  text_content: Raw text content
355
  model_choice: Selected model name
356
- load_model_func: Function to load the model
357
  run_inference_func: Function to run inference
358
  label_file_func: Function to extract ground truth label
359
 
@@ -363,51 +360,21 @@ def process_single_file(
363
  start_time = time.time()
364
 
365
  try:
366
- # ==Parse spectrum data==
367
- result, success = safe_execute(
368
- parse_spectrum_data,
369
- text_content,
370
- filename,
371
- error_context=f"parsing {filename}",
372
- show_error=False,
373
- )
374
-
375
- if not success or result is None:
376
- return None
377
 
378
- x_raw, y_raw = result
379
-
380
- # ==Resample spectrum==
381
- result, success = safe_execute(
382
- resample_spectrum,
383
- x_raw,
384
- y_raw,
385
- 500, # TARGET_LEN
386
- error_context=f"resampling {filename}",
387
- show_error=False,
388
  )
389
 
390
- if not success or result is None:
391
- return None
392
-
393
- x_resampled, y_resampled = result
394
-
395
- # ==Run inference==
396
- result, success = safe_execute(
397
- run_inference_func,
398
- y_resampled,
399
- model_choice,
400
- error_context=f"inference on {filename}",
401
- show_error=False,
402
  )
403
 
404
- if not success or result is None:
405
- ErrorHandler.log_error(
406
- Exception("Inference failed"), f"processing {filename}"
407
- )
408
- return None
409
-
410
- prediction, logits_list, probs, inference_time, logits = result
411
 
412
  # ==Calculate confidence==
413
  if logits is not None:
@@ -415,28 +382,28 @@ def process_single_file(
415
  calculate_softmax_confidence(logits)
416
  )
417
  else:
418
- probs_np = np.array([])
419
- max_confidence = 0.0
 
420
  confidence_level = "LOW"
421
  confidence_emoji = "🔴"
422
 
423
  # ==Get ground truth==
424
- try:
425
- ground_truth = label_file_func(filename)
426
- ground_truth = ground_truth if ground_truth >= 0 else None
427
- except Exception:
428
- ground_truth = None
429
 
430
  # ==Get predicted class==
431
  label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
432
- predicted_class = label_map.get(prediction, f"Unknown ({prediction})")
433
 
434
  processing_time = time.time() - start_time
435
 
436
  return {
437
  "filename": filename,
438
  "success": True,
439
- "prediction": prediction,
440
  "predicted_class": predicted_class,
441
  "confidence": max_confidence,
442
  "confidence_level": confidence_level,
@@ -464,9 +431,9 @@ def process_single_file(
464
  def process_multiple_files(
465
  uploaded_files: List,
466
  model_choice: str,
467
- load_model_func,
468
  run_inference_func,
469
  label_file_func,
 
470
  progress_callback=None,
471
  ) -> List[Dict[str, Any]]:
472
  """
@@ -475,7 +442,6 @@ def process_multiple_files(
475
  Args:
476
  uploaded_files: List of uploaded file objects
477
  model_choice: Selected model name
478
- load_model_func: Function to load the model
479
  run_inference_func: Function to run inference
480
  label_file_func: Function to extract ground truth label
481
  progress_callback: Optional callback to update progress
@@ -486,7 +452,9 @@ def process_multiple_files(
486
  results = []
487
  total_files = len(uploaded_files)
488
 
489
- ErrorHandler.log_info(f"Starting batch processing of {total_files} files")
 
 
490
 
491
  for i, uploaded_file in enumerate(uploaded_files):
492
  if progress_callback:
@@ -499,12 +467,13 @@ def process_multiple_files(
499
 
500
  # ==Process the file==
501
  result = process_single_file(
502
- uploaded_file.name,
503
- text_content,
504
- model_choice,
505
- load_model_func,
506
- run_inference_func,
507
- label_file_func,
 
508
  )
509
 
510
  if result:
@@ -524,6 +493,11 @@ def process_multiple_files(
524
  metadata={
525
  "confidence_level": result["confidence_level"],
526
  "confidence_emoji": result["confidence_emoji"],
 
 
 
 
 
527
  },
528
  )
529
 
@@ -545,110 +519,3 @@ def process_multiple_files(
545
  )
546
 
547
  return results
548
-
549
-
550
- def display_batch_results(batch_results: list):
551
- """Renders a clean, consolidated summary of batch processing results using metrics and a pandas DataFrame replacing the old expander list"""
552
- if not batch_results:
553
- st.info("No batch results to display.")
554
- return
555
-
556
- successful_runs = [r for r in batch_results if r.get("success", False)]
557
- failed_runs = [r for r in batch_results if not r.get("success", False)]
558
-
559
- # 1. High Level Metrics
560
- st.markdown("###### Batch Summary")
561
- metric_cols = st.columns(3)
562
- metric_cols[0].metric("Total Files Processed", f"{len(batch_results)}")
563
- metric_cols[1].metric("✔️ Successful", f"{len(successful_runs)}")
564
- metric_cols[2].metric("❌ Failed", f"{len(failed_runs)}")
565
-
566
- # 3 Hidden Failure Details
567
- if failed_runs:
568
- with st.expander(
569
- f"View details for {len(failed_runs)} failed file(s)", expanded=False
570
- ):
571
- for r in failed_runs:
572
- st.error(f"**File:** `{r.get('filename', 'unknown')}`")
573
- st.caption(
574
- f"Reason for failure: {r.get('error', 'No details provided')}"
575
- )
576
-
577
-
578
- # Legacy display batch results
579
- # def display_batch_results(results: List[Dict[str, Any]]) -> None:
580
- # """
581
- # Display batch processing results in the UI
582
-
583
- # Args:
584
- # results: List of processing results
585
- # """
586
- # if not results:
587
- # st.warning("No results to display")
588
- # return
589
-
590
- # successful = [r for r in results if r.get("success", False)]
591
- # failed = [r for r in results if not r.get("success", False)]
592
-
593
- # # ==Summary==
594
- # col1, col2, col3 = st.columns(3, border=True)
595
- # with col1:
596
- # st.metric("Total Files", len(results))
597
- # with col2:
598
- # st.metric("Successful", len(successful),
599
- # delta=f"{len(successful)/len(results)*100:.1f}%")
600
- # with col3:
601
- # st.metric("Failed", len(
602
- # failed), delta=f"-{len(failed)/len(results)*100:.1f}%" if failed else "0%")
603
-
604
- # # ==Results tabs==
605
- # tab1, tab2 = st.tabs(["✅Successful", "❌ Failed"], width="stretch")
606
-
607
- # with tab1:
608
- # with st.expander("Successful"):
609
- # if successful:
610
- # for result in successful:
611
- # with st.expander(f"{result['filename']}", expanded=False):
612
- # col1, col2 = st.columns(2)
613
- # with col1:
614
- # st.write(
615
- # f"**Prediction:** {result['predicted_class']}")
616
- # st.write(
617
- # f"**Confidence:** {result['confidence_emoji']} {result['confidence_level']} ({result['confidence']:.3f})")
618
- # with col2:
619
- # st.write(
620
- # f"**Processing Time:** {result['processing_time']:.3f}s")
621
- # if result['ground_truth'] is not None:
622
- # gt_label = {0: "Stable", 1: "Weathered"}.get(
623
- # result['ground_truth'], "Unknown")
624
- # correct = "✅" if result['prediction'] == result['ground_truth'] else "❌"
625
- # st.write(
626
- # f"**Ground Truth:** {gt_label} {correct}")
627
- # else:
628
- # st.info("No successful results")
629
-
630
- # with tab2:
631
- # if failed:
632
- # for result in failed:
633
- # with st.expander(f"❌ {result['filename']}", expanded=False):
634
- # st.error(f"Error: {result.get('error', 'Unknown error')}")
635
- # else:
636
- # st.success("No failed files!")
637
-
638
-
639
- def create_batch_uploader() -> List:
640
- """
641
- Create multi-file uploader widget
642
-
643
- Returns:
644
- List of uploaded files
645
- """
646
- uploaded_files = st.file_uploader(
647
- "Upload multiple Raman spectrum files (.txt)",
648
- type="txt",
649
- accept_multiple_files=True,
650
- help="Select multiple .txt files with wavenumber and intensity columns",
651
- key="batch_uploader",
652
- )
653
-
654
- return uploaded_files if uploaded_files else []
 
12
  import io
13
  from pathlib import Path
14
 
15
+ from .preprocessing import preprocess_spectrum
16
  from .errors import ErrorHandler, safe_execute
17
  from .results_manager import ResultsManager
18
  from .confidence import calculate_softmax_confidence
19
+ from config import TARGET_LEN
20
 
21
 
22
  def detect_file_format(filename: str, content: str) -> str:
 
256
  def parse_txt_spectrum(
257
  content: str, filename: str = "unknown"
258
  ) -> Tuple[np.ndarray, np.ndarray]:
259
+ """Robustly parse spectrum data from TXT format."""
 
 
260
  lines = content.strip().split("\n")
261
+ x_vals, y_vals = [], []
262
 
263
+ for i, line in enumerate(lines):
 
 
264
  line = line.strip()
265
+ if not line or line.startswith(("#", "%")):
266
+ continue
 
 
 
 
 
 
267
 
 
268
  try:
269
  # Handle different separators
270
+ parts = line.replace(",", " ").replace(";", " ").replace("\t", " ").split()
271
+
272
+ # Find the first two valid numbers in the line
273
+ numbers = []
274
+ for part in parts:
275
+ if part: # Skip empty strings from multiple spaces
276
+ try:
277
+ numbers.append(float(part))
278
+ except ValueError:
279
+ continue # Ignore non-numeric parts
280
+
281
  if len(numbers) >= 2:
282
+ x_vals.append(numbers[0])
283
+ y_vals.append(numbers[1])
284
+ else:
285
+ ErrorHandler.log_warning(
286
+ f"Could not find two numbers on line {i+1}: '{line}'",
287
+ f"Parsing {filename}",
288
+ )
289
 
290
+ except Exception as e:
291
  ErrorHandler.log_warning(
292
+ f"Error parsing line {i+1}: '{line}'. Error: {e}",
293
+ f"Parsing {filename}",
294
  )
295
  continue
296
 
297
+ if len(x_vals) < 10:
298
  raise ValueError(
299
  f"Insufficient data points ({len(x_vals)}). Need at least 10 points."
300
  )
 
339
  filename: str,
340
  text_content: str,
341
  model_choice: str,
 
342
  run_inference_func,
343
  label_file_func,
344
+ modality: str,
345
+ target_len: int,
346
  ) -> Optional[Dict[str, Any]]:
347
  """
348
  Process a single spectrum file
 
351
  filename: Name of the file
352
  text_content: Raw text content
353
  model_choice: Selected model name
 
354
  run_inference_func: Function to run inference
355
  label_file_func: Function to extract ground truth label
356
 
 
360
  start_time = time.time()
361
 
362
  try:
363
+ # 1. Parse spectrum data
364
+ x_raw, y_raw = parse_spectrum_data(text_content, filename)
 
 
 
 
 
 
 
 
 
365
 
366
+ # 2. Preprocess spectrum using the full, modality-aware pipeline
367
+ x_resampled, y_resampled = preprocess_spectrum(
368
+ x_raw, y_raw, modality=modality, target_len=target_len
 
 
 
 
 
 
 
369
  )
370
 
371
+ # 3. Run inference, passing modality
372
+ prediction, logits_list, probs, inference_time, logits = run_inference_func(
373
+ y_resampled, model_choice, modality=modality
 
 
 
 
 
 
 
 
 
374
  )
375
 
376
+ if prediction is None:
377
+ raise ValueError("Inference returned None. Model may have failed to load.")
 
 
 
 
 
378
 
379
  # ==Calculate confidence==
380
  if logits is not None:
 
382
  calculate_softmax_confidence(logits)
383
  )
384
  else:
385
+ # Fallback for older models or if logits are not returned
386
+ probs_np = np.array(probs) if probs is not None else np.array([])
387
+ max_confidence = float(np.max(probs_np)) if probs_np.size > 0 else 0.0
388
  confidence_level = "LOW"
389
  confidence_emoji = "🔴"
390
 
391
  # ==Get ground truth==
392
+ ground_truth = label_file_func(filename)
393
+ ground_truth = (
394
+ ground_truth if ground_truth is not None and ground_truth >= 0 else None
395
+ )
 
396
 
397
  # ==Get predicted class==
398
  label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"}
399
+ predicted_class = label_map.get(int(prediction), f"Unknown ({prediction})")
400
 
401
  processing_time = time.time() - start_time
402
 
403
  return {
404
  "filename": filename,
405
  "success": True,
406
+ "prediction": int(prediction),
407
  "predicted_class": predicted_class,
408
  "confidence": max_confidence,
409
  "confidence_level": confidence_level,
 
431
  def process_multiple_files(
432
  uploaded_files: List,
433
  model_choice: str,
 
434
  run_inference_func,
435
  label_file_func,
436
+ modality: str,
437
  progress_callback=None,
438
  ) -> List[Dict[str, Any]]:
439
  """
 
442
  Args:
443
  uploaded_files: List of uploaded file objects
444
  model_choice: Selected model name
 
445
  run_inference_func: Function to run inference
446
  label_file_func: Function to extract ground truth label
447
  progress_callback: Optional callback to update progress
 
452
  results = []
453
  total_files = len(uploaded_files)
454
 
455
+ ErrorHandler.log_info(
456
+ f"Starting batch processing of {total_files} files with modality '{modality}'"
457
+ )
458
 
459
  for i, uploaded_file in enumerate(uploaded_files):
460
  if progress_callback:
 
467
 
468
  # ==Process the file==
469
  result = process_single_file(
470
+ filename=uploaded_file.name,
471
+ text_content=text_content,
472
+ model_choice=model_choice,
473
+ run_inference_func=run_inference_func,
474
+ label_file_func=label_file_func,
475
+ modality=modality,
476
+ target_len=TARGET_LEN,
477
  )
478
 
479
  if result:
 
493
  metadata={
494
  "confidence_level": result["confidence_level"],
495
  "confidence_emoji": result["confidence_emoji"],
496
+ # Storing the spectrum data for later visualization
497
+ "x_raw": result["x_raw"],
498
+ "y_raw": result["y_raw"],
499
+ "x_resampled": result["x_resampled"],
500
+ "y_resampled": result["y_resampled"],
501
  },
502
  )
503
 
 
519
  )
520
 
521
  return results