File size: 13,950 Bytes
c0f3328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e6ebd2
c0f3328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
"""

Image loading and transformation utilities for polymer classification.

Supports conversion of spectral images to processable data.

"""

from typing import Tuple, Optional, List, Dict
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import streamlit as st
import pandas as pd

# Use existing inference pipeline
from utils.preprocessing import preprocess_spectrum
from core_logic import run_inference


class SpectralImageProcessor:
    """Handles loading and processing of spectral images."""

    def __init__(self):
        self.support_formats = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
        self.default_target_size = (224, 224)

    def load_image(self, image_source) -> Optional[np.ndarray]:
        """Load image from various sources."""
        try:
            if isinstance(image_source, str):
                # File path
                img = Image.open(image_source)
            elif hasattr(image_source, "read"):
                # File-like object (Streamlit uploaded file)
                img = Image.open(image_source)
            elif isinstance(image_source, np.ndarray):
                # NumPy array
                return image_source
            else:
                raise ValueError("Unsupported image source type")

            # Convert to RGB if needed
            if img.mode != "RGB":
                img = img.convert("RGB")

            return np.array(img)

        except (FileNotFoundError, IOError, ValueError) as e:
            st.error(f"Error loading image: {e}")
            return None

    def preprocess_image(

        self,

        image: np.ndarray,

        target_size: Optional[Tuple[int, int]] = None,

        enhance_contrast: bool = True,

        apply_gaussian_blur: bool = False,

        normalize: bool = True,

    ) -> np.ndarray:
        """Preprocess image for analysis."""
        if target_size is None:
            target_size = self.default_target_size

        # Convert to PIL for processing
        img = Image.fromarray(image.astype(np.uint8))

        # Resize
        img = img.resize(target_size, Image.Resampling.LANCZOS)

        # Enhance contrast if required
        if enhance_contrast:
            enhancer = ImageEnhance.Contrast(img)
            img = enhancer.enhance(1.2)

        # Apply Gaussian blur if requested
        if apply_gaussian_blur:
            img = img.filter(ImageFilter.GaussianBlur(radius=1))

        # Convert back to numpy
        processed = np.array(img)

        # Normalize to [0, 1] if requested
        if normalize:
            processed = processed.astype(np.float32) / 255.0

        return processed

    def extract_spectral_profile(

        self,

        image: np.ndarray,

        method: str = "average",

        roi: Optional[Tuple[int, int, int, int]] = None,

    ) -> np.ndarray:
        """

        Extract 1D spectral profile from 2D image.



        Args:

            image: Input image array

            method: 'average', 'center_line', 'max_intensity'

            roi: Region of interest (x1, y1, x2, y2)

        """
        if roi:
            x1, y1, x2, y2 = roi
            image_roi = image[y1:y2, x1:x2]
        else:
            image_roi = image

        if len(image_roi.shape) == 3:
            # Convert to grayscale if color
            image_roi = np.mean(image_roi, axis=2)

        if method == "average":
            # Average along one axis
            profile = np.mean(image_roi, axis=0)
        elif method == "center_line":
            # Extract center line
            center_y = image_roi.shape[0] // 2
            profile = image_roi[center_y, :]
        elif method == "max_intensity":
            # Maximum intensity projection
            profile = np.max(image_roi, axis=0)
        else:
            raise ValueError(f"Unknown method: {method}")

        return profile

    def image_to_spectrum(

        self,

        image: np.ndarray,

        wavenumber_range: Tuple[float, float] = (400, 4000),

        method: str = "average",

    ) -> Tuple[np.ndarray, np.ndarray]:
        """Convert image to spectrum-like data."""
        # Extract 1D profile
        profile = self.extract_spectral_profile(image, method=method)

        # Create wavenumber axis
        wavenumbers = np.linspace(
            wavenumber_range[0], wavenumber_range[1], len(profile)
        )

        return wavenumbers, profile

    def detect_spectral_peaks(

        self,

        spectrum: np.ndarray,

        wavenumbers: np.ndarray,

        prominence: float = 0.1,

        height: float = 0.1,

    ) -> List[Dict[str, float]]:
        """Detect peaks in spectral data."""
        from scipy.signal import find_peaks

        peaks, properties = find_peaks(spectrum, prominence=prominence, height=height)

        peak_info = []
        for i, peak_idx in enumerate(peaks):
            peak_info.append(
                {
                    "wavenumber": wavenumbers[peak_idx],
                    "intensity": spectrum[peak_idx],
                    "prominence": properties["prominences"][i],
                    "width": (
                        properties.get("widths", [None])[i]
                        if "widths" in properties
                        else None
                    ),
                }
            )

        return peak_info

    def create_visualization(

        self,

        image: np.ndarray,

        spectrum_x: np.ndarray,

        spectrum_y: np.ndarray,

        peaks: Optional[List[Dict]] = None,

    ) -> Figure:
        """Create visualization of image and extracted spectrum."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

        # Display image
        ax1.imshow(image, cmap="viridis" if len(image.shape) == 2 else None)
        ax1.set_title("Input Image")
        ax1.axis("off")

        # Display spectrum
        ax2.plot(
            spectrum_x, spectrum_y, "b-", linewidth=1.5, label="Extracted Spectrum"
        )

        # Mark peaks if provided
        if peaks:
            peak_wavenumbers = [p["wavenumber"] for p in peaks]
            peak_intensities = [p["intensity"] for p in peaks]
            ax2.plot(
                peak_wavenumbers,
                peak_intensities,
                "ro",
                markersize=6,
                label="Detected Peaks",
            )

        ax2.set_xlabel("Wavenumber (cm⁻¹)")
        ax2.set_ylabel("Intensity")
        ax2.set_title("Extracted Spectral Profile")
        ax2.grid(True, alpha=0.3)
        ax2.legend()

        plt.tight_layout()
        return fig


def render_image_upload_interface():
    """Render UI for image upload and processing."""
    st.markdown("#### Image-Based Spectral Analysis")
    st.markdown(
        "Upload spectral images for analysis and conversion to spectroscopic data."
    )

    processor = SpectralImageProcessor()

    # Image upload
    uploaded_image = st.file_uploader(
        "Upload spectral image",
        type=["png", "jpg", "jpeg", "tiff", "bmp"],
        help="Upload an image containing spectral data",
    )

    if uploaded_image is not None:
        # Load and display original image
        image = processor.load_image(uploaded_image)

        if image is not None:
            col1, col2 = st.columns([1, 1])

            with col1:
                st.markdown("##### Original Image")
                st.image(image, use_container_width=True)

                # Image info
                st.write(f"**Dimensions**: {image.shape}")
                st.write(f"**Size**: {uploaded_image.size} bytes")

            with col2:
                st.markdown("##### Processing Options")

                # Processing parameters
                target_width = st.slider("Target Width", 100, 1000, 500)
                target_height = st.slider("Target Height", 100, 1000, 300)
                enhance_contrast = st.checkbox("Enhance Contrast", value=True)
                apply_blur = st.checkbox("Apply Gaussian Blur", value=False)

                # Extraction method
                extraction_method = st.selectbox(
                    "Spectrum Extraction Method",
                    ["average", "center_line", "max_intensity"],
                    help="Method for converting 2D image to 1D spectrum",
                )

                # Wavenumber range
                st.markdown("**Wavenumber Range (cm⁻¹)**")
                wn_col1, wn_col2 = st.columns(2)
                with wn_col1:
                    wn_min = st.number_input("Min", value=400.0, step=10.0)
                with wn_col2:
                    wn_max = st.number_input("Max", value=4000.0, step=10.0)

            # Process image
            if st.button("Process Image", type="primary"):
                with st.spinner("Processing image..."):
                    # Preprocess image
                    processed_image = processor.preprocess_image(
                        image,
                        target_size=(target_width, target_height),
                        enhance_contrast=enhance_contrast,
                        apply_gaussian_blur=apply_blur,
                    )

                    # Extract spectrum
                    wavenumbers, spectrum = processor.image_to_spectrum(
                        processed_image,
                        wavenumber_range=(wn_min, wn_max),
                        method=extraction_method,
                    )

                    # Detect peaks
                    peaks = processor.detect_spectral_peaks(spectrum, wavenumbers)

                    # Create visualization
                    fig = processor.create_visualization(
                        processed_image, wavenumbers, spectrum, peaks
                    )

                    # Display visualization
                    st.pyplot(fig)

                    # Display peaks information
                    if peaks:
                        st.markdown("##### Detected Peaks")
                        peak_df = pd.DataFrame(peaks)
                        peak_df["wavenumber"] = peak_df["wavenumber"].round(2)
                        peak_df["intensity"] = peak_df["intensity"].round(4)
                        st.dataframe(peak_df)

                    # Store in session state for further analysis
                    st.session_state["image_spectrum_x"] = wavenumbers
                    st.session_state["image_spectrum_y"] = spectrum
                    st.session_state["image_peaks"] = peaks

                    st.success(
                        "Image processing complete! You can now use this data for model inference."
                    )

                    # Option to run inference on extracted spectrum
                    if st.button("Run Inference on Extracted Spectrum"):

                        # Preprocess extracted spectrum
                        modality = st.session_state.get("modality_select", "raman")
                        _, y_processed = preprocess_spectrum(
                            wavenumbers, spectrum, modality=modality, target_len=500
                        )

                        # Get selected model
                        model_choice = st.session_state.get("model_select", "figure2")
                        if " " in model_choice:
                            model_choice = model_choice.split(" ", 1)[1]

                        # Run inference
                        prediction, logits_list, probs, inference_time, logits = (
                            run_inference(y_processed, model_choice)
                        )

                        if prediction is not None:
                            class_names = ["Stable", "Weathered"]
                            predicted_class = (
                                class_names[int(prediction)]
                                if prediction < len(class_names)
                                else f"Class_{prediction}"
                            )
                            confidence = max(probs) if probs and len(probs) > 0 else 0.0

                            # Display results
                            st.markdown("##### Inference Results")
                            result_col1, result_col2 = st.columns(2)

                            with result_col1:
                                st.metric("Prediction", predicted_class)
                                st.metric("Confidence", f"{confidence:.3f}")

                            with result_col2:
                                st.metric("Model Used", model_choice)
                                st.metric("Processing Time", f"{inference_time:.3f}s")

                            # Show class probabilities
                            if probs:
                                st.markdown("**Class Probabilities**")
                                for i, prob in enumerate(probs):
                                    if i < len(class_names):
                                        st.write(f"- {class_names[i]}: {prob:.4f}")


def image_to_spectrum_converter(

    image_path: str,

    wavenumber_range: Tuple[float, float] = (400, 4000),

    method: str = "average",

) -> Tuple[np.ndarray, np.ndarray]:
    """Convert image file to spectrum data (utility function)."""
    processor = SpectralImageProcessor()

    # Load image
    image = processor.load_image(image_path)
    if image is None:
        raise ValueError(f"Could not load image from {image_path}.")

    # Convert to spectrum
    return processor.image_to_spectrum(image, wavenumber_range, method)