devjas1 commited on
Commit
c0f3328
·
1 Parent(s): 1751cd3

(MILESTONE)[Feature: Image Processing. ]: Add comprehensive spectral image processing module

Browse files

- Allows users to upload images, select preprocessing and extraction options, process images, visualize results, and optionally run inference on extracted spectra.
- Stores processed spectrum and peaks in session state for downstream analysis.
- Included utility `image_to_spectrum_converter` for direct file-to-spectrum conversion.
- Added extensive docstrings and inline comments to facilitate maintainability and future extension.
- Created new module `utils/image_processing.py` to support image-based analysis for polymer classification.
- Implemented `SpectralImageProcessor` class for loading, preprocessing, and extracting spectral data from images.
- Methods include:
- `load_image`: Handles various image sources.
- `preprocess_image`: Resizes, enhances contrast, applies blur, and normalizes images.
- `extract_spectral_profile`: Converts 2D image data to 1D spectral profile via several selectable methods.
- `image_to_spectrum`: Maps extracted profile to wavenumber domain.
- `detect_spectral_peaks`: Identifies peaks using SciPy's signal library.
- `create_visualization`: Generates matplotlib figures of image and spectrum, with peak markers.
- Allows users to upload images, select preprocessing and extraction options, process images, visualize results, and optionally run inference on extracted spectra.
- Stores processed spectrum and peaks in session state for downstream analysis.
- Included utility `image_to_spectrum_converter` for direct file-to-spectrum conversion.
- Added `render_image_upload_interface` for Streamlit UI.
- Added extensive docstrings and inline comments to facilitate maintainability and future extension.

Files changed (1) hide show
  1. utils/image_processing.py +380 -0
utils/image_processing.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image loading and transformation utilities for polymer classification.
3
+ Supports conversion of spectral images to processable data.
4
+ """
5
+
6
+ from typing import Tuple, Optional, List, Dict
7
+ import base64
8
+ import io
9
+ import numpy as np
10
+ from PIL import Image, ImageEnhance, ImageFilter
11
+ import cv2
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib.figure import Figure
14
+ import streamlit as st
15
+ import pandas as pd
16
+
17
+ # Use existing inference pipeline
18
+ from utils.preprocessing import preprocess_spectrum
19
+ from core_logic import run_inference
20
+
21
+
22
+ class SpectralImageProcessor:
23
+ """Handles loading and processing of spectral images."""
24
+
25
+ def __init__(self):
26
+ self.support_formats = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
27
+ self.default_target_size = (224, 224)
28
+
29
+ def load_image(self, image_source) -> Optional[np.ndarray]:
30
+ """Load image from various sources."""
31
+ try:
32
+ if isinstance(image_source, str):
33
+ # File path
34
+ img = Image.open(image_source)
35
+ elif hasattr(image_source, "read"):
36
+ # File-like object (Streamlit uploaded file)
37
+ img = Image.open(image_source)
38
+ elif isinstance(image_source, np.ndarray):
39
+ # NumPy array
40
+ return image_source
41
+ else:
42
+ raise ValueError("Unsupported image source type")
43
+
44
+ # Convert to RGB if needed
45
+ if img.mode != "RGB":
46
+ img = img.convert("RGB")
47
+
48
+ return np.array(img)
49
+
50
+ except (FileNotFoundError, IOError, ValueError) as e:
51
+ st.error(f"Error loading image: {e}")
52
+ return None
53
+
54
+ def preprocess_image(
55
+ self,
56
+ image: np.ndarray,
57
+ target_size: Optional[Tuple[int, int]] = None,
58
+ enhance_contrast: bool = True,
59
+ apply_gaussian_blur: bool = False,
60
+ normalize: bool = True,
61
+ ) -> np.ndarray:
62
+ """Preprocess image for analysis."""
63
+ if target_size is None:
64
+ target_size = self.default_target_size
65
+
66
+ # Convert to PIL for processing
67
+ img = Image.fromarray(image.astype(np.uint8))
68
+
69
+ # Resize
70
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
71
+
72
+ # Enhance contrast if required
73
+ if enhance_contrast:
74
+ enhancer = ImageEnhance.Contrast(img)
75
+ img = enhancer.enhance(1.2)
76
+
77
+ # Apply Gaussian blur if requested
78
+ if apply_gaussian_blur:
79
+ img = img.filter(ImageFilter.GaussianBlur(radius=1))
80
+
81
+ # Convert back to numpy
82
+ processed = np.array(img)
83
+
84
+ # Normalize to [0, 1] if requested
85
+ if normalize:
86
+ processed = processed.astype(np.float32) / 255.0
87
+
88
+ return processed
89
+
90
+ def extract_spectral_profile(
91
+ self,
92
+ image: np.ndarray,
93
+ method: str = "average",
94
+ roi: Optional[Tuple[int, int, int, int]] = None,
95
+ ) -> np.ndarray:
96
+ """
97
+ Extract 1D spectral profile from 2D image.
98
+
99
+ Args:
100
+ image: Input image array
101
+ method: 'average', 'center_line', 'max_intensity'
102
+ roi: Region of interest (x1, y1, x2, y2)
103
+ """
104
+ if roi:
105
+ x1, y1, x2, y2 = roi
106
+ image_roi = image[y1:y2, x1:x2]
107
+ else:
108
+ image_roi = image
109
+
110
+ if len(image_roi.shape) == 3:
111
+ # Convert to grayscale if color
112
+ image_roi = np.mean(image_roi, axis=2)
113
+
114
+ if method == "average":
115
+ # Average along one axis
116
+ profile = np.mean(image_roi, axis=0)
117
+ elif method == "center_line":
118
+ # Extract center line
119
+ center_y = image_roi.shape[0] // 2
120
+ profile = image_roi[center_y, :]
121
+ elif method == "max_intensity":
122
+ # Maximum intensity projection
123
+ profile = np.max(image_roi, axis=0)
124
+ else:
125
+ raise ValueError(f"Unknown method: {method}")
126
+
127
+ return profile
128
+
129
+ def image_to_spectrum(
130
+ self,
131
+ image: np.ndarray,
132
+ wavenumber_range: Tuple[float, float] = (400, 4000),
133
+ method: str = "average",
134
+ ) -> Tuple[np.ndarray, np.ndarray]:
135
+ """Convert image to spectrum-like data."""
136
+ # Extract 1D profile
137
+ profile = self.extract_spectral_profile(image, method=method)
138
+
139
+ # Create wavenumber axis
140
+ wavenumbers = np.linspace(
141
+ wavenumber_range[0], wavenumber_range[1], len(profile)
142
+ )
143
+
144
+ return wavenumbers, profile
145
+
146
+ def detect_spectral_peaks(
147
+ self,
148
+ spectrum: np.ndarray,
149
+ wavenumbers: np.ndarray,
150
+ prominence: float = 0.1,
151
+ height: float = 0.1,
152
+ ) -> List[Dict[str, float]]:
153
+ """Detect peaks in spectral data."""
154
+ from scipy.signal import find_peaks
155
+
156
+ peaks, properties = find_peaks(spectrum, prominence=prominence, height=height)
157
+
158
+ peak_info = []
159
+ for i, peak_idx in enumerate(peaks):
160
+ peak_info.append(
161
+ {
162
+ "wavenumber": wavenumbers[peak_idx],
163
+ "intensity": spectrum[peak_idx],
164
+ "prominence": properties["prominences"][i],
165
+ "width": (
166
+ properties.get("widths", [None])[i]
167
+ if "widths" in properties
168
+ else None
169
+ ),
170
+ }
171
+ )
172
+
173
+ return peak_info
174
+
175
+ def create_visualization(
176
+ self,
177
+ image: np.ndarray,
178
+ spectrum_x: np.ndarray,
179
+ spectrum_y: np.ndarray,
180
+ peaks: Optional[List[Dict]] = None,
181
+ ) -> Figure:
182
+ """Create visualization of image and extracted spectrum."""
183
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
184
+
185
+ # Display image
186
+ ax1.imshow(image, cmap="viridis" if len(image.shape) == 2 else None)
187
+ ax1.set_title("Input Image")
188
+ ax1.axis("off")
189
+
190
+ # Display spectrum
191
+ ax2.plot(
192
+ spectrum_x, spectrum_y, "b-", linewidth=1.5, label="Extracted Spectrum"
193
+ )
194
+
195
+ # Mark peaks if provided
196
+ if peaks:
197
+ peak_wavenumbers = [p["wavenumber"] for p in peaks]
198
+ peak_intensities = [p["intensity"] for p in peaks]
199
+ ax2.plot(
200
+ peak_wavenumbers,
201
+ peak_intensities,
202
+ "ro",
203
+ markersize=6,
204
+ label="Detected Peaks",
205
+ )
206
+
207
+ ax2.set_xlabel("Wavenumber (cm⁻¹)")
208
+ ax2.set_ylabel("Intensity")
209
+ ax2.set_title("Extracted Spectral Profile")
210
+ ax2.grid(True, alpha=0.3)
211
+ ax2.legend()
212
+
213
+ plt.tight_layout()
214
+ return fig
215
+
216
+
217
+ def render_image_upload_interface():
218
+ """Render UI for image upload and processing."""
219
+ st.markdown("#### Image-Based Spectral Analysis")
220
+ st.markdown(
221
+ "Upload spectral images for analysis and conversion to spectroscopic data."
222
+ )
223
+
224
+ processor = SpectralImageProcessor()
225
+
226
+ # Image upload
227
+ uploaded_image = st.file_uploader(
228
+ "Upload spectral image",
229
+ type=["png", "jpg", "jpeg", "tiff", "bmp"],
230
+ help="Upload an image containing spectral data",
231
+ )
232
+
233
+ if uploaded_image is not None:
234
+ # Load and display original image
235
+ image = processor.load_image(uploaded_image)
236
+
237
+ if image is not None:
238
+ col1, col2 = st.columns([1, 1])
239
+
240
+ with col1:
241
+ st.markdown("##### Original Image")
242
+ st.image(image, use_column_width=True)
243
+
244
+ # Image info
245
+ st.write(f"**Dimensions**: {image.shape}")
246
+ st.write(f"**Size**: {uploaded_image.size} bytes")
247
+
248
+ with col2:
249
+ st.markdown("##### Processing Options")
250
+
251
+ # Processing parameters
252
+ target_width = st.slider("Target Width", 100, 1000, 500)
253
+ target_height = st.slider("Target Height", 100, 1000, 300)
254
+ enhance_contrast = st.checkbox("Enhance Contrast", value=True)
255
+ apply_blur = st.checkbox("Apply Gaussian Blur", value=False)
256
+
257
+ # Extraction method
258
+ extraction_method = st.selectbox(
259
+ "Spectrum Extraction Method",
260
+ ["average", "center_line", "max_intensity"],
261
+ help="Method for converting 2D image to 1D spectrum",
262
+ )
263
+
264
+ # Wavenumber range
265
+ st.markdown("**Wavenumber Range (cm⁻¹)**")
266
+ wn_col1, wn_col2 = st.columns(2)
267
+ with wn_col1:
268
+ wn_min = st.number_input("Min", value=400.0, step=10.0)
269
+ with wn_col2:
270
+ wn_max = st.number_input("Max", value=4000.0, step=10.0)
271
+
272
+ # Process image
273
+ if st.button("Process Image", type="primary"):
274
+ with st.spinner("Processing image..."):
275
+ # Preprocess image
276
+ processed_image = processor.preprocess_image(
277
+ image,
278
+ target_size=(target_width, target_height),
279
+ enhance_contrast=enhance_contrast,
280
+ apply_gaussian_blur=apply_blur,
281
+ )
282
+
283
+ # Extract spectrum
284
+ wavenumbers, spectrum = processor.image_to_spectrum(
285
+ processed_image,
286
+ wavenumber_range=(wn_min, wn_max),
287
+ method=extraction_method,
288
+ )
289
+
290
+ # Detect peaks
291
+ peaks = processor.detect_spectral_peaks(spectrum, wavenumbers)
292
+
293
+ # Create visualization
294
+ fig = processor.create_visualization(
295
+ processed_image, wavenumbers, spectrum, peaks
296
+ )
297
+
298
+ # Display visualization
299
+ st.pyplot(fig)
300
+
301
+ # Display peaks information
302
+ if peaks:
303
+ st.markdown("##### Detected Peaks")
304
+ peak_df = pd.DataFrame(peaks)
305
+ peak_df["wavenumber"] = peak_df["wavenumber"].round(2)
306
+ peak_df["intensity"] = peak_df["intensity"].round(4)
307
+ st.dataframe(peak_df)
308
+
309
+ # Store in session state for further analysis
310
+ st.session_state["image_spectrum_x"] = wavenumbers
311
+ st.session_state["image_spectrum_y"] = spectrum
312
+ st.session_state["image_peaks"] = peaks
313
+
314
+ st.success(
315
+ "Image processing complete! You can now use this data for model inference."
316
+ )
317
+
318
+ # Option to run inference on extracted spectrum
319
+ if st.button("Run Inference on Extracted Spectrum"):
320
+
321
+ # Preprocess extracted spectrum
322
+ modality = st.session_state.get("modality_select", "raman")
323
+ _, y_processed = preprocess_spectrum(
324
+ wavenumbers, spectrum, modality=modality, target_len=500
325
+ )
326
+
327
+ # Get selected model
328
+ model_choice = st.session_state.get("model_select", "figure2")
329
+ if " " in model_choice:
330
+ model_choice = model_choice.split(" ", 1)[1]
331
+
332
+ # Run inference
333
+ prediction, logits_list, probs, inference_time, logits = (
334
+ run_inference(y_processed, model_choice)
335
+ )
336
+
337
+ if prediction is not None:
338
+ class_names = ["Stable", "Weathered"]
339
+ predicted_class = (
340
+ class_names[int(prediction)]
341
+ if prediction < len(class_names)
342
+ else f"Class_{prediction}"
343
+ )
344
+ confidence = max(probs) if probs and len(probs) > 0 else 0.0
345
+
346
+ # Display results
347
+ st.markdown("##### Inference Results")
348
+ result_col1, result_col2 = st.columns(2)
349
+
350
+ with result_col1:
351
+ st.metric("Prediction", predicted_class)
352
+ st.metric("Confidence", f"{confidence:.3f}")
353
+
354
+ with result_col2:
355
+ st.metric("Model Used", model_choice)
356
+ st.metric("Processing Time", f"{inference_time:.3f}s")
357
+
358
+ # Show class probabilities
359
+ if probs:
360
+ st.markdown("**Class Probabilities**")
361
+ for i, prob in enumerate(probs):
362
+ if i < len(class_names):
363
+ st.write(f"- {class_names[i]}: {prob:.4f}")
364
+
365
+
366
+ def image_to_spectrum_converter(
367
+ image_path: str,
368
+ wavenumber_range: Tuple[float, float] = (400, 4000),
369
+ method: str = "average",
370
+ ) -> Tuple[np.ndarray, np.ndarray]:
371
+ """Convert image file to spectrum data (utility function)."""
372
+ processor = SpectralImageProcessor()
373
+
374
+ # Load image
375
+ image = processor.load_image(image_path)
376
+ if image is None:
377
+ raise ValueError(f"Could not load image from {image_path}.")
378
+
379
+ # Convert to spectrum
380
+ return processor.image_to_spectrum(image, wavenumber_range, method)