Fahimeh Orvati Nia commited on
Commit
dd1d7f5
·
1 Parent(s): 96f1578

make pipeline minimal

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import tempfile
 
3
  from wrapper import run_pipeline_on_image
4
 
5
  def process(image):
@@ -10,7 +11,11 @@ def process(image):
10
  img_path = Path(tmpdir) / "input.png"
11
  image.save(img_path)
12
  outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
13
- return list(outputs.values())
 
 
 
 
14
 
15
  with gr.Blocks() as demo:
16
  gr.Markdown("# 🌿 Sorghum Single-Image Demo")
@@ -20,4 +25,4 @@ with gr.Blocks() as demo:
20
  run.click(process, inputs=inp, outputs=gallery)
21
 
22
  if __name__ == "__main__":
23
- demo.launch()
 
1
  import gradio as gr
2
  import tempfile
3
+ from pathlib import Path
4
  from wrapper import run_pipeline_on_image
5
 
6
  def process(image):
 
11
  img_path = Path(tmpdir) / "input.png"
12
  image.save(img_path)
13
  outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
14
+ # Keep order consistent: return exactly the 7 images
15
+ order = [
16
+ 'NDVI', 'ARI', 'GNDVI', 'LBP', 'HOG', 'Lacunarity', 'SizeAnalysis'
17
+ ]
18
+ return [outputs[k] for k in order if k in outputs]
19
 
20
  with gr.Blocks() as demo:
21
  gr.Markdown("# 🌿 Sorghum Single-Image Demo")
 
25
  run.click(process, inputs=inp, outputs=gallery)
26
 
27
  if __name__ == "__main__":
28
+ demo.launch()
sorghum_pipeline/config.py CHANGED
@@ -1,249 +1,76 @@
1
  """
2
- Configuration management for the Sorghum Pipeline.
3
-
4
- This module handles all configuration settings, paths, and parameters
5
- used throughout the pipeline.
6
  """
7
 
8
  import os
9
- import yaml
10
  from pathlib import Path
11
- from typing import Dict, Any, Optional
12
- from dataclasses import dataclass, field
13
 
14
 
15
  @dataclass
16
  class Paths:
17
- """Configuration for all file paths."""
18
  input_folder: str
19
  output_folder: str
20
- boundingbox_dir: Optional[str] = None
21
- labels_folder: Optional[str] = None
22
 
23
  def __post_init__(self):
24
- """Ensure all paths are absolute where provided."""
25
  self.input_folder = os.path.abspath(self.input_folder)
26
  self.output_folder = os.path.abspath(self.output_folder)
27
- if self.boundingbox_dir:
28
- self.boundingbox_dir = os.path.abspath(self.boundingbox_dir)
29
- if self.labels_folder:
30
- self.labels_folder = os.path.abspath(self.labels_folder)
31
 
32
 
33
  @dataclass
34
  class ProcessingParams:
35
- """Parameters for image processing."""
36
- # Image processing
37
- target_size: tuple = (1024, 1024)
38
- gaussian_blur_kernel: int = 5
39
- morphology_kernel_size: int = 7
40
  min_component_area: int = 1000
41
-
42
- # Segmentation
43
  segmentation_threshold: float = 0.5
44
- max_components: int = 10
45
-
46
- # Texture analysis
47
- lbp_points: int = 8
48
- lbp_radius: int = 1
49
- hog_orientations: int = 9
50
- hog_pixels_per_cell: tuple = (8, 8)
51
- hog_cells_per_block: tuple = (2, 2)
52
- lacunarity_window: int = 15
53
- ehd_threshold: float = 0.3
54
- angle_resolution: int = 45
55
-
56
- # Vegetation indices
57
- epsilon: float = 1e-10
58
- soil_factor: float = 0.16
59
-
60
- # Morphology
61
- pixel_to_cm: float = 0.1099609375
62
- prune_sizes: list = field(default_factory=lambda: [200, 100, 50, 30, 10])
63
 
64
 
65
  @dataclass
66
  class OutputSettings:
67
- """Settings for output generation."""
68
  save_images: bool = True
69
- save_plots: bool = True
70
- save_metadata: bool = True
71
- image_dpi: int = 150
72
  plot_dpi: int = 100
73
- image_format: str = "png"
74
-
75
- # Subdirectories
76
- segmentation_dir: str = "segmentation"
77
- features_dir: str = "features"
78
- texture_dir: str = "texture"
79
- morphology_dir: str = "morphology"
80
- vegetation_dir: str = "vegetation_indices"
81
- analysis_dir: str = "analysis"
82
 
83
 
84
  @dataclass
85
  class ModelSettings:
86
- """Settings for ML models."""
87
- device: str = "auto" # auto, cpu, cuda
88
  model_name: str = "briaai/RMBG-2.0"
89
- batch_size: int = 1
90
  trust_remote_code: bool = True
91
  cache_dir: str = ""
92
  local_files_only: bool = False
93
 
94
 
95
  class Config:
96
- """Main configuration class for the Sorghum Pipeline."""
97
 
98
- def __init__(self, config_path: Optional[str] = None):
99
- """
100
- Initialize configuration.
101
-
102
- Args:
103
- config_path: Path to YAML configuration file. If None, uses defaults.
104
- """
105
- self.paths = Paths(
106
- input_folder="",
107
- output_folder="",
108
- boundingbox_dir=""
109
- )
110
  self.processing = ProcessingParams()
111
  self.output = OutputSettings()
112
  self.model = ModelSettings()
113
-
114
- if config_path:
115
- self.load_from_file(config_path)
116
-
117
- def load_from_file(self, config_path: str) -> None:
118
- """Load configuration from YAML file."""
119
- config_path = Path(config_path)
120
- if not config_path.exists():
121
- raise FileNotFoundError(f"Configuration file not found: {config_path}")
122
-
123
- with open(config_path, 'r') as f:
124
- config_data = yaml.safe_load(f)
125
-
126
- # Update paths
127
- if 'paths' in config_data:
128
- self.paths = Paths(**config_data['paths'])
129
-
130
- # Update processing parameters
131
- if 'processing' in config_data:
132
- for key, value in config_data['processing'].items():
133
- if hasattr(self.processing, key):
134
- setattr(self.processing, key, value)
135
-
136
- # Update output settings
137
- if 'output' in config_data:
138
- for key, value in config_data['output'].items():
139
- if hasattr(self.output, key):
140
- setattr(self.output, key, value)
141
-
142
- # Update model settings
143
- if 'model' in config_data:
144
- for key, value in config_data['model'].items():
145
- if hasattr(self.model, key):
146
- setattr(self.model, key, value)
147
-
148
- def save_to_file(self, config_path: str) -> None:
149
- """Save current configuration to YAML file."""
150
- config_data = {
151
- 'paths': {
152
- 'input_folder': self.paths.input_folder,
153
- 'output_folder': self.paths.output_folder,
154
- 'boundingbox_dir': self.paths.boundingbox_dir,
155
- 'labels_folder': self.paths.labels_folder
156
- },
157
- 'processing': {
158
- 'target_size': self.processing.target_size,
159
- 'gaussian_blur_kernel': self.processing.gaussian_blur_kernel,
160
- 'morphology_kernel_size': self.processing.morphology_kernel_size,
161
- 'min_component_area': self.processing.min_component_area,
162
- 'segmentation_threshold': self.processing.segmentation_threshold,
163
- 'max_components': self.processing.max_components,
164
- 'lbp_points': self.processing.lbp_points,
165
- 'lbp_radius': self.processing.lbp_radius,
166
- 'hog_orientations': self.processing.hog_orientations,
167
- 'hog_pixels_per_cell': self.processing.hog_pixels_per_cell,
168
- 'hog_cells_per_block': self.processing.hog_cells_per_block,
169
- 'lacunarity_window': self.processing.lacunarity_window,
170
- 'ehd_threshold': self.processing.ehd_threshold,
171
- 'angle_resolution': self.processing.angle_resolution,
172
- 'epsilon': self.processing.epsilon,
173
- 'soil_factor': self.processing.soil_factor,
174
- 'pixel_to_cm': self.processing.pixel_to_cm,
175
- 'prune_sizes': self.processing.prune_sizes
176
- },
177
- 'output': {
178
- 'save_images': self.output.save_images,
179
- 'save_plots': self.output.save_plots,
180
- 'save_metadata': self.output.save_metadata,
181
- 'image_dpi': self.output.image_dpi,
182
- 'plot_dpi': self.output.plot_dpi,
183
- 'image_format': self.output.image_format,
184
- 'segmentation_dir': self.output.segmentation_dir,
185
- 'features_dir': self.output.features_dir,
186
- 'texture_dir': self.output.texture_dir,
187
- 'morphology_dir': self.output.morphology_dir,
188
- 'vegetation_dir': self.output.vegetation_dir,
189
- 'analysis_dir': self.output.analysis_dir
190
- },
191
- 'model': {
192
- 'device': self.model.device,
193
- 'model_name': self.model.model_name,
194
- 'batch_size': self.model.batch_size,
195
- 'trust_remote_code': self.model.trust_remote_code,
196
- 'cache_dir': self.model.cache_dir,
197
- 'local_files_only': self.model.local_files_only,
198
- }
199
- }
200
-
201
- with open(config_path, 'w') as f:
202
- yaml.dump(config_data, f, default_flow_style=False, indent=2)
203
 
204
  def get_device(self) -> str:
205
- """Get the appropriate device for processing."""
206
  if self.model.device == "auto":
207
  import torch
208
  return "cuda" if torch.cuda.is_available() else "cpu"
209
  return self.model.device
210
 
211
- def create_output_directories(self, base_path: str) -> None:
212
- """Ensure base output directory exists only.
213
-
214
- Subdirectories are created per plant in the output manager.
215
- """
216
- base_path = Path(base_path)
217
- base_path.mkdir(parents=True, exist_ok=True)
218
-
219
  def validate(self) -> bool:
220
- """Validate configuration settings."""
221
- # Check if input directory exists
222
- if not os.path.exists(self.paths.input_folder):
223
  raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
224
-
225
- # Check if bounding box directory exists (optional)
226
- if hasattr(self.paths, 'boundingbox_dir') and self.paths.boundingbox_dir and not os.path.exists(self.paths.boundingbox_dir):
227
- raise FileNotFoundError(f"Bounding box directory does not exist: {self.paths.boundingbox_dir}")
228
-
229
- # Validate processing parameters
230
- if self.processing.target_size[0] <= 0 or self.processing.target_size[1] <= 0:
231
- raise ValueError("Target size must be positive")
232
-
233
- if self.processing.segmentation_threshold < 0 or self.processing.segmentation_threshold > 1:
234
- raise ValueError("Segmentation threshold must be between 0 and 1")
235
-
236
- return True
237
-
238
-
239
- def create_default_config(output_path: str) -> None:
240
- """Create a default configuration file."""
241
- config = Config()
242
- config.paths = Paths(
243
- input_folder="Sorghum_dataset",
244
- output_folder="Sorghum_pipeline_Results",
245
- boundingbox_dir="boundingbox",
246
- labels_folder="labels"
247
- )
248
- config.save_to_file(output_path)
249
- print(f"Default configuration created at: {output_path}")
 
1
  """
2
+ Minimal configuration for the Sorghum Pipeline.
 
 
 
3
  """
4
 
5
  import os
 
6
  from pathlib import Path
7
+ from dataclasses import dataclass
 
8
 
9
 
10
  @dataclass
11
  class Paths:
12
+ """Configuration for file paths."""
13
  input_folder: str
14
  output_folder: str
15
+ boundingbox_dir: str = ""
 
16
 
17
  def __post_init__(self):
18
+ """Ensure paths are absolute."""
19
  self.input_folder = os.path.abspath(self.input_folder)
20
  self.output_folder = os.path.abspath(self.output_folder)
 
 
 
 
21
 
22
 
23
  @dataclass
24
  class ProcessingParams:
25
+ """Minimal processing parameters."""
26
+ target_size: tuple = None
 
 
 
27
  min_component_area: int = 1000
28
+ morphology_kernel_size: int = 7
 
29
  segmentation_threshold: float = 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  @dataclass
33
  class OutputSettings:
34
+ """Output settings."""
35
  save_images: bool = True
36
+ save_plots: bool = False
37
+ save_metadata: bool = False
 
38
  plot_dpi: int = 100
39
+ segmentation_dir: str = "results"
40
+ texture_dir: str = "texture_output"
41
+ morphology_dir: str = "results"
42
+ vegetation_dir: str = "Vegetation_indices_images"
 
 
 
 
 
43
 
44
 
45
  @dataclass
46
  class ModelSettings:
47
+ """Model settings."""
48
+ device: str = "auto"
49
  model_name: str = "briaai/RMBG-2.0"
 
50
  trust_remote_code: bool = True
51
  cache_dir: str = ""
52
  local_files_only: bool = False
53
 
54
 
55
  class Config:
56
+ """Minimal configuration class."""
57
 
58
+ def __init__(self):
59
+ """Initialize with defaults."""
60
+ self.paths = Paths(input_folder="", output_folder="", boundingbox_dir="")
 
 
 
 
 
 
 
 
 
61
  self.processing = ProcessingParams()
62
  self.output = OutputSettings()
63
  self.model = ModelSettings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def get_device(self) -> str:
66
+ """Get processing device."""
67
  if self.model.device == "auto":
68
  import torch
69
  return "cuda" if torch.cuda.is_available() else "cpu"
70
  return self.model.device
71
 
 
 
 
 
 
 
 
 
72
  def validate(self) -> bool:
73
+ """Validate configuration."""
74
+ if self.paths.input_folder and not os.path.exists(self.paths.input_folder):
 
75
  raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
76
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/data/mask_handler.py CHANGED
@@ -1,296 +1,28 @@
1
  """
2
- Mask handling functionality for the Sorghum Pipeline.
3
-
4
- This module handles mask creation, processing, and validation
5
- for plant segmentation tasks.
6
  """
7
 
8
  import numpy as np
9
  import cv2
10
- from typing import Dict, Tuple, Optional, List
11
  import logging
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
  class MaskHandler:
17
- """Handles mask creation, processing, and validation."""
18
 
19
  def __init__(self, min_area: int = 1000, kernel_size: int = 7):
20
- """
21
- Initialize the mask handler.
22
-
23
- Args:
24
- min_area: Minimum area for connected components
25
- kernel_size: Kernel size for morphological operations
26
- """
27
  self.min_area = min_area
28
  self.kernel_size = kernel_size
29
 
30
- def create_bounding_box_mask(self, image_shape: Tuple[int, int],
31
- bbox: Tuple[int, int, int, int]) -> np.ndarray:
32
- """
33
- Create a mask from bounding box coordinates.
34
-
35
- Args:
36
- image_shape: Shape of the image (height, width)
37
- bbox: Bounding box coordinates (x1, y1, x2, y2)
38
-
39
- Returns:
40
- Binary mask array
41
- """
42
- h, w = image_shape[:2]
43
- mask = np.zeros((h, w), dtype=np.uint8)
44
-
45
- x1, y1, x2, y2 = bbox
46
- # Clamp coordinates to image bounds
47
- x1 = max(0, min(w, x1))
48
- y1 = max(0, min(h, y1))
49
- x2 = max(0, min(w, x2))
50
- y2 = max(0, min(h, y2))
51
-
52
- mask[y1:y2, x1:x2] = 255
53
- return mask
54
-
55
- def preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
56
- """
57
- Preprocess mask by cleaning and filtering.
58
-
59
- Args:
60
- mask: Input mask
61
-
62
- Returns:
63
- Cleaned mask
64
- """
65
- if mask is None:
66
- return None
67
-
68
- # Convert to binary if needed
69
- if isinstance(mask, tuple):
70
- mask = mask[0]
71
-
72
- # Ensure binary format
73
- mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
74
-
75
- # Morphological opening to remove noise
76
- kernel = cv2.getStructuringElement(
77
- cv2.MORPH_ELLIPSE,
78
- (self.kernel_size, self.kernel_size)
79
- )
80
- opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
81
-
82
- # Remove small connected components
83
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
84
- opened, connectivity=8
85
- )
86
-
87
- clean_mask = np.zeros_like(opened)
88
- for label in range(1, num_labels): # Skip background
89
- if stats[label, cv2.CC_STAT_AREA] >= self.min_area:
90
- clean_mask[labels == label] = 255
91
-
92
- return clean_mask
93
-
94
- def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
95
- """
96
- Keep only the largest connected component in the mask.
97
-
98
- Args:
99
- mask: Input mask
100
-
101
- Returns:
102
- Mask with only the largest component
103
- """
104
- if mask is None:
105
- return None
106
-
107
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
108
-
109
- if num_labels <= 1:
110
- return mask
111
-
112
- # Find the largest component (excluding background)
113
- areas = stats[1:, cv2.CC_STAT_AREA]
114
- largest_label = 1 + np.argmax(areas)
115
-
116
- # Create mask with only the largest component
117
- largest_mask = (labels == largest_label).astype(np.uint8) * 255
118
-
119
- return largest_mask
120
-
121
  def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
122
- """
123
- Apply mask to image.
124
-
125
- Args:
126
- image: Input image
127
- mask: Binary mask
128
-
129
- Returns:
130
- Masked image
131
- """
132
  if mask is None:
133
  return image
134
-
135
- return cv2.bitwise_and(image, image, mask=mask)
136
-
137
- def create_overlay(self, image: np.ndarray, mask: np.ndarray,
138
- color: Tuple[int, int, int] = (0, 255, 0),
139
- alpha: float = 0.5) -> np.ndarray:
140
- """
141
- Create overlay of mask on image.
142
-
143
- Args:
144
- image: Base image
145
- mask: Binary mask
146
- color: Overlay color (B, G, R)
147
- alpha: Overlay transparency
148
-
149
- Returns:
150
- Image with mask overlay
151
- """
152
- overlay = image.copy()
153
- overlay[mask == 255] = color
154
- return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
155
-
156
- def get_mask_properties(self, mask: np.ndarray) -> Dict[str, float]:
157
- """
158
- Get properties of the mask.
159
-
160
- Args:
161
- mask: Binary mask
162
-
163
- Returns:
164
- Dictionary of mask properties
165
- """
166
- if mask is None:
167
- return {}
168
-
169
- # Convert to binary
170
- binary_mask = (mask > 127).astype(np.uint8)
171
-
172
- # Calculate properties
173
- area = np.sum(binary_mask)
174
- perimeter = cv2.arcLength(
175
- cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0],
176
- True
177
- ) if len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]) > 0 else 0
178
-
179
- # Bounding box
180
- contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
181
- if contours:
182
- x, y, w, h = cv2.boundingRect(contours[0])
183
- bbox_area = w * h
184
- aspect_ratio = w / h if h > 0 else 0
185
- else:
186
- bbox_area = 0
187
- aspect_ratio = 0
188
-
189
- return {
190
- "area": float(area),
191
- "perimeter": float(perimeter),
192
- "bbox_area": float(bbox_area),
193
- "aspect_ratio": float(aspect_ratio),
194
- "coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0
195
- }
196
-
197
- def validate_mask(self, mask: np.ndarray) -> bool:
198
- """
199
- Validate mask format and content.
200
-
201
- Args:
202
- mask: Mask to validate
203
-
204
- Returns:
205
- True if valid, False otherwise
206
- """
207
- if mask is None:
208
- return False
209
-
210
- if not isinstance(mask, np.ndarray):
211
- return False
212
-
213
- if mask.ndim != 2:
214
- return False
215
-
216
- if mask.dtype not in [np.uint8, np.bool_]:
217
- return False
218
-
219
- # Check if mask has any foreground pixels
220
- if np.sum(mask > 0) == 0:
221
- logger.warning("Mask has no foreground pixels")
222
- return False
223
-
224
- return True
225
-
226
- def resize_mask(self, mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
227
- """
228
- Resize mask to target size.
229
-
230
- Args:
231
- mask: Input mask
232
- target_size: Target size (width, height)
233
-
234
- Returns:
235
- Resized mask
236
- """
237
- if mask is None:
238
- return None
239
-
240
- return cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
241
-
242
- def dilate_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
243
- """
244
- Dilate mask to expand foreground regions.
245
-
246
- Args:
247
- mask: Input mask
248
- kernel_size: Size of dilation kernel
249
-
250
- Returns:
251
- Dilated mask
252
- """
253
- if mask is None:
254
- return None
255
-
256
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
257
- return cv2.dilate(mask, kernel, iterations=1)
258
-
259
- def erode_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
260
- """
261
- Erode mask to shrink foreground regions.
262
-
263
- Args:
264
- mask: Input mask
265
- kernel_size: Size of erosion kernel
266
-
267
- Returns:
268
- Eroded mask
269
- """
270
- if mask is None:
271
- return None
272
-
273
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
274
- return cv2.erode(mask, kernel, iterations=1)
275
-
276
- def fill_holes(self, mask: np.ndarray) -> np.ndarray:
277
- """
278
- Fill holes in the mask.
279
-
280
- Args:
281
- mask: Input mask
282
-
283
- Returns:
284
- Mask with filled holes
285
- """
286
- if mask is None:
287
- return None
288
-
289
- # Find contours
290
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
291
-
292
- # Create filled mask
293
- filled_mask = np.zeros_like(mask)
294
- cv2.fillPoly(filled_mask, contours, 255)
295
-
296
- return filled_mask
 
1
  """
2
+ Minimal mask handling for the Sorghum Pipeline.
 
 
 
3
  """
4
 
5
  import numpy as np
6
  import cv2
 
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
12
  class MaskHandler:
13
+ """Minimal mask handling."""
14
 
15
  def __init__(self, min_area: int = 1000, kernel_size: int = 7):
16
+ """Initialize mask handler."""
 
 
 
 
 
 
17
  self.min_area = min_area
18
  self.kernel_size = kernel_size
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
21
+ """Apply mask to image."""
 
 
 
 
 
 
 
 
 
22
  if mask is None:
23
  return image
24
+ if mask.shape[:2] != image.shape[:2]:
25
+ mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]),
26
+ interpolation=cv2.INTER_NEAREST)
27
+ binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
28
+ return cv2.bitwise_and(image, image, mask=binary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/data/preprocessor.py CHANGED
@@ -1,14 +1,11 @@
1
  """
2
- Image preprocessing functionality for the Sorghum Pipeline.
3
-
4
- This module handles image preprocessing, composite creation,
5
- and basic image transformations.
6
  """
7
 
8
  import numpy as np
9
  import cv2
10
  from PIL import Image
11
- from typing import Dict, Tuple, Any, Optional
12
  from itertools import product
13
  import logging
14
 
@@ -16,72 +13,36 @@ logger = logging.getLogger(__name__)
16
 
17
 
18
  class ImagePreprocessor:
19
- """Handles image preprocessing and composite creation."""
20
 
21
- def __init__(self, target_size: Optional[Tuple[int, int]] = None):
22
- """
23
- Initialize the image preprocessor.
24
-
25
- Args:
26
- target_size: Target size for image resizing (width, height)
27
- """
28
  self.target_size = target_size
29
 
30
  def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
31
- """
32
- Convert array to uint8 format with proper normalization.
33
-
34
- Args:
35
- arr: Input array
36
-
37
- Returns:
38
- Normalized uint8 array
39
- """
40
- # Handle NaN and infinite values
41
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
42
-
43
- # Normalize to 0-255 range
44
  if arr.ptp() > 0:
45
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
46
  else:
47
  normalized = np.zeros_like(arr)
48
-
49
  return np.clip(normalized, 0, 255).astype(np.uint8)
50
 
51
  def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
52
- """
53
- Process raw 4-band image into composite and spectral bands.
54
-
55
- Args:
56
- pil_img: PIL Image object containing 4-band data
57
-
58
- Returns:
59
- Tuple of (composite_image, spectral_bands_dict)
60
- """
61
- # Split the 4-band RAW into tiles and stack them
62
  d = pil_img.size[0] // 2
63
  boxes = [
64
  (j, i, j + d, i + d)
65
- for i, j in product(
66
- range(0, pil_img.height, d),
67
- range(0, pil_img.width, d)
68
- )
69
  ]
70
 
71
- # Extract tiles and stack them
72
- stack = np.stack([
73
- np.array(pil_img.crop(box), dtype=float)
74
- for box in boxes
75
- ], axis=-1)
76
-
77
- # Bands come in order: [green, red, red_edge, nir]
78
  green, red, red_edge, nir = np.split(stack, 4, axis=-1)
79
 
80
- # Build pseudo-RGB composite as (green, red_edge, red)
81
  composite = np.concatenate([green, red_edge, red], axis=-1)
82
  composite_uint8 = self.convert_to_uint8(composite)
83
 
84
- # Prepare spectral stack
85
  spectral_bands = {
86
  "green": green,
87
  "red": red,
@@ -92,188 +53,14 @@ class ImagePreprocessor:
92
  return composite_uint8, spectral_bands
93
 
94
  def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
95
- """
96
- Create composites for all plants in the dataset.
97
-
98
- Args:
99
- plants: Dictionary of plant data
100
-
101
- Returns:
102
- Updated plant data with composites and spectral stacks
103
- """
104
- logger.info("Creating composites for all plants...")
105
-
106
  for key, pdata in plants.items():
107
  try:
108
- # Find the PIL Image
109
  if "raw_image" in pdata:
110
  image, _ = pdata["raw_image"]
111
- elif "raw_images" in pdata and pdata["raw_images"]:
112
- image, _ = pdata["raw_images"][0]
113
- else:
114
- logger.warning(f"No raw image found for {key}")
115
- continue
116
-
117
- # Process the image
118
- composite, spectral_stack = self.process_raw_image(image)
119
-
120
- # Store results
121
- pdata["composite"] = composite
122
- pdata["spectral_stack"] = spectral_stack
123
-
124
- logger.debug(f"Created composite for {key}")
125
-
126
  except Exception as e:
127
  logger.error(f"Failed to create composite for {key}: {e}")
128
- continue
129
-
130
- logger.info("Composite creation completed")
131
- return plants
132
-
133
- def resize_image(self, image: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
134
- """
135
- Resize image to target size.
136
-
137
- Args:
138
- image: Input image
139
- target_size: Target size (width, height). If None, uses self.target_size
140
-
141
- Returns:
142
- Resized image
143
- """
144
- if target_size is None:
145
- target_size = self.target_size
146
-
147
- if target_size is None:
148
- return image
149
-
150
- return cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
151
-
152
- def normalize_image(self, image: np.ndarray, method: str = "minmax") -> np.ndarray:
153
- """
154
- Normalize image using specified method.
155
-
156
- Args:
157
- image: Input image
158
- method: Normalization method ("minmax", "zscore", "robust")
159
-
160
- Returns:
161
- Normalized image
162
- """
163
- if method == "minmax":
164
- if image.dtype == np.uint8:
165
- return image.astype(np.float32) / 255.0
166
- else:
167
- img_min, img_max = image.min(), image.max()
168
- if img_max > img_min:
169
- return (image - img_min) / (img_max - img_min)
170
- else:
171
- return np.zeros_like(image, dtype=np.float32)
172
-
173
- elif method == "zscore":
174
- mean, std = image.mean(), image.std()
175
- if std > 0:
176
- return (image - mean) / std
177
- else:
178
- return np.zeros_like(image, dtype=np.float32)
179
-
180
- elif method == "robust":
181
- q25, q75 = np.percentile(image, [25, 75])
182
- if q75 > q25:
183
- return (image - q25) / (q75 - q25)
184
- else:
185
- return np.zeros_like(image, dtype=np.float32)
186
-
187
- else:
188
- raise ValueError(f"Unknown normalization method: {method}")
189
-
190
- def apply_gaussian_blur(self, image: np.ndarray, kernel_size: int = 5) -> np.ndarray:
191
- """
192
- Apply Gaussian blur to image.
193
-
194
- Args:
195
- image: Input image
196
- kernel_size: Size of Gaussian kernel
197
-
198
- Returns:
199
- Blurred image
200
- """
201
- if kernel_size % 2 == 0:
202
- kernel_size += 1
203
-
204
- return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
205
-
206
- def apply_sharpening(self, image: np.ndarray) -> np.ndarray:
207
- """
208
- Apply sharpening filter to image.
209
-
210
- Args:
211
- image: Input image
212
-
213
- Returns:
214
- Sharpened image
215
- """
216
- kernel = np.array([
217
- [0, -1, 0],
218
- [-1, 5, -1],
219
- [0, -1, 0]
220
- ])
221
-
222
- return cv2.filter2D(image, -1, kernel)
223
-
224
- def enhance_contrast(self, image: np.ndarray, alpha: float = 1.2, beta: int = 15) -> np.ndarray:
225
- """
226
- Enhance image contrast.
227
-
228
- Args:
229
- image: Input image
230
- alpha: Contrast control (1.0 = no change)
231
- beta: Brightness control (0 = no change)
232
-
233
- Returns:
234
- Enhanced image
235
- """
236
- return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
237
-
238
- def create_overlay(self, base_image: np.ndarray, mask: np.ndarray,
239
- color: Tuple[int, int, int] = (0, 255, 0),
240
- alpha: float = 0.5) -> np.ndarray:
241
- """
242
- Create overlay of mask on base image.
243
-
244
- Args:
245
- base_image: Base image
246
- mask: Binary mask
247
- color: Overlay color (B, G, R)
248
- alpha: Overlay transparency
249
-
250
- Returns:
251
- Image with overlay
252
- """
253
- overlay = base_image.copy()
254
- overlay[mask == 255] = color
255
- return cv2.addWeighted(base_image, 1.0 - alpha, overlay, alpha, 0)
256
-
257
- def validate_composite(self, composite: np.ndarray) -> bool:
258
- """
259
- Validate composite image.
260
-
261
- Args:
262
- composite: Composite image to validate
263
-
264
- Returns:
265
- True if valid, False otherwise
266
- """
267
- if composite is None:
268
- return False
269
-
270
- if not isinstance(composite, np.ndarray):
271
- return False
272
-
273
- if composite.ndim != 3 or composite.shape[2] != 3:
274
- return False
275
-
276
- if composite.dtype != np.uint8:
277
- return False
278
-
279
- return True
 
1
  """
2
+ Minimal image preprocessing for the Sorghum Pipeline.
 
 
 
3
  """
4
 
5
  import numpy as np
6
  import cv2
7
  from PIL import Image
8
+ from typing import Dict, Tuple, Any
9
  from itertools import product
10
  import logging
11
 
 
13
 
14
 
15
  class ImagePreprocessor:
16
+ """Minimal image preprocessing."""
17
 
18
+ def __init__(self, target_size=None):
19
+ """Initialize preprocessor."""
 
 
 
 
 
20
  self.target_size = target_size
21
 
22
  def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
23
+ """Convert array to uint8."""
 
 
 
 
 
 
 
 
 
24
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
 
 
25
  if arr.ptp() > 0:
26
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
27
  else:
28
  normalized = np.zeros_like(arr)
 
29
  return np.clip(normalized, 0, 255).astype(np.uint8)
30
 
31
  def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
32
+ """Process 4-band image into composite and spectral bands."""
 
 
 
 
 
 
 
 
 
33
  d = pil_img.size[0] // 2
34
  boxes = [
35
  (j, i, j + d, i + d)
36
+ for i, j in product(range(0, pil_img.height, d), range(0, pil_img.width, d))
 
 
 
37
  ]
38
 
39
+ stack = np.stack([np.array(pil_img.crop(box), dtype=float) for box in boxes], axis=-1)
 
 
 
 
 
 
40
  green, red, red_edge, nir = np.split(stack, 4, axis=-1)
41
 
42
+ # Pseudo-RGB composite: (green, red_edge, red)
43
  composite = np.concatenate([green, red_edge, red], axis=-1)
44
  composite_uint8 = self.convert_to_uint8(composite)
45
 
 
46
  spectral_bands = {
47
  "green": green,
48
  "red": red,
 
53
  return composite_uint8, spectral_bands
54
 
55
  def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
56
+ """Create composites for all plants."""
 
 
 
 
 
 
 
 
 
 
57
  for key, pdata in plants.items():
58
  try:
 
59
  if "raw_image" in pdata:
60
  image, _ = pdata["raw_image"]
61
+ composite, spectral_stack = self.process_raw_image(image)
62
+ pdata["composite"] = composite
63
+ pdata["spectral_stack"] = spectral_stack
 
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
  logger.error(f"Failed to create composite for {key}: {e}")
66
+ return plants
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/features/morphology.py CHANGED
@@ -1,44 +1,32 @@
1
  """
2
- Morphological feature extraction for the Sorghum Pipeline.
3
-
4
- This module handles extraction of morphological features using PlantCV
5
- and other computer vision techniques.
6
  """
7
 
8
  import numpy as np
9
  import cv2
10
  import contextlib
11
  import sys
12
- from typing import Dict, Any, Optional, List, Tuple
13
  import logging
14
 
15
- # Try to import PlantCV, but don't fail if not available
16
  try:
17
  from plantcv import plantcv as pcv
18
  PLANT_CV_AVAILABLE = True
19
  except ImportError:
20
  PLANT_CV_AVAILABLE = False
21
- logger.warning("PlantCV not available. Morphological features will be limited.")
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
  class MorphologyExtractor:
27
- """Extracts morphological features from plant images."""
28
 
29
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
30
- """
31
- Initialize morphology extractor.
32
-
33
- Args:
34
- pixel_to_cm: Conversion factor from pixels to centimeters
35
- prune_sizes: List of pruning sizes for skeleton processing
36
- """
37
  self.pixel_to_cm = pixel_to_cm
38
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
39
 
40
  if PLANT_CV_AVAILABLE:
41
- # Configure PlantCV
42
  pcv.params.debug = None
43
  pcv.params.text_size = 0.7
44
  pcv.params.text_thickness = 2
@@ -46,283 +34,53 @@ class MorphologyExtractor:
46
  pcv.params.dpi = 100
47
 
48
  def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
49
- """
50
- Extract morphological features from plant image and mask.
51
 
52
- Args:
53
- image: Plant image (BGR format)
54
- mask: Binary mask of the plant
55
-
56
- Returns:
57
- Dictionary containing morphological features and images
58
- """
59
- features = {
60
- 'traits': {},
61
- 'images': {},
62
- 'success': False
63
- }
64
 
65
  try:
66
- # Preprocess mask
67
  clean_mask = self._preprocess_mask(mask)
68
  if clean_mask is None:
69
- logger.warning("Failed to preprocess mask")
70
  return features
71
 
72
- # Extract basic morphological features
73
- basic_traits = self._extract_basic_features(clean_mask)
74
- features['traits'].update(basic_traits)
75
-
76
- # Extract skeleton-based features if PlantCV is available
77
- if PLANT_CV_AVAILABLE:
78
- skeleton_features = self._extract_skeleton_features(image, clean_mask)
79
- features['traits'].update(skeleton_features['traits'])
80
- features['images'].update(skeleton_features['images'])
81
- else:
82
- # Fallback to basic OpenCV features
83
- cv_features = self._extract_opencv_features(image, clean_mask)
84
- features['traits'].update(cv_features['traits'])
85
- features['images'].update(cv_features['images'])
86
-
87
- features['success'] = True
88
- logger.debug("Morphological features extracted successfully")
89
 
90
  except Exception as e:
91
- logger.error(f"Morphological feature extraction failed: {e}")
92
 
93
  return features
94
 
95
- def _preprocess_mask(self, mask: np.ndarray) -> Optional[np.ndarray]:
96
- """Preprocess mask for morphological analysis."""
97
  if mask is None:
98
  return None
99
-
100
- # Convert to binary if needed
101
- if isinstance(mask, tuple):
102
- mask = mask[0]
103
-
104
- # Ensure binary format
105
  mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
106
-
107
- # Morphological opening to remove noise
108
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
109
  opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
110
 
111
- # Remove small connected components
112
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
113
  clean_mask = np.zeros_like(opened)
114
 
115
- for label in range(1, num_labels): # Skip background
116
  if stats[label, cv2.CC_STAT_AREA] >= 1000:
117
  clean_mask[labels == label] = 255
118
 
119
  return clean_mask
120
 
121
- def _extract_basic_features(self, mask: np.ndarray) -> Dict[str, float]:
122
- """Extract basic morphological features using OpenCV."""
123
- features = {}
124
-
125
- try:
126
- # Find contours
127
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
128
-
129
- if not contours:
130
- return features
131
-
132
- # Get the largest contour
133
- largest_contour = max(contours, key=cv2.contourArea)
134
-
135
- # Basic measurements
136
- area = cv2.contourArea(largest_contour)
137
- perimeter = cv2.arcLength(largest_contour, True)
138
-
139
- # Bounding box
140
- x, y, w, h = cv2.boundingRect(largest_contour)
141
- bbox_area = w * h
142
-
143
- # Ellipse fitting
144
- if len(largest_contour) >= 5:
145
- ellipse = cv2.fitEllipse(largest_contour)
146
- (center, axes, angle) = ellipse
147
- major_axis = max(axes)
148
- minor_axis = min(axes)
149
- else:
150
- major_axis = max(w, h)
151
- minor_axis = min(w, h)
152
-
153
- # Convert to centimeters
154
- features['area_cm2'] = area * (self.pixel_to_cm ** 2)
155
- features['perimeter_cm'] = perimeter * self.pixel_to_cm
156
- features['width_cm'] = w * self.pixel_to_cm
157
- features['height_cm'] = h * self.pixel_to_cm
158
- features['bbox_area_cm2'] = bbox_area * (self.pixel_to_cm ** 2)
159
- features['major_axis_cm'] = major_axis * self.pixel_to_cm
160
- features['minor_axis_cm'] = minor_axis * self.pixel_to_cm
161
- features['aspect_ratio'] = w / h if h > 0 else 0
162
- features['elongation'] = major_axis / minor_axis if minor_axis > 0 else 0
163
- features['circularity'] = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
164
- features['solidity'] = area / bbox_area if bbox_area > 0 else 0
165
-
166
- # Convex hull
167
- hull = cv2.convexHull(largest_contour)
168
- hull_area = cv2.contourArea(hull)
169
- features['convexity'] = area / hull_area if hull_area > 0 else 0
170
-
171
- except Exception as e:
172
- logger.error(f"Basic feature extraction failed: {e}")
173
-
174
- return features
175
-
176
- def _extract_skeleton_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
177
- """Extract skeleton-based features using PlantCV."""
178
- features = {'traits': {}, 'images': {}}
179
-
180
- if not PLANT_CV_AVAILABLE:
181
- return features
182
-
183
- try:
184
- # Suppress PlantCV output
185
- with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
186
- contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
187
-
188
- # Skeletonize
189
- skeleton = pcv.morphology.skeletonize(mask=mask)
190
- features['images']['skeleton'] = skeleton
191
-
192
- # Prune skeleton
193
- pruned_skel = skeleton
194
- for size in self.prune_sizes:
195
- pruned_skel, _, _ = pcv.morphology.prune(
196
- skel_img=pruned_skel, size=size, mask=mask
197
- )
198
-
199
- features['images']['pruned_skeleton'] = pruned_skel
200
-
201
- # Find branch points and tips
202
- branch_pts = pcv.morphology.find_branch_pts(pruned_skel, mask)
203
- features['images']['branch_points'] = branch_pts
204
-
205
- try:
206
- tip_pts = pcv.morphology.find_tips(pruned_skel, mask)
207
- features['images']['tip_points'] = tip_pts
208
- except Exception as e:
209
- logger.warning(f"Tip detection failed: {e}")
210
-
211
- # Segment objects
212
- try:
213
- leaf_obj, stem_obj = pcv.morphology.segment_sort(
214
- pruned_skel, [], mask
215
- )
216
- features['traits']['num_leaves'] = len(leaf_obj)
217
- features['traits']['num_stems'] = len(stem_obj)
218
- except Exception as e:
219
- logger.warning(f"Object segmentation failed: {e}")
220
- features['traits']['num_leaves'] = 0
221
- features['traits']['num_stems'] = 0
222
-
223
- # Size analysis
224
- try:
225
- labeled_mask, n_labels = pcv.create_labels(mask)
226
- size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
227
- features['images']['size_analysis'] = size_analysis
228
-
229
- # Get size traits
230
- obs = pcv.outputs.observations.get("default_1", {})
231
- for trait, info in obs.items():
232
- if trait not in ["in_bounds", "object_in_frame"]:
233
- val = info.get("value", None)
234
- if val is not None:
235
- if trait == "area":
236
- val = val * (self.pixel_to_cm ** 2)
237
- elif trait in ["perimeter", "width", "height", "longest_path",
238
- "ellipse_major_axis", "ellipse_minor_axis"]:
239
- val = val * self.pixel_to_cm
240
- features['traits'][trait] = val
241
-
242
- except Exception as e:
243
- logger.warning(f"Size analysis failed: {e}")
244
-
245
- except Exception as e:
246
- logger.error(f"Skeleton feature extraction failed: {e}")
247
-
248
- return features
249
-
250
- def _extract_opencv_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
251
- """Extract features using only OpenCV (fallback when PlantCV is not available)."""
252
- features = {'traits': {}, 'images': {}}
253
-
254
- try:
255
- # Create skeleton using OpenCV
256
- skeleton = self._create_skeleton_opencv(mask)
257
- features['images']['skeleton'] = skeleton
258
-
259
- # Find branch points
260
- branch_points = self._find_branch_points_opencv(skeleton)
261
- features['images']['branch_points'] = branch_points
262
- features['traits']['num_branches'] = len(branch_points)
263
-
264
- # Find endpoints
265
- endpoints = self._find_endpoints_opencv(skeleton)
266
- features['images']['endpoints'] = endpoints
267
- features['traits']['num_endpoints'] = len(endpoints)
268
-
269
- # Skeleton length
270
- skeleton_length = np.sum(skeleton > 0)
271
- features['traits']['skeleton_length_pixels'] = skeleton_length
272
- features['traits']['skeleton_length_cm'] = skeleton_length * self.pixel_to_cm
273
-
274
- except Exception as e:
275
- logger.error(f"OpenCV feature extraction failed: {e}")
276
-
277
- return features
278
-
279
- def _create_skeleton_opencv(self, mask: np.ndarray) -> np.ndarray:
280
- """Create skeleton using OpenCV."""
281
- # Convert to binary
282
- binary = (mask > 0).astype(np.uint8)
283
-
284
- # Create skeleton using morphological operations
285
- skeleton = np.zeros_like(binary)
286
- element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
287
-
288
- while True:
289
- eroded = cv2.erode(binary, element)
290
- temp = cv2.dilate(eroded, element)
291
- temp = cv2.subtract(binary, temp)
292
- skeleton = cv2.bitwise_or(skeleton, temp)
293
- binary = eroded.copy()
294
-
295
- if cv2.countNonZero(binary) == 0:
296
- break
297
-
298
- return skeleton * 255
299
-
300
- def _find_branch_points_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
301
- """Find branch points in skeleton using OpenCV."""
302
- # Count neighbors for each pixel
303
- kernel = np.ones((3, 3), dtype=np.uint8)
304
- kernel[1, 1] = 0 # Don't count center pixel
305
-
306
- neighbor_count = cv2.filter2D(skeleton, -1, kernel)
307
-
308
- # Branch points have 3 or more neighbors
309
- branch_points = np.where((skeleton > 0) & (neighbor_count >= 3))
310
- return list(zip(branch_points[1], branch_points[0])) # (x, y) format
311
-
312
- def _find_endpoints_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
313
- """Find endpoints in skeleton using OpenCV."""
314
- # Count neighbors for each pixel
315
- kernel = np.ones((3, 3), dtype=np.uint8)
316
- kernel[1, 1] = 0 # Don't count center pixel
317
-
318
- neighbor_count = cv2.filter2D(skeleton, -1, kernel)
319
-
320
- # Endpoints have exactly 1 neighbor
321
- endpoints = np.where((skeleton > 0) & (neighbor_count == 1))
322
- return list(zip(endpoints[1], endpoints[0])) # (x, y) format
323
-
324
  class _FilteredStream:
325
- """Filter PlantCV output to reduce noise."""
326
  def __init__(self, stream):
327
  self.stream = stream
328
 
@@ -335,46 +93,4 @@ class MorphologyExtractor:
335
  try:
336
  self.stream.flush()
337
  except Exception:
338
- pass
339
-
340
- def create_morphology_visualization(self, image: np.ndarray, mask: np.ndarray,
341
- features: Dict[str, Any]) -> np.ndarray:
342
- """
343
- Create visualization of morphological features.
344
-
345
- Args:
346
- image: Original image
347
- mask: Binary mask
348
- features: Extracted features
349
-
350
- Returns:
351
- Visualization image
352
- """
353
- try:
354
- # Create visualization
355
- vis = image.copy()
356
-
357
- # Draw mask outline
358
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
359
- cv2.drawContours(vis, contours, -1, (0, 255, 0), 2)
360
-
361
- # Draw bounding box
362
- if contours:
363
- x, y, w, h = cv2.boundingRect(contours[0])
364
- cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
365
-
366
- # Draw skeleton if available
367
- if 'skeleton' in features.get('images', {}):
368
- skeleton = features['images']['skeleton']
369
- vis[skeleton > 0] = [0, 0, 255] # Red skeleton
370
-
371
- # Draw branch points if available
372
- if 'branch_points' in features.get('images', {}):
373
- branch_img = features['images']['branch_points']
374
- vis[branch_img > 0] = [255, 255, 0] # Yellow branch points
375
-
376
- return vis
377
-
378
- except Exception as e:
379
- logger.error(f"Visualization creation failed: {e}")
380
- return image
 
1
  """
2
+ Minimal morphological feature extraction (PlantCV size analysis only).
 
 
 
3
  """
4
 
5
  import numpy as np
6
  import cv2
7
  import contextlib
8
  import sys
9
+ from typing import Dict, Any, List
10
  import logging
11
 
 
12
  try:
13
  from plantcv import plantcv as pcv
14
  PLANT_CV_AVAILABLE = True
15
  except ImportError:
16
  PLANT_CV_AVAILABLE = False
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
  class MorphologyExtractor:
22
+ """Minimal morphology extraction (PlantCV size analysis)."""
23
 
24
  def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
25
+ """Initialize."""
 
 
 
 
 
 
26
  self.pixel_to_cm = pixel_to_cm
27
  self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
28
 
29
  if PLANT_CV_AVAILABLE:
 
30
  pcv.params.debug = None
31
  pcv.params.text_size = 0.7
32
  pcv.params.text_thickness = 2
 
34
  pcv.params.dpi = 100
35
 
36
  def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
37
+ """Extract only PlantCV size analysis image."""
38
+ features = {'traits': {}, 'images': {}, 'success': False}
39
 
40
+ if not PLANT_CV_AVAILABLE:
41
+ logger.warning("PlantCV not available")
42
+ return features
 
 
 
 
 
 
 
 
 
43
 
44
  try:
 
45
  clean_mask = self._preprocess_mask(mask)
46
  if clean_mask is None:
 
47
  return features
48
 
49
+ # Size analysis only
50
+ with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
51
+ contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
52
+ try:
53
+ labeled_mask, n_labels = pcv.create_labels(clean_mask)
54
+ size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
55
+ features['images']['size_analysis'] = size_analysis
56
+ features['success'] = True
57
+ except Exception as e:
58
+ logger.warning(f"Size analysis failed: {e}")
 
 
 
 
 
 
 
59
 
60
  except Exception as e:
61
+ logger.error(f"Morphology extraction failed: {e}")
62
 
63
  return features
64
 
65
+ def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
66
+ """Preprocess mask."""
67
  if mask is None:
68
  return None
 
 
 
 
 
 
69
  mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
 
 
70
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
71
  opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
72
 
 
73
  num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
74
  clean_mask = np.zeros_like(opened)
75
 
76
+ for label in range(1, num_labels):
77
  if stats[label, cv2.CC_STAT_AREA] >= 1000:
78
  clean_mask[labels == label] = 255
79
 
80
  return clean_mask
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  class _FilteredStream:
83
+ """Filter PlantCV output."""
84
  def __init__(self, stream):
85
  self.stream = stream
86
 
 
93
  try:
94
  self.stream.flush()
95
  except Exception:
96
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/features/texture.py CHANGED
@@ -1,299 +1,79 @@
1
  """
2
- Texture feature extraction for the Sorghum Pipeline.
3
-
4
- This module handles extraction of texture features including:
5
- - Local Binary Patterns (LBP)
6
- - Histogram of Oriented Gradients (HOG)
7
- - Lacunarity features
8
- - Edge Histogram Descriptor (EHD)
9
  """
10
 
11
  import numpy as np
12
- import cv2
13
  import torch
14
  import torch.nn.functional as F
15
  from skimage.feature import local_binary_pattern, hog
16
  from skimage import exposure
17
- from scipy import ndimage, signal
18
- from sklearn.decomposition import PCA
19
- from typing import Dict, Tuple, Optional, Any
20
  import logging
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
  class TextureExtractor:
26
- """Extracts texture features from images."""
27
 
28
- def __init__(self,
29
- lbp_points: int = 8,
30
- lbp_radius: int = 1,
31
- hog_orientations: int = 9,
32
- hog_pixels_per_cell: Tuple[int, int] = (8, 8),
33
- hog_cells_per_block: Tuple[int, int] = (2, 2),
34
- lacunarity_window: int = 15,
35
- ehd_threshold: float = 0.3,
36
- angle_resolution: int = 45):
37
- """
38
- Initialize texture extractor.
39
-
40
- Args:
41
- lbp_points: Number of points for LBP
42
- lbp_radius: Radius for LBP
43
- hog_orientations: Number of orientations for HOG
44
- hog_pixels_per_cell: Pixels per cell for HOG
45
- hog_cells_per_block: Cells per block for HOG
46
- lacunarity_window: Window size for lacunarity
47
- ehd_threshold: Threshold for EHD
48
- angle_resolution: Angle resolution for EHD
49
- """
50
  self.lbp_points = lbp_points
51
  self.lbp_radius = lbp_radius
52
  self.hog_orientations = hog_orientations
53
  self.hog_pixels_per_cell = hog_pixels_per_cell
54
  self.hog_cells_per_block = hog_cells_per_block
55
  self.lacunarity_window = lacunarity_window
56
- self.ehd_threshold = ehd_threshold
57
- self.angle_resolution = angle_resolution
58
 
59
  def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
60
- """
61
- Extract Local Binary Pattern features.
62
-
63
- Args:
64
- gray_image: Grayscale input image
65
-
66
- Returns:
67
- LBP feature map
68
- """
69
  try:
70
- lbp = local_binary_pattern(
71
- gray_image,
72
- self.lbp_points,
73
- self.lbp_radius,
74
- method='uniform'
75
- )
76
  return self._convert_to_uint8(lbp)
77
  except Exception as e:
78
- logger.error(f"LBP extraction failed: {e}")
79
  return np.zeros_like(gray_image, dtype=np.uint8)
80
 
81
  def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
82
- """
83
- Extract Histogram of Oriented Gradients features.
84
-
85
- Args:
86
- gray_image: Grayscale input image
87
-
88
- Returns:
89
- HOG feature map
90
- """
91
  try:
92
- _, vis = hog(
93
- gray_image,
94
- orientations=self.hog_orientations,
95
- pixels_per_cell=self.hog_pixels_per_cell,
96
- cells_per_block=self.hog_cells_per_block,
97
- visualize=True,
98
- feature_vector=True
99
- )
100
  return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
101
  except Exception as e:
102
- logger.error(f"HOG extraction failed: {e}")
103
  return np.zeros_like(gray_image, dtype=np.uint8)
104
 
105
- def compute_local_lacunarity(self, gray_image: np.ndarray, window_size: int) -> np.ndarray:
106
- """
107
- Compute local lacunarity.
108
-
109
- Args:
110
- gray_image: Grayscale input image
111
- window_size: Size of the sliding window
112
-
113
- Returns:
114
- Local lacunarity map
115
- """
116
  try:
117
  arr = gray_image.astype(np.float32)
118
- m1 = ndimage.uniform_filter(arr, size=window_size)
119
- m2 = ndimage.uniform_filter(arr * arr, size=window_size)
120
  var = m2 - m1 * m1
121
- eps = 1e-6
122
- lac = var / (m1 * m1 + eps) + 1
123
- lac[m1 <= eps] = 0
124
- return lac
125
- except Exception as e:
126
- logger.error(f"Local lacunarity computation failed: {e}")
127
- return np.zeros_like(gray_image, dtype=np.float32)
128
-
129
- def compute_lacunarity_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
130
- """
131
- Compute three types of lacunarity features.
132
-
133
- Args:
134
- gray_image: Grayscale input image
135
-
136
- Returns:
137
- Tuple of (lac1, lac2, lac3) lacunarity maps
138
- """
139
- try:
140
- # L1: Single window lacunarity
141
- lac1 = self.compute_local_lacunarity(gray_image, self.lacunarity_window)
142
-
143
- # L2: Multi-scale lacunarity
144
- scales = [max(3, self.lacunarity_window//2), self.lacunarity_window, self.lacunarity_window*2]
145
- lac2 = np.mean([
146
- self.compute_local_lacunarity(gray_image, s) for s in scales
147
- ], axis=0)
148
-
149
- # L3: DBC Lacunarity (if available)
150
- try:
151
- from ..models.dbc_lacunarity import DBC_Lacunarity
152
- x = torch.from_numpy(gray_image.astype(np.float32)/255.0)[None, None]
153
- layer = DBC_Lacunarity(window_size=self.lacunarity_window).eval()
154
- with torch.no_grad():
155
- lac3 = layer(x).squeeze().cpu().numpy()
156
- except ImportError:
157
- logger.warning("DBC Lacunarity not available, using L2 as L3")
158
- lac3 = lac2.copy()
159
-
160
- return (
161
- self._convert_to_uint8(lac1),
162
- self._convert_to_uint8(lac2),
163
- self._convert_to_uint8(lac3)
164
- )
165
- except Exception as e:
166
- logger.error(f"Lacunarity features computation failed: {e}")
167
- empty = np.zeros_like(gray_image, dtype=np.uint8)
168
- return empty, empty, empty
169
-
170
- def generate_ehd_masks(self, mask_size: int = 3) -> np.ndarray:
171
- """
172
- Generate masks for Edge Histogram Descriptor.
173
-
174
- Args:
175
- mask_size: Size of the mask
176
-
177
- Returns:
178
- Array of EHD masks
179
- """
180
- if mask_size < 3:
181
- mask_size = 3
182
- if mask_size % 2 == 0:
183
- mask_size += 1
184
-
185
- # Base gradient mask
186
- Gy = np.outer([1, 0, -1], [1, 2, 1])
187
-
188
- # Expand if needed
189
- if mask_size > 3:
190
- expd = np.outer([1, 2, 1], [1, 2, 1])
191
- for _ in range((mask_size - 3) // 2):
192
- Gy = signal.convolve2d(expd, Gy, mode='full')
193
-
194
- # Generate masks for different angles
195
- angles = np.arange(0, 360, self.angle_resolution)
196
- masks = np.zeros((len(angles), mask_size, mask_size), dtype=np.float32)
197
-
198
- for i, angle in enumerate(angles):
199
- masks[i] = ndimage.rotate(Gy, angle, reshape=False, mode='nearest')
200
-
201
- return masks
202
-
203
- def extract_ehd_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
204
- """
205
- Extract Edge Histogram Descriptor features.
206
-
207
- Args:
208
- gray_image: Grayscale input image
209
-
210
- Returns:
211
- Tuple of (ehd_features, ehd_map)
212
- """
213
- try:
214
- # Generate masks
215
- masks = self.generate_ehd_masks()
216
-
217
- # Convert to tensor
218
- X = torch.from_numpy(gray_image.astype(np.float32)/255.0).unsqueeze(0).unsqueeze(0)
219
- masks_tensor = torch.tensor(masks).unsqueeze(1).float()
220
-
221
- # Convolve with masks
222
- edge_responses = F.conv2d(X, masks_tensor, dilation=7)
223
-
224
- # Find maximum response
225
- values, indices = torch.max(edge_responses, dim=1)
226
- indices[values < self.ehd_threshold] = masks.shape[0]
227
-
228
- # Pool features
229
- feat_vect = []
230
- for edge in range(masks.shape[0] + 1):
231
- pooled = F.avg_pool2d(
232
- (indices == edge).unsqueeze(1).float(),
233
- kernel_size=5, stride=1, padding=2
234
- )
235
- feat_vect.append(pooled.squeeze(1))
236
-
237
- ehd_features = torch.stack(feat_vect, dim=1).squeeze(0).cpu().numpy()
238
- ehd_map = np.argmax(ehd_features, axis=0).astype(np.uint8)
239
-
240
- return ehd_features, ehd_map
241
-
242
  except Exception as e:
243
- logger.error(f"EHD features extraction failed: {e}")
244
- empty_features = np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32)
245
- empty_map = np.zeros_like(gray_image, dtype=np.uint8)
246
- return empty_features, empty_map
247
 
248
  def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
249
- """
250
- Extract all texture features from a grayscale image.
251
-
252
- Args:
253
- gray_image: Grayscale input image
254
-
255
- Returns:
256
- Dictionary of texture features
257
- """
258
- features = {}
259
-
260
- try:
261
- # LBP
262
- features['lbp'] = self.extract_lbp(gray_image)
263
-
264
- # HOG
265
- features['hog'] = self.extract_hog(gray_image)
266
-
267
- # Lacunarity
268
- lac1, lac2, lac3 = self.compute_lacunarity_features(gray_image)
269
- features['lac1'] = lac1
270
- features['lac2'] = lac2
271
- features['lac3'] = lac3
272
-
273
- # EHD
274
- ehd_features, ehd_map = self.extract_ehd_features(gray_image)
275
- features['ehd_features'] = ehd_features
276
- features['ehd_map'] = ehd_map
277
-
278
- logger.debug("All texture features extracted successfully")
279
-
280
- except Exception as e:
281
- logger.error(f"Texture feature extraction failed: {e}")
282
- # Return empty features
283
- features = {
284
- 'lbp': np.zeros_like(gray_image, dtype=np.uint8),
285
- 'hog': np.zeros_like(gray_image, dtype=np.uint8),
286
- 'lac1': np.zeros_like(gray_image, dtype=np.uint8),
287
- 'lac2': np.zeros_like(gray_image, dtype=np.uint8),
288
- 'lac3': np.zeros_like(gray_image, dtype=np.uint8),
289
- 'ehd_features': np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32),
290
- 'ehd_map': np.zeros_like(gray_image, dtype=np.uint8)
291
- }
292
-
293
- return features
294
 
295
  def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
296
- """Convert array to uint8 with proper normalization."""
297
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
298
  if arr.ptp() > 0:
299
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
@@ -302,72 +82,19 @@ class TextureExtractor:
302
  return np.clip(normalized, 0, 255).astype(np.uint8)
303
 
304
  def compute_texture_statistics(self, features: Dict[str, np.ndarray],
305
- mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
306
- """
307
- Compute statistics for texture features.
308
-
309
- Args:
310
- features: Dictionary of texture features
311
- mask: Optional mask to apply
312
-
313
- Returns:
314
- Dictionary of feature statistics
315
- """
316
  stats = {}
317
-
318
  for feature_name, feature_data in features.items():
319
- if feature_name == 'ehd_features':
320
- # Special handling for EHD features
321
- if mask is not None:
322
- # Apply mask to each channel
323
- masked_features = []
324
- for i in range(feature_data.shape[0]):
325
- channel = feature_data[i]
326
- if mask.shape != channel.shape:
327
- # Resize mask to match channel
328
- mask_resized = cv2.resize(mask, (channel.shape[1], channel.shape[0]),
329
- interpolation=cv2.INTER_NEAREST)
330
- masked_channel = np.where(mask_resized > 0, channel, np.nan)
331
- else:
332
- masked_channel = np.where(mask > 0, channel, np.nan)
333
- masked_features.append(masked_channel)
334
- feature_data = np.stack(masked_features, axis=0)
335
- else:
336
- feature_data = feature_data
337
-
338
- # Compute statistics for each EHD channel
339
- channel_stats = {}
340
- for i in range(feature_data.shape[0]):
341
- channel = feature_data[i]
342
- valid_data = channel[~np.isnan(channel)]
343
- if len(valid_data) > 0:
344
- channel_stats[f'channel_{i}'] = {
345
- 'mean': float(np.mean(valid_data)),
346
- 'std': float(np.std(valid_data)),
347
- 'min': float(np.min(valid_data)),
348
- 'max': float(np.max(valid_data)),
349
- 'median': float(np.median(valid_data))
350
- }
351
- stats[feature_name] = channel_stats
352
  else:
353
- # Regular 2D features
354
- if mask is not None and mask.shape == feature_data.shape:
355
- masked_data = np.where(mask > 0, feature_data, np.nan)
356
- else:
357
- masked_data = feature_data
358
-
359
- valid_data = masked_data[~np.isnan(masked_data)]
360
- if len(valid_data) > 0:
361
- stats[feature_name] = {
362
- 'mean': float(np.mean(valid_data)),
363
- 'std': float(np.std(valid_data)),
364
- 'min': float(np.min(valid_data)),
365
- 'max': float(np.max(valid_data)),
366
- 'median': float(np.median(valid_data))
367
- }
368
- else:
369
- stats[feature_name] = {
370
- 'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0, 'median': 0.0
371
- }
372
-
373
- return stats
 
1
  """
2
+ Minimal texture feature extraction.
 
 
 
 
 
 
3
  """
4
 
5
  import numpy as np
 
6
  import torch
7
  import torch.nn.functional as F
8
  from skimage.feature import local_binary_pattern, hog
9
  from skimage import exposure
10
+ from scipy import ndimage
11
+ from typing import Dict, Tuple, Optional
 
12
  import logging
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
  class TextureExtractor:
18
+ """Minimal texture extraction (LBP, HOG, Lacunarity only)."""
19
 
20
+ def __init__(self, lbp_points: int = 8, lbp_radius: int = 1,
21
+ hog_orientations: int = 9, hog_pixels_per_cell: Tuple[int, int] = (8, 8),
22
+ hog_cells_per_block: Tuple[int, int] = (2, 2), lacunarity_window: int = 15,
23
+ ehd_threshold: float = 0.3, angle_resolution: int = 45):
24
+ """Initialize with defaults."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.lbp_points = lbp_points
26
  self.lbp_radius = lbp_radius
27
  self.hog_orientations = hog_orientations
28
  self.hog_pixels_per_cell = hog_pixels_per_cell
29
  self.hog_cells_per_block = hog_cells_per_block
30
  self.lacunarity_window = lacunarity_window
 
 
31
 
32
  def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
33
+ """Extract Local Binary Pattern."""
 
 
 
 
 
 
 
 
34
  try:
35
+ lbp = local_binary_pattern(gray_image, self.lbp_points, self.lbp_radius, method='uniform')
 
 
 
 
 
36
  return self._convert_to_uint8(lbp)
37
  except Exception as e:
38
+ logger.error(f"LBP failed: {e}")
39
  return np.zeros_like(gray_image, dtype=np.uint8)
40
 
41
  def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
42
+ """Extract HOG features."""
 
 
 
 
 
 
 
 
43
  try:
44
+ _, vis = hog(gray_image, orientations=self.hog_orientations,
45
+ pixels_per_cell=self.hog_pixels_per_cell,
46
+ cells_per_block=self.hog_cells_per_block,
47
+ visualize=True, feature_vector=True)
 
 
 
 
48
  return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
49
  except Exception as e:
50
+ logger.error(f"HOG failed: {e}")
51
  return np.zeros_like(gray_image, dtype=np.uint8)
52
 
53
+ def compute_local_lacunarity(self, gray_image: np.ndarray) -> np.ndarray:
54
+ """Compute lacunarity."""
 
 
 
 
 
 
 
 
 
55
  try:
56
  arr = gray_image.astype(np.float32)
57
+ m1 = ndimage.uniform_filter(arr, size=self.lacunarity_window)
58
+ m2 = ndimage.uniform_filter(arr * arr, size=self.lacunarity_window)
59
  var = m2 - m1 * m1
60
+ lac = var / (m1 * m1 + 1e-6) + 1
61
+ lac[m1 <= 1e-6] = 0
62
+ return self._convert_to_uint8(lac)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
+ logger.error(f"Lacunarity failed: {e}")
65
+ return np.zeros_like(gray_image, dtype=np.uint8)
 
 
66
 
67
  def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
68
+ """Extract LBP, HOG, and Lacunarity."""
69
+ return {
70
+ 'lbp': self.extract_lbp(gray_image),
71
+ 'hog': self.extract_hog(gray_image),
72
+ 'lac2': self.compute_local_lacunarity(gray_image)
73
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
76
+ """Convert to uint8."""
77
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
78
  if arr.ptp() > 0:
79
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
 
82
  return np.clip(normalized, 0, 255).astype(np.uint8)
83
 
84
  def compute_texture_statistics(self, features: Dict[str, np.ndarray],
85
+ mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
86
+ """Compute basic statistics."""
 
 
 
 
 
 
 
 
 
87
  stats = {}
 
88
  for feature_name, feature_data in features.items():
89
+ if mask is not None and mask.shape == feature_data.shape:
90
+ masked_data = np.where(mask > 0, feature_data, np.nan)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  else:
92
+ masked_data = feature_data
93
+
94
+ valid_data = masked_data[~np.isnan(masked_data)]
95
+ if len(valid_data) > 0:
96
+ stats[feature_name] = {
97
+ 'mean': float(np.mean(valid_data)),
98
+ 'std': float(np.std(valid_data)),
99
+ }
100
+ return stats
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/features/vegetation.py CHANGED
@@ -1,308 +1,71 @@
1
  """
2
- Vegetation index extraction for the Sorghum Pipeline.
3
-
4
- This module handles extraction of various vegetation indices
5
- from multispectral data.
6
  """
7
 
8
  import numpy as np
9
- import cv2
10
- from typing import Dict, Tuple, Optional, Any
11
  import logging
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
  class VegetationIndexExtractor:
17
- """Extracts vegetation indices from spectral data."""
18
 
19
  def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
20
- """
21
- Initialize vegetation index extractor.
22
-
23
- Args:
24
- epsilon: Small value to avoid division by zero
25
- soil_factor: Soil factor for certain indices
26
- """
27
- # Coerce to float in case config passed strings like "1e-10"
28
- try:
29
- self.epsilon = float(epsilon)
30
- except Exception:
31
- self.epsilon = 1e-10
32
- try:
33
- self.soil_factor = float(soil_factor)
34
- except Exception:
35
- self.soil_factor = 0.16
36
 
37
- # Define vegetation index formulas
38
  self.index_formulas = {
39
  "NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
40
- "GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
41
- "NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
42
- "GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
43
- "TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
44
- "MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
45
- "GRVI": lambda nir, green: nir / (green + self.epsilon),
46
- "NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
47
- "MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
48
- "OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
49
- "TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
50
- "GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
51
- # Requested additions and aliases
52
- "GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
53
- "GDVI": lambda nir, green: nir - green,
54
- "NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
55
- "DSWI4": lambda green, red: green / (red + self.epsilon),
56
- "CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
57
- "LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
58
- "CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
59
- "MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
60
- "MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
61
- "MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
62
- # MTVI variants per request
63
- "MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
64
- "MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
65
- "CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
66
  "ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
67
- "ARI2": lambda nir, green, red_edge: nir * (1.0 / (green + self.epsilon)) - nir * (1.0 / (red_edge + self.epsilon)),
68
- "DVI": lambda nir, red: nir - red,
69
- "WDVI": lambda nir, red, a=0.5: nir - a * red,
70
- "SR": lambda nir, red: nir / (red + self.epsilon),
71
- "MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
72
- "PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
73
- "GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
74
- "ExR": lambda red, green: 1.3 * red - green,
75
- "RI": lambda red, green: (red - green) / (red + green + self.epsilon),
76
- "RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
77
- "RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
78
- "RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
79
- "AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
80
- "SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
81
- "TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
82
- "TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
83
- "CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
84
- # Additional indices
85
- "RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
86
- "NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
87
- "BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
88
- "IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
89
- "EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
90
  }
91
 
92
- # Define required bands for each index
93
  self.index_bands = {
94
  "NDVI": ["nir", "red"],
95
- "GNDVI": ["nir", "green"],
96
- "NDRE": ["nir", "red_edge"],
97
- "GRNDVI": ["nir", "green", "red"],
98
- "TNDVI": ["nir", "red"],
99
- "MGRVI": ["green", "red"],
100
- "GRVI": ["nir", "green"],
101
- "NGRDI": ["green", "red"],
102
- "MSAVI": ["nir", "red"],
103
- "OSAVI": ["nir", "red"],
104
- "TSAVI": ["nir", "red"],
105
- "GSAVI": ["nir", "green"],
106
- "GOSAVI": ["nir", "green"],
107
- "GDVI": ["nir", "green"],
108
- "NDWI": ["green", "nir"],
109
- "DSWI4": ["green", "red"],
110
- "CIRE": ["nir", "red_edge"],
111
- "LCI": ["nir", "red_edge"],
112
- "CIgreen": ["nir", "green"],
113
- "MCARI": ["red_edge", "red", "green"],
114
- "MCARI1": ["nir", "red", "green"],
115
- "MCARI2": ["nir", "red", "green"],
116
- "MTVI1": ["nir", "red", "green"],
117
- "MTVI2": ["nir", "red", "green"],
118
- "CVI": ["nir", "red", "green"],
119
  "ARI": ["green", "red_edge"],
120
- "ARI2": ["nir", "green", "red_edge"],
121
- "DVI": ["nir", "red"],
122
- "WDVI": ["nir", "red"],
123
- "SR": ["nir", "red"],
124
- "MSR": ["nir", "red"],
125
- "PVI": ["nir", "red"],
126
- "GEMI": ["nir", "red"],
127
- "ExR": ["red", "green"],
128
- "RI": ["red", "green"],
129
- "RRI1": ["nir", "red_edge"],
130
- "RRI2": ["red_edge", "red"],
131
- "RRI": ["nir", "red_edge"],
132
- "AVI": ["nir", "red"],
133
- "SIPI2": ["nir", "green", "red"],
134
- "TCARI": ["red_edge", "red", "green"],
135
- "TCARIOSAVI": ["red_edge", "red", "green", "nir"],
136
- "CCCI": ["nir", "red_edge", "red"],
137
- "RDVI": ["nir", "red"],
138
- "NLI": ["nir", "red"],
139
- "BIXS": ["green", "red"],
140
- "IPVI": ["nir", "red"],
141
- "EVI2": ["nir", "red"]
142
  }
143
 
144
  def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
145
- mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
146
- """
147
- Compute vegetation indices from spectral data.
148
-
149
- Args:
150
- spectral_stack: Dictionary of spectral bands
151
- mask: Binary mask for the plant
152
-
153
- Returns:
154
- Dictionary of vegetation indices with values and statistics
155
- """
156
  indices = {}
157
 
158
  for index_name, formula in self.index_formulas.items():
159
  try:
160
- # Get required bands
161
- required_bands = self.index_bands.get(index_name, [])
162
-
163
- # Check if all required bands are available
164
  if not all(band in spectral_stack for band in required_bands):
165
- logger.warning(f"Skipping {index_name}: missing required bands")
166
  continue
167
 
168
- # Extract band data as float arrays
169
  band_data = []
170
  for band in required_bands:
171
  arr = spectral_stack[band]
172
- # Ensure numeric float np.ndarray
173
  if isinstance(arr, np.ndarray):
174
  arr = arr.squeeze(-1)
175
- arr = np.asarray(arr, dtype=np.float64)
176
- band_data.append(arr)
177
 
178
- # Compute index (ensure float math)
179
  index_values = formula(*band_data).astype(np.float64)
 
 
180
 
181
- # Apply mask
182
- if mask is not None:
183
- binary_mask = (np.asarray(mask).astype(np.int32) > 0)
184
- masked_values = np.where(binary_mask, index_values, np.nan)
185
- else:
186
- masked_values = index_values
187
-
188
- # Compute statistics
189
  valid_values = masked_values[~np.isnan(masked_values)]
190
  if len(valid_values) > 0:
191
  stats = {
192
  'mean': float(np.mean(valid_values)),
193
  'std': float(np.std(valid_values)),
194
- 'min': float(np.min(valid_values)),
195
- 'max': float(np.max(valid_values)),
196
- 'median': float(np.median(valid_values)),
197
- 'q25': float(np.percentile(valid_values, 25)),
198
- 'q75': float(np.percentile(valid_values, 75)),
199
- 'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
200
  }
201
  else:
202
- stats = {
203
- 'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
204
- 'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
205
- }
206
 
207
  indices[index_name] = {
208
  'values': masked_values,
209
  'statistics': stats
210
  }
211
 
212
- logger.debug(f"Computed {index_name}")
213
-
214
  except Exception as e:
215
  logger.error(f"Failed to compute {index_name}: {e}")
216
- continue
217
-
218
- return indices
219
-
220
- def create_vegetation_index_image(self, index_values: np.ndarray,
221
- colormap: str = 'RdYlGn',
222
- vmin: Optional[float] = None,
223
- vmax: Optional[float] = None) -> np.ndarray:
224
- """
225
- Create visualization image for vegetation index.
226
-
227
- Args:
228
- index_values: Vegetation index values
229
- colormap: Matplotlib colormap name
230
- vmin: Minimum value for normalization
231
- vmax: Maximum value for normalization
232
-
233
- Returns:
234
- RGB image array
235
- """
236
- try:
237
- import matplotlib.pyplot as plt
238
- import matplotlib.cm as cm
239
- from matplotlib.colors import Normalize
240
-
241
- # Determine value range
242
- valid_values = index_values[~np.isnan(index_values)]
243
- if len(valid_values) == 0:
244
- return np.zeros((*index_values.shape, 3), dtype=np.uint8)
245
-
246
- if vmin is None:
247
- vmin = np.min(valid_values)
248
- if vmax is None:
249
- vmax = np.max(valid_values)
250
-
251
- # Normalize values
252
- norm = Normalize(vmin=vmin, vmax=vmax)
253
- cmap = cm.get_cmap(colormap)
254
-
255
- # Apply colormap
256
- rgba_img = cmap(norm(index_values))
257
- rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
258
-
259
- # Convert to RGB uint8
260
- rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
261
-
262
- return rgb_img
263
-
264
- except Exception as e:
265
- logger.error(f"Failed to create vegetation index image: {e}")
266
- return np.zeros((*index_values.shape, 3), dtype=np.uint8)
267
-
268
- def get_available_indices(self) -> list:
269
- """Get list of available vegetation indices."""
270
- return list(self.index_formulas.keys())
271
-
272
- def get_index_requirements(self, index_name: str) -> list:
273
- """
274
- Get required bands for a specific index.
275
-
276
- Args:
277
- index_name: Name of the vegetation index
278
-
279
- Returns:
280
- List of required band names
281
- """
282
- return self.index_bands.get(index_name, [])
283
-
284
- def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
285
- """
286
- Validate spectral data for vegetation index computation.
287
-
288
- Args:
289
- spectral_stack: Dictionary of spectral bands
290
-
291
- Returns:
292
- True if valid, False otherwise
293
- """
294
- if not spectral_stack:
295
- return False
296
-
297
- required_bands = ['nir', 'red', 'green', 'red_edge']
298
- if not all(band in spectral_stack for band in required_bands):
299
- logger.warning("Missing required spectral bands")
300
- return False
301
-
302
- # Check data shapes
303
- shapes = [arr.shape for arr in spectral_stack.values()]
304
- if not all(shape == shapes[0] for shape in shapes):
305
- logger.warning("Inconsistent spectral band shapes")
306
- return False
307
 
308
- return True
 
1
  """
2
+ Minimal vegetation index extraction (NDVI, ARI, GNDVI only).
 
 
 
3
  """
4
 
5
  import numpy as np
6
+ from typing import Dict, Any
 
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
12
  class VegetationIndexExtractor:
13
+ """Minimal vegetation index extraction."""
14
 
15
  def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
16
+ """Initialize with defaults."""
17
+ self.epsilon = epsilon
18
+ self.soil_factor = soil_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
20
  self.index_formulas = {
21
  "NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  "ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
23
+ "GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
 
26
  self.index_bands = {
27
  "NDVI": ["nir", "red"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  "ARI": ["green", "red_edge"],
29
+ "GNDVI": ["nir", "green"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
 
32
  def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
33
+ mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
34
+ """Compute NDVI, ARI, and GNDVI."""
 
 
 
 
 
 
 
 
 
35
  indices = {}
36
 
37
  for index_name, formula in self.index_formulas.items():
38
  try:
39
+ required_bands = self.index_bands[index_name]
 
 
 
40
  if not all(band in spectral_stack for band in required_bands):
 
41
  continue
42
 
 
43
  band_data = []
44
  for band in required_bands:
45
  arr = spectral_stack[band]
 
46
  if isinstance(arr, np.ndarray):
47
  arr = arr.squeeze(-1)
48
+ band_data.append(np.asarray(arr, dtype=np.float64))
 
49
 
 
50
  index_values = formula(*band_data).astype(np.float64)
51
+ binary_mask = (np.asarray(mask).astype(np.int32) > 0)
52
+ masked_values = np.where(binary_mask, index_values, np.nan)
53
 
 
 
 
 
 
 
 
 
54
  valid_values = masked_values[~np.isnan(masked_values)]
55
  if len(valid_values) > 0:
56
  stats = {
57
  'mean': float(np.mean(valid_values)),
58
  'std': float(np.std(valid_values)),
 
 
 
 
 
 
59
  }
60
  else:
61
+ stats = {'mean': 0.0, 'std': 0.0}
 
 
 
62
 
63
  indices[index_name] = {
64
  'values': masked_values,
65
  'statistics': stats
66
  }
67
 
 
 
68
  except Exception as e:
69
  logger.error(f"Failed to compute {index_name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return indices
sorghum_pipeline/output/manager.py CHANGED
@@ -1,688 +1,143 @@
1
  """
2
- Output manager for the Sorghum Pipeline.
3
-
4
- This module handles saving results, generating visualizations,
5
- and creating reports.
6
  """
7
 
8
  import os
9
- import json
10
  import numpy as np
11
  import cv2
12
-
13
- # Use a non-GUI backend to avoid segmentation faults in headless runs
14
- try:
15
- import matplotlib
16
- if os.environ.get('MPLBACKEND') is None:
17
- matplotlib.use('Agg')
18
- import matplotlib.pyplot as plt
19
- import matplotlib.cm as cm
20
- from matplotlib.colors import Normalize
21
- except Exception:
22
- # Fallback safe imports (should not happen normally)
23
- import matplotlib.pyplot as plt
24
- import matplotlib.cm as cm
25
- from matplotlib.colors import Normalize
26
- from mpl_toolkits.axes_grid1 import make_axes_locatable
27
  from pathlib import Path
28
- from typing import Dict, Any, Optional, List, Tuple
29
- from concurrent.futures import ThreadPoolExecutor, as_completed
30
- import pandas as pd
31
  import logging
32
 
33
  logger = logging.getLogger(__name__)
34
 
35
 
36
  class OutputManager:
37
- """Manages output generation and saving."""
38
 
39
  def __init__(self, output_folder: str, settings: Any):
40
- """
41
- Initialize output manager.
42
-
43
- Args:
44
- output_folder: Base output folder
45
- settings: Output settings from config
46
- """
47
  self.output_folder = Path(output_folder)
48
  self.settings = settings
49
- # Fast mode and parallel save controls
50
- try:
51
- self.fast_mode: bool = bool(int(os.environ.get('FAST_OUTPUT', '0'))) or bool(getattr(settings, 'fast_mode', False))
52
- except Exception:
53
- self.fast_mode = False
54
- try:
55
- self.max_workers: int = int(os.environ.get('FAST_SAVE_WORKERS', '4'))
56
- except Exception:
57
- self.max_workers = 4
58
  try:
59
- self.png_compression: int = int(os.environ.get('PNG_COMPRESSION', '1')) # 0-9; 1 is fast
60
  except Exception:
61
- self.png_compression = 1
62
-
63
- # Reduce thread usage to lower risk of native library segfaults
64
- try:
65
- import os as _os
66
- _os.environ.setdefault('OMP_NUM_THREADS', '1')
67
- _os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
68
- _os.environ.setdefault('MKL_NUM_THREADS', '1')
69
- _os.environ.setdefault('NUMEXPR_NUM_THREADS', '1')
70
- except Exception:
71
- pass
72
- try:
73
- cv2.setNumThreads(1)
74
- except Exception:
75
- pass
76
-
77
- # Create base directories
78
  self.output_folder.mkdir(parents=True, exist_ok=True)
79
 
80
- def _imwrite_fast(self, dest: Path, img: np.ndarray) -> None:
81
- try:
82
- cv2.imwrite(str(dest), img, [cv2.IMWRITE_PNG_COMPRESSION, int(self.png_compression)])
83
- except Exception:
84
- cv2.imwrite(str(dest), img)
85
-
86
  def create_output_directories(self) -> None:
87
- """Ensure base output directory exists.
88
-
89
- Note: Do NOT create subdirectories at the root (e.g., 'analysis').
90
- Subdirectories are created within each plant's directory only.
91
- """
92
  self.output_folder.mkdir(parents=True, exist_ok=True)
93
 
94
  def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
95
- """
96
- Save all results for a single plant.
97
-
98
- Args:
99
- plant_key: Plant identifier (e.g., "2025_02_05_plant1_frame8")
100
- plant_data: Plant data dictionary
101
- """
102
- try:
103
- # Parse plant key
104
- parts = plant_key.split('_')
105
- date_key = "_".join(parts[:3])
106
- plant_name = parts[3]
107
- frame_key = parts[4] if len(parts) > 4 else "frame0"
108
-
109
- # Create plant-specific directory
110
- plant_dir = self.output_folder / date_key / plant_name
111
- plant_dir.mkdir(parents=True, exist_ok=True)
112
-
113
- # Save segmentation results
114
- self._save_segmentation_results(plant_dir, plant_name, plant_data)
115
-
116
- # Save texture features
117
- self._save_texture_features(plant_dir, plant_data)
118
-
119
- # Save vegetation indices
120
- self._save_vegetation_indices(plant_dir, plant_data)
121
-
122
- # Save morphology features
123
- self._save_morphology_features(plant_dir, plant_data)
124
-
125
- # Save analysis plots
126
- self._save_analysis_plots(plant_dir, plant_data)
127
-
128
- # Save metadata
129
- self._save_metadata(plant_dir, plant_key, plant_data)
130
-
131
- logger.debug(f"Results saved for {plant_key}")
132
-
133
- except Exception as e:
134
- logger.error(f"Failed to save results for {plant_key}: {e}")
135
-
136
- def _save_segmentation_results(self, plant_dir: Path, plant_name: str, plant_data: Dict[str, Any]) -> None:
137
- """Save segmentation results."""
138
- if not self.settings.save_images:
139
  return
140
 
141
- seg_dir = plant_dir / self.settings.segmentation_dir
142
- seg_dir.mkdir(exist_ok=True)
143
-
144
- try:
145
- tasks: List[Tuple[Path, np.ndarray]] = []
146
- # Choose which base image to present in original/overlay
147
- use_feature_image = False
148
- try:
149
- # Allow env override, and special-case plants 13-16 per user requirement
150
- use_feature_image = bool(int(os.environ.get('OUTPUT_USE_FEATURE_IMAGE', '0'))) or plant_name in { 'plant13','plant14','plant15','plant16' }
151
- except Exception:
152
- use_feature_image = plant_name in { 'plant13','plant14','plant15','plant16' }
153
- if use_feature_image:
154
- base_image = plant_data.get('composite', plant_data.get('segmentation_composite'))
155
- else:
156
- base_image = plant_data.get('segmentation_composite', plant_data.get('composite'))
157
- if base_image is not None:
158
- tasks.append((seg_dir / 'original.png', base_image))
159
- if 'mask' in plant_data:
160
- tasks.append((seg_dir / 'mask.png', plant_data['mask']))
161
- if 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
162
- tasks.append((seg_dir / 'mask3.png', plant_data['mask3']))
163
- # Save the BRIA-generated mask (if present before overrides) as mask2.png
164
- if 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
165
- tasks.append((seg_dir / 'mask2.png', plant_data['original_mask']))
166
- if base_image is not None and 'mask' in plant_data:
167
- overlay = self._create_overlay(base_image, plant_data['mask'])
168
- tasks.append((seg_dir / 'overlay.png', overlay))
169
- if 'masked_composite' in plant_data:
170
- tasks.append((seg_dir / 'masked_composite.png', plant_data['masked_composite']))
171
 
172
- # Create white-background maskouts
173
- try:
174
- if base_image is not None and 'mask' in plant_data:
175
- maskout_external = self._create_maskout_white_background(base_image, plant_data['mask'])
176
- tasks.append((seg_dir / 'maskout_external.png', maskout_external))
177
- # BRIA-only maskout directly on original composite
178
- if base_image is not None and 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
179
- maskout_bria = self._create_maskout_white_background(base_image, plant_data['original_mask'])
180
- tasks.append((seg_dir / 'maskout_bria.png', maskout_bria))
181
- # mask3 maskout on original composite
182
- if base_image is not None and 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
183
- maskout_mask3 = self._create_maskout_white_background(base_image, plant_data['mask3'])
184
- tasks.append((seg_dir / 'maskout_mask3.png', maskout_mask3))
185
- except Exception as _e:
186
- logger.debug(f"Failed to create double maskouts: {_e}")
187
 
188
- if self.max_workers > 1 and len(tasks) > 1:
189
- with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
190
- futures = [ex.submit(self._imwrite_fast, p, img) for p, img in tasks]
191
- for _ in as_completed(futures):
192
- pass
193
- else:
194
- for p, img in tasks:
195
- self._imwrite_fast(p, img)
196
  except Exception as e:
197
- logger.error(f"Failed to save segmentation results: {e}")
198
-
199
- def _save_texture_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
200
- """Save texture features."""
201
- if not self.settings.save_images or 'texture_features' not in plant_data:
202
- return
203
-
204
- texture_dir = plant_dir / self.settings.texture_dir
205
- texture_dir.mkdir(exist_ok=True)
206
-
207
- def save_feature_png(feature_name: str, values: Any, dest: Path, cmap_name: str = 'viridis') -> None:
208
- try:
209
- arr = np.asarray(values)
210
- if arr.ndim == 3 and arr.shape[-1] == 3:
211
- self._imwrite_fast(dest, cv2.cvtColor(arr.astype(np.uint8), cv2.COLOR_RGB2BGR))
212
- return
213
- if self.fast_mode:
214
- # Fast path: simple normalization, no matplotlib
215
- normalized = self._normalize_to_uint8(np.nan_to_num(arr.astype(np.float64), nan=0.0))
216
- self._imwrite_fast(dest, normalized)
217
- else:
218
- arr = arr.astype(np.float64)
219
- masked = np.ma.masked_invalid(arr)
220
- fig, ax = plt.subplots(figsize=(5, 5))
221
- ax.set_axis_off()
222
- ax.set_facecolor('white')
223
- im = ax.imshow(masked, cmap=cmap_name)
224
- divider = make_axes_locatable(ax)
225
- cax = divider.append_axes("right", size="2%", pad=0.02)
226
- cbar = plt.colorbar(im, cax=cax, orientation='vertical')
227
- cbar.set_label(feature_name, fontsize=7)
228
- cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
229
- if hasattr(cbar, 'outline') and cbar.outline is not None:
230
- cbar.outline.set_linewidth(0.5)
231
- plt.tight_layout()
232
- plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
233
- plt.close(fig)
234
- except Exception as e:
235
- logger.error(f"Failed to save texture feature image for {feature_name}: {e}")
236
- try:
237
- normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
238
- self._imwrite_fast(dest, normalized)
239
- except Exception:
240
- pass
241
 
 
242
  try:
243
- texture_features = plant_data['texture_features']
244
-
245
- for band, band_data in texture_features.items():
246
- if 'features' not in band_data:
247
- continue
248
-
249
- band_dir = texture_dir / band
250
- band_dir.mkdir(exist_ok=True)
251
-
252
- features = band_data['features']
253
-
254
- # Save individual feature maps (optionally in parallel)
255
- items: List[Tuple[str, np.ndarray, Path, str]] = []
256
- for feature_name, feature_map in features.items():
257
- if feature_name == 'ehd_features':
258
- for i in range(feature_map.shape[0]):
259
- channel = feature_map[i]
260
- if isinstance(channel, np.ndarray) and channel.size > 0:
261
- items.append((f'ehd_channel_{i}', channel, band_dir / f'ehd_channel_{i}.png', 'magma'))
262
- else:
263
- if isinstance(feature_map, np.ndarray) and feature_map.size > 0:
264
- cmap_choice = 'gray' if feature_name in ('lbp', 'hog') else 'plasma' if feature_name.startswith('lac') else 'viridis'
265
- items.append((feature_name, feature_map, band_dir / f'{feature_name}.png', cmap_choice))
266
-
267
- if self.max_workers > 1 and len(items) > 1:
268
- with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
269
- futures = [ex.submit(save_feature_png, n, m, p, c) for (n, m, p, c) in items]
270
- for _ in as_completed(futures):
271
- pass
272
- else:
273
- for (n, m, p, c) in items:
274
- save_feature_png(n, m, p, c)
275
-
276
- # Create feature summary plot
277
- self._create_texture_summary_plot(band_dir, features, band)
278
-
279
- # Save texture statistics if available
280
- if 'statistics' in band_data and isinstance(band_data['statistics'], dict):
281
- try:
282
- with open(band_dir / 'texture_statistics.json', 'w') as f:
283
- json.dump(band_data['statistics'], f, indent=2)
284
- except Exception as e:
285
- logger.error(f"Failed to save texture statistics for {band}: {e}")
286
-
287
  except Exception as e:
288
- logger.error(f"Failed to save texture features: {e}")
289
-
290
- def _save_vegetation_indices(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
291
- """Save vegetation indices."""
292
- if not self.settings.save_images or 'vegetation_indices' not in plant_data:
293
- return
294
-
295
- veg_dir = plant_dir / self.settings.vegetation_dir
296
- veg_dir.mkdir(exist_ok=True)
297
-
298
- # Colormap and range settings per index
299
- index_cmap_settings = {
300
- "NDVI": (cm.RdYlGn, -1, 1),
301
- "GNDVI": (cm.RdYlGn, -1, 1),
302
- "NDRE": (cm.RdYlGn, -1, 1),
303
- "GRNDVI": (cm.RdYlGn, -1, 1),
304
- "TNDVI": (cm.RdYlGn, -1, 1),
305
- "MGRVI": (cm.RdYlGn, -1, 1),
306
- "GRVI": (cm.RdYlGn, -1, 1),
307
- "NGRDI": (cm.RdYlGn, -1, 1),
308
- "MSAVI": (cm.YlGn, 0, 1),
309
- "OSAVI": (cm.YlGn, 0, 1),
310
- "TSAVI": (cm.YlGn, 0, 1),
311
- "GSAVI": (cm.YlGn, 0, 1),
312
- "NDWI": (cm.Blues, -1, 1),
313
- "DSWI4": (cm.Blues, -1, 1),
314
- "CIRE": (cm.viridis, 0, 10),
315
- "LCI": (cm.viridis, 0, 5),
316
- "CIgreen": (cm.viridis, 0, 5),
317
- "MCARI": (cm.viridis, 0, 1.5),
318
- "MCARI1": (cm.viridis, 0, 1.5),
319
- "MCARI2": (cm.viridis, 0, 1.5),
320
- "CVI": (cm.plasma, 0, 10),
321
- "TCARI": (cm.viridis, 0, 1),
322
- "TCARIOSAVI": (cm.viridis, 0, 1),
323
- "AVI": (cm.magma, 0, 1),
324
- "SIPI2": (cm.inferno, 0, 1),
325
- "ARI": (cm.magma, 0, 1),
326
- "ARI2": (cm.magma, 0, 1),
327
- "DVI": (cm.Greens, 0, None),
328
- "WDVI": (cm.Greens, 0, None),
329
- "SR": (cm.viridis, 0, 10),
330
- "MSR": (cm.viridis, 0, 10),
331
- "PVI": (cm.cividis, None, None),
332
- "GEMI": (cm.cividis, 0, 1),
333
- "ExR": (cm.Reds, -1, 1),
334
- "RI": (cm.Reds, 0, None),
335
- "RRI1": (cm.Reds, 0, 1)
336
- }
337
-
338
- def save_index_png(index_name: str, values: Any, dest: Path) -> None:
339
- try:
340
- arr = values
341
- if not isinstance(arr, (list, tuple,)) and isinstance(arr, (float, int)):
342
- return
343
- arr = np.asarray(arr, dtype=np.float64)
344
- if self.fast_mode:
345
- normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
346
- self._imwrite_fast(dest, normalized)
347
- else:
348
- cmap, vmin, vmax = index_cmap_settings.get(index_name, (cm.viridis, np.nanmin(arr), np.nanmax(arr)))
349
- if vmin is None:
350
- vmin = np.nanmin(arr)
351
- if vmax is None:
352
- vmax = np.nanmax(arr)
353
- if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
354
- vmin, vmax = 0.0, 1.0
355
- masked = np.ma.masked_invalid(arr)
356
- fig, ax = plt.subplots(figsize=(5, 5))
357
- ax.set_axis_off()
358
- ax.set_facecolor('white')
359
- im = ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
360
- divider = make_axes_locatable(ax)
361
- cax = divider.append_axes("right", size="2%", pad=0.02)
362
- cbar = plt.colorbar(im, cax=cax, orientation='vertical')
363
- cbar.set_label(index_name, fontsize=7)
364
- cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
365
- if hasattr(cbar, 'outline') and cbar.outline is not None:
366
- cbar.outline.set_linewidth(0.5)
367
- plt.tight_layout()
368
- plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
369
- plt.close(fig)
370
- except Exception as e:
371
- logger.error(f"Failed to save vegetation index image for {index_name}: {e}")
372
- try:
373
- # Fallback simple normalization
374
- normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
375
- self._imwrite_fast(dest, normalized)
376
- except Exception:
377
- pass
378
 
 
379
  try:
380
- vegetation_indices = plant_data['vegetation_indices']
381
-
382
- items_png: List[Tuple[str, np.ndarray, Path]] = []
383
- items_stats: List[Tuple[Path, Dict[str, Any]]] = []
384
- for index_name, index_data in vegetation_indices.items():
385
- if isinstance(index_data, dict) and 'values' in index_data:
386
- values = index_data['values']
387
- if isinstance(values, np.ndarray) and values.size > 0:
388
- items_png.append((index_name, values, veg_dir / f'{index_name}.png'))
389
- stats = index_data.get('statistics')
390
- if isinstance(stats, dict):
391
- items_stats.append((veg_dir / f'{index_name}_stats.json', stats))
392
-
393
- # Save sequentially to avoid matplotlib thread-safety issues
394
- for (name, arr, dest) in items_png:
395
- save_index_png(name, arr, dest)
396
- for (path, stats) in items_stats:
397
- try:
398
- with open(path, 'w') as f:
399
- json.dump(stats, f, indent=2)
400
- except Exception as e:
401
- logger.error(f"Failed to save stats for {path.name.split('.')[0]}: {e}")
402
-
403
- # Create vegetation index summary (skip in fast mode)
404
- if not self.fast_mode:
405
- self._create_vegetation_summary_plot(veg_dir, vegetation_indices)
406
-
407
- # Save aggregated vegetation statistics
408
- try:
409
- all_stats = {k: v.get('statistics', {}) for k, v in vegetation_indices.items() if isinstance(v, dict)}
410
- with open(veg_dir / 'vegetation_statistics.json', 'w') as f:
411
- json.dump(all_stats, f, indent=2)
412
- except Exception as e:
413
- logger.error(f"Failed to save aggregated vegetation statistics: {e}")
414
-
415
  except Exception as e:
416
  logger.error(f"Failed to save vegetation indices: {e}")
417
-
418
- def _save_morphology_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
419
- """Save morphological features."""
420
- if not self.settings.save_images or 'morphology_features' not in plant_data:
421
- return
422
-
423
- morph_dir = plant_dir / self.settings.morphology_dir
424
- morph_dir.mkdir(exist_ok=True)
425
-
426
  try:
427
- morphology_features = plant_data['morphology_features']
 
 
428
 
429
- # Save morphological images
430
- if 'images' in morphology_features:
431
- for image_name, image_data in morphology_features['images'].items():
432
- if isinstance(image_data, np.ndarray) and image_data.size > 0:
433
- cv2.imwrite(str(morph_dir / f'{image_name}.png'), image_data)
434
 
435
- # Save morphological data
436
- if 'traits' in morphology_features:
437
- traits = morphology_features['traits']
438
- with open(morph_dir / 'traits.json', 'w') as f:
439
- json.dump(traits, f, indent=2)
440
 
 
 
 
 
 
441
  except Exception as e:
442
- logger.error(f"Failed to save morphology features: {e}")
443
-
444
- def _save_analysis_plots(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
445
- """Save analysis plots."""
446
- if not self.settings.save_plots or self.fast_mode:
447
- return
448
-
449
- analysis_dir = plant_dir / self.settings.analysis_dir
450
- analysis_dir.mkdir(exist_ok=True)
451
-
452
- try:
453
- # Create comprehensive analysis plot
454
- self._create_comprehensive_analysis_plot(analysis_dir, plant_data)
455
-
456
- except Exception as e:
457
- logger.error(f"Failed to save analysis plots: {e}")
458
-
459
- def _save_metadata(self, plant_dir: Path, plant_key: str, plant_data: Dict[str, Any]) -> None:
460
- """Save metadata for the plant."""
461
- if not self.settings.save_metadata:
462
- return
463
-
464
  try:
465
- metadata = {
466
- 'plant_key': plant_key,
467
- 'timestamp': pd.Timestamp.now().isoformat(),
468
- 'image_shape': plant_data.get('composite', np.array([])).shape if 'composite' in plant_data else None,
469
- 'has_mask': 'mask' in plant_data and plant_data['mask'] is not None,
470
- 'features_available': {
471
- 'texture': 'texture_features' in plant_data,
472
- 'vegetation': 'vegetation_indices' in plant_data,
473
- 'morphology': 'morphology_features' in plant_data
474
- }
475
- }
476
-
477
- with open(plant_dir / 'metadata.json', 'w') as f:
478
- json.dump(metadata, f, indent=2)
479
-
480
  except Exception as e:
481
- logger.error(f"Failed to save metadata: {e}")
482
 
483
- def _create_overlay(self, image: np.ndarray, mask: np.ndarray,
484
- color: Tuple[int, int, int] = (0, 255, 0),
485
- alpha: float = 0.5) -> np.ndarray:
486
- """Return a strictly masked image: pixels where mask>0 keep original; others set to 0."""
487
  if mask is None:
488
  return image
489
- # Resize mask to image size if needed
490
  if mask.shape[:2] != image.shape[:2]:
491
- try:
492
- mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
493
- except Exception:
494
- pass
495
  binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
496
  return cv2.bitwise_and(image, image, mask=binary)
497
 
498
- def _create_maskout_white_background(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
499
- """Create maskout image with white background."""
500
- # Create white background
501
- white_background = np.full_like(image, 255, dtype=np.uint8)
502
-
503
- # Apply mask to original image (keep only masked regions)
504
- masked_image = image.copy()
505
- masked_image[mask == 0] = 0 # Set non-masked regions to black
506
-
507
- # Combine: white background + masked image
508
- result = white_background.copy()
509
- result[mask > 0] = masked_image[mask > 0]
510
-
511
- return result
512
-
513
  def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
514
- """Normalize array to uint8 range."""
515
- if arr.size == 0:
516
- return arr.astype(np.uint8)
517
-
518
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
519
-
520
  if arr.ptp() > 0:
521
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
522
  else:
523
  normalized = np.zeros_like(arr)
524
-
525
- return np.clip(normalized, 0, 255).astype(np.uint8)
526
-
527
- def _create_texture_summary_plot(self, output_dir: Path, features: Dict[str, np.ndarray], band: str) -> None:
528
- """Create texture feature summary plot."""
529
- try:
530
- # Get available features
531
- available_features = [k for k, v in features.items()
532
- if isinstance(v, np.ndarray) and v.size > 0 and k != 'ehd_features']
533
-
534
- if not available_features:
535
- return
536
-
537
- # Create subplot
538
- n_features = len(available_features)
539
- cols = min(3, n_features)
540
- rows = (n_features + cols - 1) // cols
541
-
542
- fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
543
- if n_features == 1:
544
- axes = [axes]
545
- elif rows == 1:
546
- axes = axes.reshape(1, -1)
547
-
548
- for i, feature_name in enumerate(available_features):
549
- row, col = divmod(i, cols)
550
- ax = axes[row, col] if rows > 1 else axes[col]
551
-
552
- feature_map = features[feature_name]
553
- ax.imshow(feature_map, cmap='viridis')
554
- ax.set_title(f'{band.upper()} - {feature_name.upper()}')
555
- ax.axis('off')
556
-
557
- # Hide unused subplots
558
- for i in range(n_features, rows * cols):
559
- row, col = divmod(i, cols)
560
- ax = axes[row, col] if rows > 1 else axes[col]
561
- ax.axis('off')
562
-
563
- plt.tight_layout()
564
- plt.savefig(output_dir / f'{band}_texture_summary.png',
565
- dpi=self.settings.plot_dpi, bbox_inches='tight')
566
- plt.close()
567
-
568
- except Exception as e:
569
- logger.error(f"Failed to create texture summary plot: {e}")
570
-
571
- def _create_vegetation_summary_plot(self, output_dir: Path, vegetation_indices: Dict[str, Any]) -> None:
572
- """Create vegetation index summary plot."""
573
- try:
574
- # Get available indices
575
- available_indices = [k for k, v in vegetation_indices.items()
576
- if isinstance(v, dict) and 'values' in v and isinstance(v['values'], np.ndarray)]
577
-
578
- if not available_indices:
579
- return
580
-
581
- # Create subplot
582
- n_indices = len(available_indices)
583
- cols = min(3, n_indices)
584
- rows = (n_indices + cols - 1) // cols
585
-
586
- fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
587
- if n_indices == 1:
588
- axes = [axes]
589
- elif rows == 1:
590
- axes = axes.reshape(1, -1)
591
-
592
- for i, index_name in enumerate(available_indices):
593
- row, col = divmod(i, cols)
594
- ax = axes[row, col] if rows > 1 else axes[col]
595
-
596
- values = vegetation_indices[index_name]['values']
597
- im = ax.imshow(values, cmap='RdYlGn')
598
- ax.set_title(f'{index_name}')
599
- ax.axis('off')
600
- divider = make_axes_locatable(ax)
601
- cax = divider.append_axes("right", size="2%", pad=0.02)
602
- cbar = plt.colorbar(im, cax=cax, orientation='vertical')
603
- cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
604
- if hasattr(cbar, 'outline') and cbar.outline is not None:
605
- cbar.outline.set_linewidth(0.5)
606
-
607
- # Hide unused subplots
608
- for i in range(n_indices, rows * cols):
609
- row, col = divmod(i, cols)
610
- ax = axes[row, col] if rows > 1 else axes[col]
611
- ax.axis('off')
612
-
613
- plt.tight_layout()
614
- plt.savefig(output_dir / 'vegetation_indices_summary.png',
615
- dpi=self.settings.plot_dpi, bbox_inches='tight')
616
- plt.close()
617
-
618
- except Exception as e:
619
- logger.error(f"Failed to create vegetation summary plot: {e}")
620
-
621
- def _create_comprehensive_analysis_plot(self, output_dir: Path, plant_data: Dict[str, Any]) -> None:
622
- """Create comprehensive analysis plot."""
623
- try:
624
- fig, axes = plt.subplots(2, 3, figsize=(15, 10))
625
-
626
- # Original image
627
- if 'composite' in plant_data:
628
- axes[0, 0].imshow(cv2.cvtColor(plant_data['composite'], cv2.COLOR_BGR2RGB))
629
- axes[0, 0].set_title('Original Composite')
630
- axes[0, 0].axis('off')
631
-
632
- # Mask
633
- if 'mask' in plant_data:
634
- axes[0, 1].imshow(plant_data['mask'], cmap='gray')
635
- axes[0, 1].set_title('Segmentation Mask')
636
- axes[0, 1].axis('off')
637
-
638
- # Overlay
639
- if 'composite' in plant_data and 'mask' in plant_data:
640
- overlay = self._create_overlay(plant_data['composite'], plant_data['mask'])
641
- axes[0, 2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
642
- axes[0, 2].set_title('Overlay')
643
- axes[0, 2].axis('off')
644
-
645
- # Texture features (if available)
646
- if 'texture_features' in plant_data and 'color' in plant_data['texture_features']:
647
- color_features = plant_data['texture_features']['color'].get('features', {})
648
- if 'lbp' in color_features:
649
- axes[1, 0].imshow(color_features['lbp'], cmap='viridis')
650
- axes[1, 0].set_title('LBP Texture')
651
- axes[1, 0].axis('off')
652
-
653
- # Vegetation indices (if available)
654
- if 'vegetation_indices' in plant_data:
655
- veg_indices = plant_data['vegetation_indices']
656
- if 'NDVI' in veg_indices and 'values' in veg_indices['NDVI']:
657
- axes[1, 1].imshow(veg_indices['NDVI']['values'], cmap='RdYlGn')
658
- axes[1, 1].set_title('NDVI')
659
- axes[1, 1].axis('off')
660
-
661
- # Morphology (if available)
662
- if 'morphology_features' in plant_data and 'images' in plant_data['morphology_features']:
663
- morph_images = plant_data['morphology_features']['images']
664
- if 'skeleton' in morph_images:
665
- axes[1, 2].imshow(morph_images['skeleton'], cmap='gray')
666
- axes[1, 2].set_title('Skeleton')
667
- axes[1, 2].axis('off')
668
-
669
- plt.tight_layout()
670
- plt.savefig(output_dir / 'comprehensive_analysis.png',
671
- dpi=min(getattr(self.settings, 'plot_dpi', 100), 100), bbox_inches='tight')
672
- plt.close()
673
-
674
- except Exception as e:
675
- logger.error(f"Failed to create comprehensive analysis plot: {e}")
676
-
677
- def create_pipeline_summary(self, results: Dict[str, Any]) -> None:
678
- """Create a summary of the entire pipeline run."""
679
- try:
680
- summary_file = self.output_folder / 'pipeline_summary.json'
681
-
682
- with open(summary_file, 'w') as f:
683
- json.dump(results['summary'], f, indent=2)
684
-
685
- logger.info(f"Pipeline summary saved to {summary_file}")
686
-
687
- except Exception as e:
688
- logger.error(f"Failed to create pipeline summary: {e}")
 
1
  """
2
+ Minimal output manager for demo (saves only 7 required images).
 
 
 
3
  """
4
 
5
  import os
 
6
  import numpy as np
7
  import cv2
8
+ import matplotlib
9
+ if os.environ.get('MPLBACKEND') is None:
10
+ matplotlib.use('Agg')
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.cm as cm
 
 
 
 
 
 
 
 
 
 
13
  from pathlib import Path
14
+ from typing import Dict, Any
 
 
15
  import logging
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  class OutputManager:
21
+ """Minimal output manager for demo."""
22
 
23
  def __init__(self, output_folder: str, settings: Any):
24
+ """Initialize output manager."""
 
 
 
 
 
 
25
  self.output_folder = Path(output_folder)
26
  self.settings = settings
 
 
 
 
 
 
 
 
 
27
  try:
28
+ self.minimal_demo: bool = bool(int(os.environ.get('MINIMAL_DEMO', '0')))
29
  except Exception:
30
+ self.minimal_demo = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.output_folder.mkdir(parents=True, exist_ok=True)
32
 
 
 
 
 
 
 
33
  def create_output_directories(self) -> None:
34
+ """Create output directories."""
 
 
 
 
35
  self.output_folder.mkdir(parents=True, exist_ok=True)
36
 
37
  def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
38
+ """Save minimal demo outputs only."""
39
+ if not self.minimal_demo:
40
+ logger.warning("OutputManager configured for minimal demo only")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return
42
 
43
+ self._save_minimal_demo_outputs(plant_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def _save_minimal_demo_outputs(self, plant_data: Dict[str, Any]) -> None:
46
+ """Save only the 7 required images."""
47
+ results_dir = self.output_folder / 'results'
48
+ veg_dir = self.output_folder / 'Vegetation_indices_images'
49
+ tex_dir = self.output_folder / 'texture_output'
50
+ results_dir.mkdir(parents=True, exist_ok=True)
51
+ veg_dir.mkdir(parents=True, exist_ok=True)
52
+ tex_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
53
 
54
+ # 1. Mask
55
+ try:
56
+ mask = plant_data.get('mask')
57
+ if isinstance(mask, np.ndarray):
58
+ cv2.imwrite(str(results_dir / 'mask.png'), mask)
 
 
 
59
  except Exception as e:
60
+ logger.error(f"Failed to save mask: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # 2. Overlay
63
  try:
64
+ base_image = plant_data.get('composite')
65
+ mask = plant_data.get('mask')
66
+ if isinstance(base_image, np.ndarray) and isinstance(mask, np.ndarray):
67
+ overlay = self._create_overlay(base_image, mask)
68
+ cv2.imwrite(str(results_dir / 'overlay.png'), overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
+ logger.error(f"Failed to save overlay: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # 3-5. Vegetation indices (NDVI, ARI, GNDVI)
73
  try:
74
+ veg = plant_data.get('vegetation_indices', {})
75
+ for name in ['NDVI', 'ARI', 'GNDVI']:
76
+ data = veg.get(name, {})
77
+ values = data.get('values') if isinstance(data, dict) else None
78
+ if isinstance(values, np.ndarray) and values.size > 0:
79
+ try:
80
+ cmap = cm.RdYlGn if name in ['NDVI', 'GNDVI'] else cm.magma
81
+ vmin, vmax = (-1, 1) if name in ['NDVI', 'GNDVI'] else (0, 1)
82
+
83
+ masked = np.ma.masked_invalid(values.astype(np.float64))
84
+ fig, ax = plt.subplots(figsize=(5, 5))
85
+ ax.set_axis_off()
86
+ ax.set_facecolor('white')
87
+ ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
88
+ plt.tight_layout()
89
+ plt.savefig(veg_dir / f"{name.lower()}.png", dpi=100, bbox_inches='tight')
90
+ plt.close(fig)
91
+ except Exception as e:
92
+ logger.error(f"Failed to save {name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  except Exception as e:
94
  logger.error(f"Failed to save vegetation indices: {e}")
95
+
96
+ # 6-8. Texture features (LBP, HOG, Lacunarity)
 
 
 
 
 
 
 
97
  try:
98
+ tex = plant_data.get('texture_features', {})
99
+ color_band = tex.get('color', {})
100
+ feats = color_band.get('features', {})
101
 
102
+ if isinstance(feats.get('lbp'), np.ndarray) and feats['lbp'].size > 0:
103
+ cv2.imwrite(str(tex_dir / 'lbp.png'), feats['lbp'].astype(np.uint8))
 
 
 
104
 
105
+ if isinstance(feats.get('hog'), np.ndarray) and feats['hog'].size > 0:
106
+ cv2.imwrite(str(tex_dir / 'hog.png'), feats['hog'].astype(np.uint8))
 
 
 
107
 
108
+ lac = feats.get('lac2')
109
+ if isinstance(lac, np.ndarray) and lac.size > 0:
110
+ if lac.dtype != np.uint8:
111
+ lac = self._normalize_to_uint8(lac.astype(np.float64))
112
+ cv2.imwrite(str(tex_dir / 'lacunarity.png'), lac)
113
  except Exception as e:
114
+ logger.error(f"Failed to save texture: {e}")
115
+
116
+ # 9. Morphology size analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  try:
118
+ morph = plant_data.get('morphology_features', {})
119
+ images = morph.get('images', {})
120
+ size_img = images.get('size_analysis')
121
+ if isinstance(size_img, np.ndarray) and size_img.size > 0:
122
+ cv2.imwrite(str(results_dir / 'size.size_analysis.png'), size_img)
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ logger.error(f"Failed to save size analysis: {e}")
125
 
126
+ def _create_overlay(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
127
+ """Create overlay (masked pixels only)."""
 
 
128
  if mask is None:
129
  return image
 
130
  if mask.shape[:2] != image.shape[:2]:
131
+ mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]),
132
+ interpolation=cv2.INTER_NEAREST)
 
 
133
  binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
134
  return cv2.bitwise_and(image, image, mask=binary)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
137
+ """Normalize to uint8."""
 
 
 
138
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
 
139
  if arr.ptp() > 0:
140
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
141
  else:
142
  normalized = np.zeros_like(arr)
143
+ return np.clip(normalized, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/pipeline.py CHANGED
@@ -1,620 +1,110 @@
1
  """
2
  Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
3
 
4
- This module orchestrates the entire pipeline from data loading
5
- to feature extraction and result output.
6
  """
7
 
8
  import os
9
- import subprocess
10
  import logging
11
  from pathlib import Path
12
- from typing import Dict, Any, Optional, List, Set
13
  import numpy as np
14
  import cv2
15
- import torch
16
- from torchvision import transforms
17
- from transformers import AutoModelForImageSegmentation
18
  from sklearn.decomposition import PCA
19
- try:
20
- from tqdm import tqdm
21
- except Exception:
22
- tqdm = None
23
 
24
  from .config import Config
25
- from .data import DataLoader, ImagePreprocessor, MaskHandler
26
  from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
27
  from .output import OutputManager
28
  from .segmentation import SegmentationManager
29
- # Make occlusion handling optional if the module is not present
30
- try:
31
- from .segmentation.occlusion_handler import OcclusionHandler # type: ignore
32
- except Exception:
33
- OcclusionHandler = None # type: ignore
34
 
35
 
36
  class SorghumPipeline:
37
- """
38
- Main pipeline class for sorghum plant phenotyping.
39
-
40
- This class orchestrates the entire pipeline from data loading
41
- to feature extraction and result output.
42
- """
43
 
44
- def __init__(self, config_path: Optional[str] = None, config: Optional[Config] = None, include_ignored: bool = False, enable_occlusion_handling: bool = False, enable_instance_integration: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
45
- """
46
- Initialize the pipeline.
47
-
48
- Args:
49
- config_path: Path to configuration file
50
- config: Configuration object (if not using file)
51
- include_ignored: Whether to include ignored plants
52
- enable_occlusion_handling: Whether to enable SAM2Long occlusion handling
53
- """
54
- # Setup logging
55
  self._setup_logging()
56
-
57
- # Load configuration
58
- if config is not None:
59
- self.config = config
60
- elif config_path is not None:
61
- self.config = Config(config_path)
62
- else:
63
- raise ValueError("Either config_path or config must be provided")
64
-
65
- # Validate configuration
66
  self.config.validate()
67
-
68
- # Store settings
69
- self.enable_occlusion_handling = enable_occlusion_handling
70
- self.enable_instance_integration = enable_instance_integration
71
- self.strict_loader = strict_loader
72
- self.excluded_dates = excluded_dates or []
73
-
74
- # Initialize components
75
- self._initialize_components(include_ignored)
76
-
77
- logger.info("Sorghum Pipeline initialized successfully")
78
 
79
  def _setup_logging(self):
80
  """Setup logging configuration."""
81
  logging.basicConfig(
82
  level=logging.INFO,
83
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
84
- handlers=[
85
- logging.StreamHandler(),
86
- logging.FileHandler('sorghum_pipeline.log')
87
- ]
88
  )
89
- global logger
90
- logger = logging.getLogger(__name__)
91
 
92
- def _initialize_components(self, include_ignored: bool = False):
93
- """Initialize all pipeline components."""
94
- # Data components
95
- self.data_loader = DataLoader(
96
- input_folder=self.config.paths.input_folder,
97
- debug=True,
98
- include_ignored=include_ignored,
99
- strict_loader=self.strict_loader,
100
- excluded_dates=self.excluded_dates,
101
- )
102
- self.preprocessor = ImagePreprocessor(
103
- target_size=self.config.processing.target_size
104
- )
105
- self.mask_handler = MaskHandler(
106
- min_area=self.config.processing.min_component_area,
107
- kernel_size=self.config.processing.morphology_kernel_size
108
- )
109
-
110
- # Feature extractors
111
- self.texture_extractor = TextureExtractor(
112
- lbp_points=self.config.processing.lbp_points,
113
- lbp_radius=self.config.processing.lbp_radius,
114
- hog_orientations=self.config.processing.hog_orientations,
115
- hog_pixels_per_cell=self.config.processing.hog_pixels_per_cell,
116
- hog_cells_per_block=self.config.processing.hog_cells_per_block,
117
- lacunarity_window=self.config.processing.lacunarity_window,
118
- ehd_threshold=self.config.processing.ehd_threshold,
119
- angle_resolution=self.config.processing.angle_resolution
120
- )
121
-
122
- self.vegetation_extractor = VegetationIndexExtractor(
123
- epsilon=self.config.processing.epsilon,
124
- soil_factor=self.config.processing.soil_factor
125
- )
126
-
127
- self.morphology_extractor = MorphologyExtractor(
128
- pixel_to_cm=self.config.processing.pixel_to_cm,
129
- prune_sizes=self.config.processing.prune_sizes
130
- )
131
-
132
- # Segmentation
133
  self.segmentation_manager = SegmentationManager(
134
- model_name=self.config.model.model_name,
135
  device=self.config.get_device(),
136
- threshold=self.config.processing.segmentation_threshold,
137
- trust_remote_code=self.config.model.trust_remote_code,
138
- cache_dir=self.config.model.cache_dir if getattr(self.config.model, 'cache_dir', '') else None,
139
- local_files_only=getattr(self.config.model, 'local_files_only', False),
140
  )
141
-
142
- # Occlusion handling (optional)
143
- self.occlusion_handler = None
144
- if self.enable_occlusion_handling and OcclusionHandler is not None:
145
- try:
146
- self.occlusion_handler = OcclusionHandler(
147
- device=self.config.get_device(),
148
- model="tiny", # Can be made configurable
149
- confidence_threshold=0.5,
150
- iou_threshold=0.1
151
- )
152
- logger.info("Occlusion handler initialized successfully")
153
- except Exception as e:
154
- logger.warning(f"Failed to initialize occlusion handler: {e}")
155
- logger.warning("Continuing without occlusion handling")
156
- self.occlusion_handler = None
157
- elif self.enable_occlusion_handling and OcclusionHandler is None:
158
- logger.warning("Occlusion handler module not found; continuing without occlusion handling")
159
-
160
- # Output manager
161
  self.output_manager = OutputManager(
162
  output_folder=self.config.paths.output_folder,
163
  settings=self.config.output
164
  )
165
 
166
- def _free_gpu_memory_before_instance(self) -> None:
167
- """Attempt to free GPU memory prior to running SAM2Long in a subprocess.
168
-
169
- - Moves BRIA segmentation model to CPU if present
170
- - Deletes the model reference to release VRAM
171
- - Calls torch.cuda.empty_cache()
172
  """
173
- try:
174
- import torch as _torch # type: ignore
175
- # Move BRIA model to CPU and drop reference
176
- try:
177
- if getattr(self, 'segmentation_manager', None) is not None:
178
- mdl = getattr(self.segmentation_manager, 'model', None)
179
- if mdl is not None:
180
- try:
181
- mdl.to('cpu')
182
- except Exception:
183
- pass
184
- try:
185
- delattr(self.segmentation_manager, 'model')
186
- except Exception:
187
- pass
188
- # Ensure attribute exists but is None for future checks
189
- try:
190
- self.segmentation_manager.model = None # type: ignore
191
- except Exception:
192
- pass
193
- except Exception:
194
- pass
195
- # Free CUDA cache
196
- try:
197
- if _torch.cuda.is_available():
198
- _torch.cuda.empty_cache()
199
- except Exception:
200
- pass
201
- logger.info("Freed GPU memory before SAM2Long invocation (moved BRIA to CPU and emptied cache)")
202
- except Exception as e:
203
- logger.warning(f"Failed to free GPU memory before instance segmentation: {e}")
204
-
205
- def run(self, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None, filter_frames: Optional[List[str]] = None, run_instance_segmentation: bool = False, features_frame_only: Optional[int] = None, reuse_instance_results: bool = False, instance_mapping_path: Optional[str] = None, force_reprocess: bool = False, respect_instance_frame_rules_for_features: bool = False, substitute_feature_image_from_instance_src: bool = False) -> Dict[str, Any]:
206
- """
207
- Run the complete pipeline.
208
 
209
  Args:
210
- load_all_frames: Whether to load all frames or selected frames
211
- segmentation_only: If True, run segmentation only and skip feature extraction
212
 
213
  Returns:
214
- Dictionary containing all results
215
  """
216
- logger.info("Starting Sorghum Pipeline...")
217
 
218
  try:
219
  import time
 
 
220
  total_start = time.perf_counter()
221
- # Step 1: Load data
222
- logger.info("Step 1/6: Loading data...")
223
- # In reuse mode we need all frames to select the mapped frame per plant
224
- if reuse_instance_results:
225
- plants = self.data_loader.load_all_frames()
226
- else:
227
- # If specific frames are requested, we must load all frames to filter correctly
228
- if load_all_frames or (filter_frames is not None and len(filter_frames) > 0):
229
- plants = self.data_loader.load_all_frames()
230
- else:
231
- plants = self.data_loader.load_selected_frames()
232
 
233
- # Optional filter by specific plant names (e.g., ["plant1"])
234
- if filter_plants:
235
- allowed = set(filter_plants)
236
- plants = {
237
- key: pdata for key, pdata in plants.items()
238
- if len(key.split('_')) > 3 and key.split('_')[3] in allowed
 
 
239
  }
240
-
241
- # Optional filter by specific frame numbers (e.g., ["9"] or ["frame9"])
242
- if filter_frames:
243
- # Normalize to 'frameX' tokens
244
- wanted = set(
245
- [f if str(f).startswith('frame') else f"frame{str(f)}" for f in filter_frames]
246
- )
247
- plants = {
248
- key: pdata for key, pdata in plants.items()
249
- if key.split('_')[-1] in wanted
250
- }
251
-
252
- if not plants:
253
- raise ValueError("No plant data loaded")
254
-
255
- logger.info(f"Loaded {len(plants)} plants")
256
-
257
- # If reusing instance results with mapping, restrict to exactly the mapped frame per plant (default frame8)
258
- if reuse_instance_results:
259
- try:
260
- import json as _json
261
- if instance_mapping_path is None:
262
- raise ValueError("instance_mapping_path is required in reuse mode")
263
- _map = _json.load(open(instance_mapping_path, 'r'))
264
- # Normalize mapping plant keys and compute target frame (default 8)
265
- target_frame_by_plant = {}
266
- for pk, pv in _map.items():
267
- k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
268
- try:
269
- target_frame_by_plant[k_norm] = int(pv.get('frame', 8))
270
- except Exception:
271
- target_frame_by_plant[k_norm] = 8
272
- before = len(plants)
273
- plants = {
274
- key: pdata for key, pdata in plants.items()
275
- if (len(key.split('_')) > 3 and key.split('_')[3] in target_frame_by_plant
276
- and key.split('_')[-1] == f"frame{target_frame_by_plant[key.split('_')[3]]}")
277
- }
278
- logger.info(f"Restricted loaded data by mapping frames: {before} -> {len(plants)} items")
279
- except Exception as e:
280
- logger.warning(f"Failed to restrict loaded data by mapping frames: {e}")
281
-
282
- # Skip plants that already have saved results (unless force_reprocess)
283
- if not force_reprocess:
284
- try:
285
- before = len(plants)
286
- filtered = {}
287
- for key, pdata in plants.items():
288
- parts = key.split('_')
289
- if len(parts) < 5:
290
- filtered[key] = pdata
291
- continue
292
- date_key = "_".join(parts[:3])
293
- plant_name = parts[3]
294
- plant_dir = Path(self.config.paths.output_folder) / date_key / plant_name
295
- meta_ok = (plant_dir / 'metadata.json').exists()
296
- seg_mask_ok = (plant_dir / self.config.output.segmentation_dir / 'mask.png').exists()
297
- if meta_ok or seg_mask_ok:
298
- continue
299
- filtered[key] = pdata
300
- plants = filtered
301
- logger.info(f"Skip-existing filter: {before} -> {len(plants)} items to process")
302
- except Exception as e:
303
- logger.warning(f"Skip-existing filter failed: {e}")
304
 
305
- # Pre-segmentation borrowing: use plant12 images for plant13 from the start
306
- try:
307
- rewired = 0
308
- borrow_map: Dict[str, str] = {
309
- 'plant13': 'plant12',
310
- 'plant14': 'plant13',
311
- 'plant15': 'plant14',
312
- 'plant16': 'plant15',
313
- }
314
- for _k in list(plants.keys()):
315
- _parts = _k.split('_')
316
- # Expect keys like YYYY_MM_DD_plantX_frameY
317
- if len(_parts) < 5:
318
- continue
319
- _date_key = "_".join(_parts[:3])
320
- _plant_name = _parts[3]
321
- _frame_token = _parts[4]
322
- # Do NOT borrow on 2025_05_08
323
- if _date_key == '2025_05_08':
324
- continue
325
- if _plant_name not in borrow_map:
326
- continue
327
- _src_plant = borrow_map[_plant_name]
328
- _src_key = f"{_date_key}_{_src_plant}_{_frame_token}"
329
- _src = plants.get(_src_key)
330
- if not _src:
331
- # Fallback: load raw image for source plant directly from disk
332
- try:
333
- from PIL import Image as _Image
334
- _date_folder = _date_key.replace('_', '-')
335
- _frame_num = int(_frame_token.replace('frame', ''))
336
- _date_dir = Path(self.config.paths.input_folder)
337
- # If input folder is a parent of dates, append date folder
338
- if _date_dir.name != _date_folder:
339
- _date_dir = _date_dir / _date_folder
340
- _frame_path = _date_dir / _src_plant / f"{_src_plant}_frame{_frame_num}.tif"
341
- if _frame_path.exists():
342
- _img = _Image.open(str(_frame_path))
343
- _src = {"raw_image": (_img, _frame_path.name), "plant_name": _plant_name, "file_path": str(_frame_path)}
344
- else:
345
- _src = None
346
- except Exception:
347
- _src = None
348
- if not _src:
349
- continue
350
- _tgt = plants[_k]
351
- # Preserve original raw image once
352
- if 'raw_image' in _tgt and 'raw_image_original' not in _tgt:
353
- _tgt['raw_image_original'] = _tgt['raw_image']
354
- if 'raw_image' in _src:
355
- _tgt['raw_image'] = _src['raw_image']
356
- _tgt['borrowed_from'] = _src_plant
357
- rewired += 1
358
- if rewired > 0:
359
- logger.info(f"Pre-seg borrowing applied: rewired {rewired} frames for plants 13/14/15/16")
360
- except Exception as e:
361
- logger.warning(f"Pre-seg borrowing failed: {e}")
362
-
363
- # Step 2: Create composites
364
- logger.info("Step 2/6: Creating composites...")
365
- step_start = time.perf_counter()
366
  plants = self.preprocessor.create_composites(plants)
367
- logger.info(f"Composites done in {(time.perf_counter()-step_start):.2f}s")
368
 
369
- # Step 3: Segment plants (optionally with bounding boxes)
370
- logger.info("Step 3/6: Segmenting plants...")
371
- step_start = time.perf_counter()
372
- bbox_lookup = None
373
- try:
374
- bbox_dir = getattr(self.config.paths, 'boundingbox_dir', None)
375
- # Default to project BoundingBox dir if unset or falsy
376
- if not bbox_dir:
377
- try:
378
- self.config.paths.boundingbox_dir = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/BoundingBox"
379
- bbox_dir = self.config.paths.boundingbox_dir
380
- except Exception:
381
- bbox_dir = None
382
- if bbox_dir:
383
- bbox_lookup = self.data_loader.load_bounding_boxes(bbox_dir)
384
- logger.info(f"Loaded bounding boxes from {bbox_dir}")
385
- except Exception as e:
386
- logger.warning(f"Failed to load bounding boxes: {e}")
387
- bbox_lookup = None
388
- plants = self._segment_plants(plants, bbox_lookup)
389
- logger.info(f"Segmentation done in {(time.perf_counter()-step_start):.2f}s")
390
 
391
- # Step 3.5: Handle occlusion if enabled
392
- if self.enable_occlusion_handling and self.occlusion_handler is not None:
393
- logger.info("Step 3.5/6: Handling occlusion with SAM2Long...")
394
- step_start = time.perf_counter()
395
- plants = self._handle_occlusion(plants)
396
- logger.info(f"Occlusion handling done in {(time.perf_counter()-step_start):.2f}s")
397
-
398
- # Optional: Export RMBG maskouts with white background and run instance segmentation
399
- if (run_instance_segmentation or self.enable_instance_integration) and not reuse_instance_results:
400
- if not load_all_frames:
401
- logger.warning("Instance segmentation expects all 13 frames; consider running with load_all_frames=True.")
402
- logger.info("Step 3.6: Exporting white-background RMBG images for instance segmentation...")
403
- # Derive date-specific export/result directories when a single date is present
404
- date_keys = set()
405
- try:
406
- for _k in plants.keys():
407
- _p = _k.split('_')
408
- if len(_p) >= 3:
409
- date_keys.add("_".join(_p[:3]))
410
- except Exception:
411
- pass
412
- if len(date_keys) == 1:
413
- date_key = next(iter(date_keys))
414
- base_dir = Path(self.config.paths.output_folder) / date_key
415
- export_dir = base_dir / "instance_input_maskouts"
416
- instance_results_dir = base_dir / "instance_results"
417
- else:
418
- export_dir = Path(self.config.paths.output_folder) / "instance_input_maskouts"
419
- instance_results_dir = Path(self.config.paths.output_folder) / "instance_results"
420
- export_dir.mkdir(parents=True, exist_ok=True)
421
- instance_results_dir.mkdir(parents=True, exist_ok=True)
422
- self._export_white_background_maskouts(plants, export_dir)
423
-
424
- logger.info("Invoking final SAM2Long instance segmentation on exported images...")
425
- # Free GPU memory before launching SAM2Long to avoid CUDA OOM
426
- self._free_gpu_memory_before_instance()
427
- env = os.environ.copy()
428
- env["SAM2LONG_IMAGES_DIR"] = str(export_dir)
429
- env["SAM2LONG_RESULTS_DIR"] = str(instance_results_dir)
430
- # Ensure instance outputs include all frames for all dates
431
- try:
432
- env.pop("INSTANCE_OUTPUT_FRAMES", None)
433
- except Exception:
434
- pass
435
- script_path = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/Experiments3_code/sam2long_instance_integration.py"
436
- try:
437
- subprocess.run(["python", script_path], check=True, env=env)
438
- except subprocess.CalledProcessError as e:
439
- logger.error(f"Instance segmentation failed: {e}")
440
- else:
441
- # Integrate instance masks (track_0 as target) into pdata before feature extraction
442
- try:
443
- self._apply_instance_masks(plants, instance_results_dir)
444
- logger.info("Applied instance segmentation masks to pipeline data")
445
- except Exception as e:
446
- logger.warning(f"Failed to apply instance masks: {e}")
447
- elif reuse_instance_results:
448
- # Reuse existing instance masks from mapping file
449
- if instance_mapping_path is None:
450
- raise ValueError("reuse_instance_results=True requires instance_mapping_path to be provided")
451
- try:
452
- self._apply_instance_masks_from_mapping(plants, Path(instance_mapping_path))
453
- logger.info("Applied instance masks from mapping file")
454
- except Exception as e:
455
- logger.error(f"Failed to apply instance masks from mapping: {e}")
456
 
457
- if not segmentation_only:
458
- # If reusing instance results with a mapping, restrict features to mapped frames per plant
459
- if reuse_instance_results and instance_mapping_path is not None:
460
- try:
461
- import json as _json
462
- _map = _json.load(open(instance_mapping_path, 'r'))
463
- # Normalize map
464
- _norm = {}
465
- for pk, pv in _map.items():
466
- k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
467
- _norm[k_norm] = int(pv.get('frame', 8))
468
- before = len(plants)
469
- plants = {
470
- k: v for k, v in plants.items()
471
- if len(k.split('_')) > 3 and k.split('_')[3] in _norm and k.split('_')[-1] == f"frame{_norm[k.split('_')[3]]}"
472
- }
473
- logger.info(f"Restricted feature extraction by mapping: {before} -> {len(plants)} items")
474
- except Exception as e:
475
- logger.warning(f"Failed to restrict by mapping frames: {e}")
476
- # Optional: restrict features to per-plant preferred frame using internal frame rules
477
- if respect_instance_frame_rules_for_features:
478
- try:
479
- # Keep this in sync with _apply_instance_masks frame_rules
480
- frame_rules: Dict[str, int] = {
481
- "plant33": 2,
482
- "plant16": 4,
483
- "plant19": 5,
484
- "plant26": 8,
485
- "plant27": 8,
486
- "plant29": 8,
487
- "plant35": 7,
488
- "plant36": 6,
489
- "plant37": 2,
490
- "plant45": 5,
491
- }
492
- before = len(plants)
493
- def _keep(k: str) -> bool:
494
- parts = k.split('_')
495
- if len(parts) < 2:
496
- return False
497
- plant_name = parts[-2]
498
- frame_token = parts[-1]
499
- if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
500
- return False
501
- desired = frame_rules.get(plant_name, 8)
502
- return frame_token == f"frame{desired}"
503
- plants = {k: v for k, v in plants.items() if _keep(k)}
504
- logger.info(f"Restricted feature extraction by per-plant frame rules: {before} -> {len(plants)} items")
505
- except Exception as e:
506
- logger.warning(f"Failed to apply per-plant frame restriction for features: {e}")
507
-
508
- # Optional: if features_frame_only set, keep only that frame's entries (global single frame)
509
- if features_frame_only is not None:
510
- frame_token = f"frame{features_frame_only}"
511
- plants = {k: v for k, v in plants.items() if k.split('_')[-1] == frame_token}
512
- logger.info(f"Restricted feature extraction to {len(plants)} items for {frame_token}")
513
-
514
- # Optional: substitute feature input image from instance src_rules mapping (e.g., plant14 <- plant13)
515
- if substitute_feature_image_from_instance_src:
516
- try:
517
- src_rules: Dict[str, str] = {
518
- "plant13": "plant12",
519
- "plant14": "plant13",
520
- "plant15": "plant14",
521
- "plant16": "plant15",
522
- }
523
- switched = 0
524
- for key in list(plants.keys()):
525
- parts = key.split('_')
526
- if len(parts) < 5:
527
- continue
528
- date_key = "_".join(parts[:3])
529
- plant_name = parts[3]
530
- frame_token = parts[-1]
531
- if plant_name not in src_rules:
532
- continue
533
- src_plant = src_rules[plant_name]
534
- src_key = f"{date_key}_{src_plant}_{frame_token}"
535
- if src_key not in plants:
536
- continue
537
- src_pdata = plants[src_key]
538
- tgt_pdata = plants[key]
539
- # Preserve the original composite used for segmentation for correct overlays later
540
- try:
541
- if 'composite' in tgt_pdata and 'segmentation_composite' not in tgt_pdata:
542
- tgt_pdata['segmentation_composite'] = tgt_pdata['composite']
543
- except Exception:
544
- pass
545
- # Swap feature inputs: composite and spectral bands
546
- if 'composite' in src_pdata:
547
- tgt_pdata['composite'] = src_pdata['composite']
548
- if 'spectral_stack' in src_pdata:
549
- tgt_pdata['spectral_stack'] = src_pdata['spectral_stack']
550
- # Ensure mask aligns with substituted composite; resize if needed
551
- try:
552
- import cv2 as _cv2
553
- import numpy as _np
554
- comp = tgt_pdata.get('composite')
555
- msk = tgt_pdata.get('mask')
556
- if comp is not None and msk is not None:
557
- ch, cw = comp.shape[:2]
558
- mh, mw = msk.shape[:2]
559
- if (mh, mw) != (ch, cw):
560
- resized = _cv2.resize(msk.astype('uint8'), (cw, ch), interpolation=_cv2.INTER_NEAREST)
561
- tgt_pdata['mask'] = resized
562
- if 'soft_mask' in tgt_pdata and isinstance(tgt_pdata['soft_mask'], _np.ndarray):
563
- tgt_pdata['soft_mask'] = (resized > 0).astype(_np.float32)
564
- # Precompute masked composite with white background for saving
565
- white = _np.full_like(comp, 255, dtype=_np.uint8)
566
- result = white.copy()
567
- result[tgt_pdata['mask'] > 0] = comp[tgt_pdata['mask'] > 0]
568
- tgt_pdata['masked_composite'] = result
569
- except Exception:
570
- pass
571
- switched += 1
572
- if switched > 0:
573
- logger.info(f"Substituted feature images from src_rules for {switched} items")
574
- except Exception as e:
575
- logger.warning(f"Failed feature-image substitution via src_rules: {e}")
576
- # Step 4: Extract features
577
- logger.info("Step 4/6: Extracting features...")
578
- step_start = time.perf_counter()
579
- # Stream-save mode: save outputs immediately after each plant's features when fast output is enabled
580
- stream_save = False
581
- try:
582
- import os as _os
583
- stream_save = bool(int(_os.environ.get('STREAM_SAVE', '0'))) or bool(getattr(self.output_manager, 'fast_mode', False))
584
- except Exception:
585
- stream_save = False
586
-
587
- plants = self._extract_features(plants, stream_save=stream_save)
588
- logger.info(f"Features done in {(time.perf_counter()-step_start):.2f}s")
589
-
590
- # Step 5: Generate outputs (skip if already stream-saved)
591
- if not stream_save:
592
- logger.info("Step 5/6: Generating outputs...")
593
- step_start = time.perf_counter()
594
- self._generate_outputs(plants)
595
- logger.info(f"Outputs done in {(time.perf_counter()-step_start):.2f}s")
596
-
597
- # Step 6: Create summary
598
- logger.info("Step 6/6: Creating summary...")
599
- summary = self._create_summary(plants)
600
- else:
601
- logger.info("Segmentation-only mode: skipping texture/vegetation/morphology features and plots")
602
- # Segmentation-only: generate only segmentation outputs and a minimal summary
603
- logger.info("Step 4/4: Generating segmentation outputs (segmentation-only mode)...")
604
- self._generate_outputs(plants)
605
- summary = {
606
- "total_plants": len(plants),
607
- "successful_plants": len(plants),
608
- "failed_plants": 0,
609
- "features_extracted": {
610
- "texture": 0,
611
- "vegetation": 0,
612
- "morphology": 0
613
- }
614
- }
615
 
616
  total_time = time.perf_counter() - total_start
617
- logger.info(f"Pipeline completed successfully in {total_time:.2f}s!")
 
618
  return {
619
  "plants": plants,
620
  "summary": summary,
@@ -626,752 +116,129 @@ class SorghumPipeline:
626
  logger.error(f"Pipeline failed: {e}")
627
  raise
628
 
629
- def _export_white_background_maskouts(self, plants: Dict[str, Any], out_dir: Path) -> None:
630
- """Export RMBG composites with white background using the soft/binary masks.
631
-
632
- Filenames follow: plantX_plantX_frameY_maskout.png so the final instance script can detect plants.
633
- """
634
- # Clear any previous maskouts to avoid processing stale plants
635
- try:
636
- if out_dir.exists():
637
- for p in out_dir.glob("*_maskout.png"):
638
- try:
639
- p.unlink()
640
- except Exception:
641
- pass
642
- except Exception:
643
- pass
644
- count = 0
645
- # Per-plant rule: use bbox-only (skip SAM2Long) for these plants on all dates except 2025_05_08
646
- bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
647
- date_exception = "2025_05_08"
648
  for key, pdata in plants.items():
649
  try:
650
- # key format: "YYYY_MM_DD_plantX_frameY"
651
- parts = key.split('_')
652
- if len(parts) < 3:
653
- continue
654
- plant_name = parts[-2]
655
- frame_token = parts[-1] # e.g., frame8
656
- if not plant_name.startswith('plant') or not frame_token.startswith('frame'):
657
- continue
658
- date_key = "_".join(parts[:3])
659
- if (plant_name in bbox_only_plants) and (date_key != date_exception):
660
- # Skip exporting maskouts for bbox-only plants so SAM2Long does not run on them
661
- continue
662
- # Extract frame number
663
- frame_num = int(frame_token.replace('frame', ''))
664
- composite = pdata.get('composite')
665
- mask = pdata.get('mask')
666
- if composite is None or mask is None:
667
- continue
668
- # Ensure 3-channel BGR
669
- if len(composite.shape) == 2:
670
- composite_bgr = cv2.cvtColor(composite, cv2.COLOR_GRAY2BGR)
671
- else:
672
- composite_bgr = composite
673
- out_img = composite_bgr.copy()
674
- # Set background to white where mask == 0
675
- out_img[mask == 0] = (255, 255, 255)
676
- out_path = out_dir / f"{plant_name}_{plant_name}_{frame_token}_maskout.png"
677
- cv2.imwrite(str(out_path), out_img)
678
- count += 1
679
- except Exception as e:
680
- logger.warning(f"Failed to export maskout for {key}: {e}")
681
- logger.info(f"Exported {count} white-background maskouts to {out_dir}")
682
-
683
- def _segment_plants(self, plants: Dict[str, Any],
684
- bbox_lookup: Optional[Dict[str, tuple]]) -> Dict[str, Any]:
685
- """Segment plants using BRIA model.
686
-
687
- If bbox_lookup is provided and contains an entry for the plant (e.g., 'plant1'),
688
- the image is cropped/masked to the bounding box region before segmentation and the
689
- predicted mask is mapped back to the full image size. In bbox mode a largest
690
- connected component post-processing is applied to obtain a clean target mask.
691
- """
692
- total = len(plants)
693
- iterator = plants.items()
694
- if tqdm is not None:
695
- iterator = tqdm(list(plants.items()), desc="Segmenting", total=total, unit="img", leave=False)
696
- for idx, (key, pdata) in enumerate(iterator):
697
- try:
698
- # Get composite image
699
  composite = pdata['composite']
700
- h, w = composite.shape[:2]
701
-
702
- # Determine bbox for this plant if available
703
- parts = key.split('_')
704
- plant_name = parts[-2] if len(parts) >= 2 else None
705
- date_key = "_".join(parts[:3]) if len(parts) >= 3 else None # e.g., 2025_04_16
706
- bbox = None
707
- if bbox_lookup is not None and plant_name is not None:
708
- # keys in bbox_lookup are typically like 'plant1'
709
- bbox = bbox_lookup.get(plant_name)
710
- # For plant33, ignore any bbox and run full-image segmentation on all dates except the exception
711
- if plant_name == 'plant33' and date_key != '2025_05_08':
712
- bbox = None
713
-
714
- # Plants that should use the bounding box itself as the mask (skip model)
715
- bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant39", "plant42", "plant44", "plant46"}
716
- use_bbox_only = (plant_name in bbox_only_plants)
717
-
718
- # Do not use bounding boxes for date 2025_05_08
719
- if date_key == '2025_05_08':
720
- bbox = None
721
-
722
- if bbox is not None:
723
- # Clamp bbox to image
724
- x1, y1, x2, y2 = bbox
725
- x1 = max(0, min(w, int(x1)))
726
- x2 = max(0, min(w, int(x2)))
727
- y1 = max(0, min(h, int(y1)))
728
- y2 = max(0, min(h, int(y2)))
729
- if x2 <= x1 or y2 <= y1:
730
- x1, y1, x2, y2 = 0, 0, w, h
731
-
732
- if use_bbox_only:
733
- # Use the bbox as the mask directly (255 inside, 0 outside)
734
- soft_full = np.zeros((h, w), dtype=np.float32)
735
- soft_full[y1:y2, x1:x2] = 1.0
736
- bin_full = np.zeros((h, w), dtype=np.uint8)
737
- bin_full[y1:y2, x1:x2] = 255
738
- pdata['soft_mask'] = soft_full
739
- pdata['mask'] = bin_full
740
- else:
741
- # Segment inside the bbox region and map back
742
- crop = composite[y1:y2, x1:x2]
743
- soft_mask_crop = self.segmentation_manager.segment_image_soft(crop)
744
- soft_full = np.zeros((h, w), dtype=np.float32)
745
- soft_resized = cv2.resize(soft_mask_crop, (x2 - x1, y2 - y1), interpolation=cv2.INTER_LINEAR)
746
- soft_full[y1:y2, x1:x2] = soft_resized
747
- bin_full = (soft_full > 0.5).astype(np.uint8) * 255
748
- try:
749
- n_lbl, labels, stats, _ = cv2.connectedComponentsWithStats(bin_full, 8)
750
- if n_lbl > 1:
751
- largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
752
- bin_full = (labels == largest).astype(np.uint8) * 255
753
- except Exception:
754
- pass
755
- pdata['soft_mask'] = soft_full.astype(np.float32)
756
- pdata['mask'] = bin_full.astype(np.uint8)
757
- else:
758
- # Full-image segmentation (no bbox)
759
- soft_mask = self.segmentation_manager.segment_image_soft(composite)
760
- pdata['soft_mask'] = soft_mask
761
- pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
762
-
763
- # Progress log every 25 items and for first/last
764
- if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
765
- logger.info(f"Segmented {idx + 1}/{total}: {key}")
766
-
767
  except Exception as e:
768
  logger.error(f"Segmentation failed for {key}: {e}")
769
  pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
770
  pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
771
-
772
  return plants
773
 
774
- def _handle_occlusion(self, plants: Dict[str, Any]) -> Dict[str, Any]:
775
- """
776
- Handle occlusion problems using SAM2Long.
777
-
778
- This method groups plants by their base plant ID and processes
779
- each plant's 13-frame sequence to differentiate target plant
780
- from neighboring plants.
781
-
782
- Args:
783
- plants: Dictionary of plant data
784
-
785
- Returns:
786
- Updated plant data with occlusion handling results
787
- """
788
- if self.occlusion_handler is None:
789
- logger.warning("Occlusion handler not available, skipping occlusion handling")
790
- return plants
791
-
792
- # Group plants by base plant ID (e.g., "plant1" from "plant1_plant1_frame1")
793
- plant_groups = {}
794
  for key, pdata in plants.items():
795
- # Extract plant ID from key like "plant1_plant1_frame1"
796
- parts = key.split('_')
797
- if len(parts) >= 3:
798
- plant_id = parts[0] # e.g., "plant1"
799
- if plant_id not in plant_groups:
800
- plant_groups[plant_id] = []
801
- plant_groups[plant_id].append((key, pdata))
802
-
803
- logger.info(f"Processing {len(plant_groups)} plant groups for occlusion handling")
804
-
805
- # Process each plant group
806
- for plant_id, plant_frames in plant_groups.items():
807
- try:
808
- # Sort frames by frame number
809
- plant_frames.sort(key=lambda x: int(x[0].split('_')[-1].replace('frame', '')))
810
-
811
- if len(plant_frames) < 2:
812
- logger.warning(f"Plant {plant_id} has only {len(plant_frames)} frames, skipping")
813
- continue
814
-
815
- # Extract frames and keys
816
- frame_keys = [x[0] for x in plant_frames]
817
- frames = [x[1]['composite'] for x in plant_frames]
818
-
819
- logger.info(f"Processing plant {plant_id} with {len(frames)} frames")
820
-
821
- # Process with SAM2Long
822
- occlusion_results = self.occlusion_handler.segment_plant_sequence(
823
- frames=frames,
824
- target_plant_id=plant_id
825
- )
826
-
827
- # Update plant data with occlusion results
828
- target_masks = occlusion_results['target_masks']
829
- neighbor_masks = occlusion_results['neighbor_masks']
830
-
831
- for i, (key, pdata) in enumerate(plant_frames):
832
- if i < len(target_masks):
833
- # Update mask with target plant only
834
- pdata['original_mask'] = pdata.get('mask', np.zeros_like(target_masks[i]))
835
- pdata['mask'] = target_masks[i]
836
- pdata['neighbor_mask'] = neighbor_masks[i]
837
- pdata['occlusion_handled'] = True
838
-
839
- # Update soft mask as well
840
- pdata['original_soft_mask'] = pdata.get('soft_mask', np.zeros_like(target_masks[i], dtype=np.float32))
841
- pdata['soft_mask'] = (target_masks[i] / 255.0).astype(np.float32)
842
-
843
- # Calculate and store occlusion metrics
844
- metrics = self.occlusion_handler.get_occlusion_metrics(occlusion_results)
845
- for key, pdata in plant_frames:
846
- pdata['occlusion_metrics'] = metrics
847
-
848
- logger.info(f"Plant {plant_id} occlusion handling completed")
849
- logger.info(f" - Average occlusion ratio: {metrics['average_occlusion_ratio']:.3f}")
850
- logger.info(f" - Frames with occlusion: {metrics['frames_with_occlusion']}")
851
-
852
- except Exception as e:
853
- logger.error(f"Occlusion handling failed for plant {plant_id}: {e}")
854
- # Mark as failed but continue
855
- for key, pdata in plant_frames:
856
- pdata['occlusion_handled'] = False
857
- pdata['occlusion_error'] = str(e)
858
-
859
- return plants
860
-
861
- def _extract_features(self, plants: Dict[str, Any], stream_save: bool = False) -> Dict[str, Any]:
862
- """Extract all features from plants.
863
-
864
- If stream_save is True, save outputs for each plant immediately after
865
- its features are computed to improve throughput and reduce peak memory.
866
- """
867
- total = len(plants)
868
- logger.info(f"Extracting features for {total} plants...")
869
- iterator = plants.items()
870
- if tqdm is not None:
871
- iterator = tqdm(list(plants.items()), desc="Extracting features", total=total, unit="img", leave=False)
872
-
873
- # Prepare output directories once if we're streaming saves
874
- if stream_save:
875
  try:
876
- self.output_manager.create_output_directories()
877
- except Exception:
878
- pass
879
-
880
- for idx, (key, pdata) in enumerate(iterator):
881
- try:
882
- logger.debug(f"Extracting features for {key}")
883
-
884
- # Extract texture features
885
  pdata['texture_features'] = self._extract_texture_features(pdata)
886
-
887
- # Extract vegetation indices
888
  pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
889
-
890
- # Extract morphological features
891
  pdata['morphology_features'] = self._extract_morphology_features(pdata)
892
-
893
- # Immediately save outputs for this plant if streaming is enabled
894
- if stream_save:
895
- try:
896
- self.output_manager.save_plant_results(key, pdata)
897
- except Exception as _e:
898
- logger.error(f"Stream-save failed for {key}: {_e}")
899
-
900
- logger.debug(f"Features extracted for {key}")
901
- if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
902
- logger.info(f"Extracted features for {idx + 1}/{total}: {key}")
903
-
904
  except Exception as e:
905
  logger.error(f"Feature extraction failed for {key}: {e}")
906
- # Add empty features
907
  pdata['texture_features'] = {}
908
  pdata['vegetation_indices'] = {}
909
  pdata['morphology_features'] = {}
910
-
911
  return plants
912
 
913
  def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
914
- """Extract texture features for a single plant."""
915
  features = {}
916
-
917
- # Get bands to process
918
- bands = ['color', 'nir', 'red_edge', 'red', 'green', 'pca']
919
-
920
- for band in bands:
921
- try:
922
- # Prepare grayscale image
923
- gray_image = self._prepare_band_image(pdata, band)
924
-
925
- # Extract texture features
926
- band_features = self.texture_extractor.extract_all_texture_features(gray_image)
927
-
928
- # Compute statistics using mask3 → features_mask → mask
929
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
930
- stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
931
-
932
- features[band] = {
933
- 'features': band_features,
934
- 'statistics': stats
935
- }
936
-
937
- except Exception as e:
938
- logger.error(f"Texture extraction failed for band {band}: {e}")
939
- features[band] = {'features': {}, 'statistics': {}}
940
 
941
  return features
942
 
943
  def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
944
- """Extract vegetation indices for a single plant."""
945
  try:
946
  spectral_stack = pdata.get('spectral_stack', {})
947
- # Prefer mask3 → features_mask → mask
948
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
949
-
950
  if not spectral_stack or mask is None:
951
  return {}
952
 
953
- return self.vegetation_extractor.compute_vegetation_indices(
954
- spectral_stack, mask
955
- )
956
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  except Exception as e:
958
  logger.error(f"Vegetation index extraction failed: {e}")
959
  return {}
960
 
961
  def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
962
- """Extract morphological features for a single plant."""
963
  try:
964
  composite = pdata.get('composite')
965
- # Prefer mask3 → features_mask → mask
966
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
967
-
968
  if composite is None or mask is None:
969
  return {}
970
-
971
- return self.morphology_extractor.extract_morphology_features(
972
- composite, mask
973
- )
974
-
975
  except Exception as e:
976
- logger.error(f"Morphology feature extraction failed: {e}")
977
  return {}
978
 
979
- def _prepare_band_image(self, pdata: Dict[str, Any], band: str) -> np.ndarray:
980
- """Prepare grayscale image for a specific band."""
981
- if band == 'color':
982
- composite = pdata['composite']
983
- # Prefer mask3 → features_mask → mask
984
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
985
- if mask is not None:
986
- masked = self.mask_handler.apply_mask_to_image(composite, mask)
987
- return cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
988
- else:
989
- return cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
990
-
991
- elif band == 'pca':
992
- # Create PCA from spectral bands
993
- spectral_stack = pdata.get('spectral_stack', {})
994
- # Prefer mask3 → features_mask → mask
995
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
996
-
997
- if not spectral_stack:
998
- return np.zeros((512, 512), dtype=np.uint8)
999
-
1000
- # Stack bands
1001
- bands_data = []
1002
- for b in ['nir', 'red_edge', 'red', 'green']:
1003
- if b in spectral_stack:
1004
- arr = spectral_stack[b].squeeze(-1).astype(float)
1005
- if mask is not None:
1006
- arr = np.where(mask > 0, arr, np.nan)
1007
- bands_data.append(arr)
1008
-
1009
- if not bands_data:
1010
- return np.zeros((512, 512), dtype=np.uint8)
1011
-
1012
- # Create PCA
1013
- full_stack = np.stack(bands_data, axis=-1)
1014
- h, w, c = full_stack.shape
1015
- flat = full_stack.reshape(-1, c)
1016
- valid = ~np.isnan(flat).any(axis=1)
1017
-
1018
- if valid.sum() == 0:
1019
- return np.zeros((h, w), dtype=np.uint8)
1020
-
1021
- vec = np.zeros(h * w)
1022
- vec[valid] = PCA(n_components=1, whiten=True).fit_transform(
1023
- flat[valid]
1024
- ).squeeze()
1025
-
1026
- gray_f = vec.reshape(h, w)
1027
- if mask is not None:
1028
- m, M = gray_f[mask > 0].min(), gray_f[mask > 0].max()
1029
- else:
1030
- m, M = gray_f.min(), gray_f.max()
1031
-
1032
- if M > m:
1033
- gray = ((gray_f - m) / (M - m) * 255).astype(np.uint8)
1034
- else:
1035
- gray = np.zeros_like(gray_f, dtype=np.uint8)
1036
-
1037
- return gray
1038
-
1039
- else:
1040
- # Individual spectral band
1041
- spectral_stack = pdata.get('spectral_stack', {})
1042
- # Prefer mask3 → features_mask → mask
1043
- mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
1044
-
1045
- if band not in spectral_stack:
1046
- return np.zeros((512, 512), dtype=np.uint8)
1047
-
1048
- arr = spectral_stack[band].squeeze(-1).astype(float)
1049
- if mask is not None:
1050
- arr = np.where(mask > 0, arr, np.nan)
1051
-
1052
- if mask is not None:
1053
- m, M = np.nanmin(arr), np.nanmax(arr)
1054
- else:
1055
- m, M = arr.min(), arr.max()
1056
-
1057
- if M > m:
1058
- gray = ((np.nan_to_num(arr, nan=m) - m) / (M - m) * 255).astype(np.uint8)
1059
- else:
1060
- gray = np.zeros_like(arr, dtype=np.uint8)
1061
-
1062
- return gray
1063
-
1064
  def _generate_outputs(self, plants: Dict[str, Any]) -> None:
1065
- """Generate all output files and visualizations."""
1066
  self.output_manager.create_output_directories()
1067
-
1068
  for key, pdata in plants.items():
1069
  try:
1070
- logger.debug(f"Generating outputs for {key}")
1071
  self.output_manager.save_plant_results(key, pdata)
1072
  except Exception as e:
1073
  logger.error(f"Output generation failed for {key}: {e}")
1074
 
1075
  def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
1076
- """Create summary of pipeline results."""
1077
- summary = {
1078
  "total_plants": len(plants),
1079
- "successful_plants": 0,
1080
- "failed_plants": 0,
1081
  "features_extracted": {
1082
- "texture": 0,
1083
- "vegetation": 0,
1084
- "morphology": 0
1085
  }
1086
- }
1087
-
1088
- for key, pdata in plants.items():
1089
- try:
1090
- # Check if features were extracted
1091
- if pdata.get('texture_features'):
1092
- summary["features_extracted"]["texture"] += 1
1093
- if pdata.get('vegetation_indices'):
1094
- summary["features_extracted"]["vegetation"] += 1
1095
- if pdata.get('morphology_features'):
1096
- summary["features_extracted"]["morphology"] += 1
1097
-
1098
- summary["successful_plants"] += 1
1099
-
1100
- except Exception:
1101
- summary["failed_plants"] += 1
1102
-
1103
- return summary
1104
-
1105
- def _apply_instance_masks(self, plants: Dict[str, Any], instance_results_dir: Path) -> None:
1106
- """Replace segmentation masks with SAM2Long instance masks using track_1.
1107
-
1108
- Expects files under instance_results_dir/plantX/track_1/frame_YY_mask.png.
1109
- """
1110
- # Default and per-plant overrides for source plant, track and preferred frame
1111
- default_track = "track_0"
1112
- src_rules: Dict[str, str] = {
1113
- "plant13": "plant12",
1114
- "plant14": "plant13",
1115
- "plant15": "plant14",
1116
- "plant16": "plant15",
1117
- }
1118
- track_rules: Dict[str, str] = {
1119
- # explicit track rules
1120
- "plant1": "track_0",
1121
- "plant4": "track_0",
1122
- "plant9": "track_3",
1123
- "plant13": "track_1",
1124
- "plant14": "track_0",
1125
- "plant15": "track_0",
1126
- "plant16": "track_0",
1127
- "plant18": "track_0",
1128
- "plant19": "track_0",
1129
- "plant23": "track_1",
1130
- "plant26": "track_0",
1131
- "plant27": "track_0",
1132
- "plant29": "track_0",
1133
- "plant31": "track_1",
1134
- "plant34": "track_1",
1135
- "plant35": "track_1",
1136
- "plant36": "track_0",
1137
- "plant37": "track_1",
1138
- "plant38": "track_0",
1139
- "plant39": "track_1",
1140
- "plant40": "track_0",
1141
- "plant41": "track_1",
1142
- "plant42": "track_0",
1143
- "plant43": "track_0",
1144
- "plant45": "track_0",
1145
- }
1146
- frame_rules: Dict[str, int] = {
1147
- # preferred frame overrides (1-based)
1148
- "plant13": 8,
1149
- "plant14": 8,
1150
- "plant15": 8,
1151
- "plant33": 2,
1152
- "plant16": 4,
1153
- "plant19": 5,
1154
- "plant26": 8,
1155
- "plant27": 8,
1156
- "plant29": 8,
1157
- "plant35": 7,
1158
- "plant36": 6,
1159
- "plant37": 2,
1160
- "plant45": 5,
1161
- }
1162
- # Per-plant rule: skip applying instance masks (keep bbox/BRIA mask) on all dates except 2025_05_08
1163
- bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
1164
- date_exception = "2025_05_08"
1165
-
1166
- for key, pdata in plants.items():
1167
- try:
1168
- parts = key.split('_')
1169
- if len(parts) < 3:
1170
- continue
1171
- plant_name = parts[-2]
1172
- frame_token = parts[-1] # frame8
1173
- if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
1174
- continue
1175
- date_key = "_".join(parts[:3])
1176
- if (plant_name in bbox_only_plants) and (date_key != date_exception):
1177
- # Do not override masks for bbox-only plants
1178
- continue
1179
- frame_num = int(frame_token.replace('frame', ''))
1180
- # Resolve source plant, track and desired frame
1181
- src_plant = src_rules.get(plant_name, plant_name)
1182
- track_name = track_rules.get(plant_name, default_track)
1183
- desired_frame = frame_rules.get(plant_name, frame_num)
1184
- plant_dir = Path(instance_results_dir) / src_plant / track_name
1185
- mask_path = plant_dir / f"frame_{desired_frame:02d}_mask.png"
1186
- if not mask_path.exists():
1187
- # Fallback to current frame if override not found
1188
- fallback = plant_dir / f"frame_{frame_num:02d}_mask.png"
1189
- if fallback.exists():
1190
- mask_path = fallback
1191
- else:
1192
- # Last-resort: pick any available frame mask in the track directory
1193
- try:
1194
- candidates = sorted(plant_dir.glob("frame_*_mask.png"))
1195
- if len(candidates) > 0:
1196
- mask_path = candidates[0]
1197
- else:
1198
- continue
1199
- except Exception:
1200
- continue
1201
- inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
1202
- if inst_mask is None:
1203
- continue
1204
- # Ensure binary uint8 0/255
1205
- inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
1206
- pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
1207
- pdata['mask'] = inst_mask_bin
1208
- pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
1209
- pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
1210
- pdata['instance_applied'] = True
1211
-
1212
- # Build mask3 = external(mask) AND BRIA(original_mask)
1213
- try:
1214
- _m1 = pdata.get('mask')
1215
- _m2 = pdata.get('original_mask')
1216
- if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
1217
- _m1b = (_m1.astype(np.uint8) > 0)
1218
- _m2b = (_m2.astype(np.uint8) > 0)
1219
- mask3 = (_m1b & _m2b).astype(np.uint8) * 255
1220
- pdata['mask3'] = mask3
1221
- pdata['features_mask'] = mask3
1222
- except Exception:
1223
- pass
1224
-
1225
- # After applying instance masks, also overwrite the composite and spectral stack
1226
- # with the source plant's raw image (desired frame preferred) so that
1227
- # feature extraction and saved originals/overlays are consistent with the mask source.
1228
- try:
1229
- if plant_name in src_rules:
1230
- date_key = "_".join(parts[:3])
1231
- src_key_desired = f"{date_key}_{src_plant}_frame{desired_frame}"
1232
- src_key_same = f"{date_key}_{src_plant}_{frame_token}"
1233
- copy_from = plants.get(src_key_desired) or plants.get(src_key_same)
1234
- if copy_from is None:
1235
- # Fallback: load source composite from filesystem if not present in plants dict
1236
- try:
1237
- from PIL import Image as _Image
1238
- _date_folder = date_key.replace('_', '-')
1239
- _date_dir = Path(self.config.paths.input_folder)
1240
- if _date_dir.name != _date_folder:
1241
- _date_dir = _date_dir / _date_folder
1242
- _frame_path = _date_dir / src_plant / f"{src_plant}_frame{desired_frame}.tif"
1243
- if not _frame_path.exists():
1244
- _frame_path = _date_dir / src_plant / f"{src_plant}_frame{frame_num}.tif"
1245
- if _frame_path.exists():
1246
- _img = _Image.open(str(_frame_path))
1247
- # Process to composite using preprocessor
1248
- comp, spec = self.preprocessor.process_raw_image(_img)
1249
- copy_from = {"composite": comp, "spectral_stack": spec}
1250
- except Exception:
1251
- copy_from = None
1252
- if copy_from is not None:
1253
- # Preserve the segmentation-time composite once
1254
- if 'composite' in pdata and 'segmentation_composite' not in pdata:
1255
- pdata['segmentation_composite'] = pdata['composite']
1256
- if 'composite' in copy_from:
1257
- pdata['composite'] = copy_from['composite']
1258
- if 'spectral_stack' in copy_from:
1259
- pdata['spectral_stack'] = copy_from['spectral_stack']
1260
- # Ensure mask size matches the copied composite
1261
- ch, cw = pdata['composite'].shape[:2]
1262
- mh, mw = pdata['mask'].shape[:2]
1263
- if (mh, mw) != (ch, cw):
1264
- pdata['mask'] = cv2.resize(pdata['mask'].astype('uint8'), (cw, ch), interpolation=cv2.INTER_NEAREST)
1265
- pdata['soft_mask'] = (pdata['mask'] > 0).astype(np.float32)
1266
- except Exception:
1267
- pass
1268
- except Exception as e:
1269
- logger.debug(f"Instance mask apply failed for {key}: {e}")
1270
-
1271
- def _apply_instance_masks_from_mapping(self, plants: Dict[str, Any], mapping_file: Path) -> None:
1272
- """Apply instance masks using an explicit mapping file with absolute paths.
1273
-
1274
- mapping JSON structure:
1275
- {
1276
- "plant1": {"frame": 8, "mask_path": "/abs/path/to/plant1/track_X/frame_08_mask.png"},
1277
- "plant2": {"frame": 8, "mask_path": "/abs/path/.../frame_08_mask.png"},
1278
- ...
1279
- }
1280
- If a plant's mapping specifies a different frame, only entries matching that frame are updated.
1281
- """
1282
- import json
1283
- if not mapping_file.exists():
1284
- raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
1285
- with open(mapping_file, "r") as f:
1286
- mapping = json.load(f)
1287
- # Normalize mapping plant keys to names like 'plantX'
1288
- norm_map = {}
1289
- for k, v in mapping.items():
1290
- k_norm = k if str(k).startswith("plant") else f"plant{int(k)}" if str(k).isdigit() else str(k)
1291
- norm_map[k_norm] = v
1292
-
1293
- for key, pdata in plants.items():
1294
- try:
1295
- parts = key.split('_')
1296
- if len(parts) < 3:
1297
- continue
1298
- plant_name = parts[-2]
1299
- frame_token = parts[-1]
1300
- if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
1301
- continue
1302
- frame_num = int(frame_token.replace('frame', ''))
1303
- if plant_name not in norm_map:
1304
- continue
1305
- entry = norm_map[plant_name]
1306
- target_frame = int(entry.get("frame", frame_num))
1307
- if frame_num != target_frame:
1308
- # Only update the designated frame for this plant
1309
- continue
1310
- mask_path_str = entry.get("mask_path")
1311
- if not mask_path_str:
1312
- continue
1313
- mask_path = Path(mask_path_str)
1314
- if not mask_path.exists():
1315
- logger.warning(f"Mask path not found for {plant_name} {frame_token}: {mask_path}")
1316
- continue
1317
- inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
1318
- if inst_mask is None:
1319
- continue
1320
- inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
1321
- pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
1322
- pdata['mask'] = inst_mask_bin
1323
- pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
1324
- pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
1325
- pdata['instance_applied'] = True
1326
-
1327
- # Build mask3 = external(mask) AND BRIA(original_mask)
1328
- try:
1329
- _m1 = pdata.get('mask')
1330
- _m2 = pdata.get('original_mask')
1331
- if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
1332
- _m1b = (_m1.astype(np.uint8) > 0)
1333
- _m2b = (_m2.astype(np.uint8) > 0)
1334
- mask3 = (_m1b & _m2b).astype(np.uint8) * 255
1335
- pdata['mask3'] = mask3
1336
- pdata['features_mask'] = mask3
1337
- except Exception:
1338
- pass
1339
- except Exception as e:
1340
- logger.debug(f"Instance mapping apply failed for {key}: {e}")
1341
-
1342
-
1343
- def run_pipeline(config_path: str, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None) -> Dict[str, Any]:
1344
- """
1345
- Convenience function to run the pipeline.
1346
-
1347
- Args:
1348
- config_path: Path to configuration file
1349
- load_all_frames: Whether to load all frames or selected frames
1350
- segmentation_only: If True, run segmentation only and skip feature extraction
1351
-
1352
- Returns:
1353
- Pipeline results
1354
- """
1355
- pipeline = SorghumPipeline(config_path)
1356
- return pipeline.run(load_all_frames, segmentation_only, filter_plants)
1357
-
1358
-
1359
- if __name__ == "__main__":
1360
- import sys
1361
-
1362
- config_path = sys.argv[1] if len(sys.argv) > 1 else "config.yml"
1363
- load_all = "--all" in sys.argv
1364
- seg_only = "--seg-only" in sys.argv
1365
- # Basic arg parse for --plant=<name>
1366
- plant_filter = None
1367
- for arg in sys.argv[1:]:
1368
- if arg.startswith("--plant="):
1369
- plant_filter = [arg.split("=", 1)[1]]
1370
-
1371
- try:
1372
- results = run_pipeline(config_path, load_all, seg_only, plant_filter)
1373
- print("Pipeline completed successfully!")
1374
- print(f"Processed {results['summary']['total_plants']} plants")
1375
- except Exception as e:
1376
- print(f"Pipeline failed: {e}")
1377
- sys.exit(1)
 
1
  """
2
  Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
3
 
4
+ Minimal single-image version for Hugging Face demo.
 
5
  """
6
 
7
  import os
 
8
  import logging
9
  from pathlib import Path
10
+ from typing import Dict, Any, Optional
11
  import numpy as np
12
  import cv2
 
 
 
13
  from sklearn.decomposition import PCA
 
 
 
 
14
 
15
  from .config import Config
16
+ from .data import ImagePreprocessor, MaskHandler
17
  from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
18
  from .output import OutputManager
19
  from .segmentation import SegmentationManager
20
+
21
+ logger = logging.getLogger(__name__)
 
 
 
22
 
23
 
24
  class SorghumPipeline:
25
+ """Minimal pipeline for single-image plant phenotyping."""
 
 
 
 
 
26
 
27
+ def __init__(self, config: Config):
28
+ """Initialize the minimal pipeline."""
 
 
 
 
 
 
 
 
 
29
  self._setup_logging()
30
+ self.config = config
 
 
 
 
 
 
 
 
 
31
  self.config.validate()
32
+ self._initialize_components()
33
+ logger.info("Sorghum Pipeline initialized")
 
 
 
 
 
 
 
 
 
34
 
35
  def _setup_logging(self):
36
  """Setup logging configuration."""
37
  logging.basicConfig(
38
  level=logging.INFO,
39
+ format='%(asctime)s - %(levelname)s - %(message)s',
40
+ handlers=[logging.StreamHandler()]
 
 
 
41
  )
 
 
42
 
43
+ def _initialize_components(self):
44
+ """Initialize pipeline components."""
45
+ self.preprocessor = ImagePreprocessor(target_size=None)
46
+ self.mask_handler = MaskHandler(min_area=1000, kernel_size=7)
47
+ self.texture_extractor = TextureExtractor()
48
+ self.vegetation_extractor = VegetationIndexExtractor()
49
+ self.morphology_extractor = MorphologyExtractor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  self.segmentation_manager = SegmentationManager(
51
+ model_name="briaai/RMBG-2.0",
52
  device=self.config.get_device(),
53
+ threshold=0.5,
54
+ trust_remote_code=True
 
 
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  self.output_manager = OutputManager(
57
  output_folder=self.config.paths.output_folder,
58
  settings=self.config.output
59
  )
60
 
61
+ def run(self, single_image_path: str) -> Dict[str, Any]:
 
 
 
 
 
62
  """
63
+ Run minimal pipeline on single image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  Args:
66
+ single_image_path: Path to input image
 
67
 
68
  Returns:
69
+ Dictionary containing results
70
  """
71
+ logger.info("Starting minimal single-image pipeline...")
72
 
73
  try:
74
  import time
75
+ from PIL import Image as _Image
76
+
77
  total_start = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # Load single image
80
+ _p = Path(single_image_path)
81
+ _img = _Image.open(str(_p))
82
+ plants = {
83
+ "demo_demo_frame1": {
84
+ "raw_image": (_img, _p.name),
85
+ "plant_name": "demo",
86
+ "file_path": str(_p)
87
  }
88
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Create composite
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  plants = self.preprocessor.create_composites(plants)
 
92
 
93
+ # Segment
94
+ plants = self._segment_plants(plants)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Extract features
97
+ plants = self._extract_features(plants)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Generate outputs
100
+ self._generate_outputs(plants)
101
+
102
+ # Summary
103
+ summary = self._create_summary(plants)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  total_time = time.perf_counter() - total_start
106
+ logger.info(f"Pipeline completed in {total_time:.2f}s")
107
+
108
  return {
109
  "plants": plants,
110
  "summary": summary,
 
116
  logger.error(f"Pipeline failed: {e}")
117
  raise
118
 
119
+ def _segment_plants(self, plants: Dict[str, Any]) -> Dict[str, Any]:
120
+ """Segment plants using BRIA model (full image)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  for key, pdata in plants.items():
122
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  composite = pdata['composite']
124
+ soft_mask = self.segmentation_manager.segment_image_soft(composite)
125
+ pdata['soft_mask'] = soft_mask
126
+ pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
127
+ logger.info(f"Segmented {key}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
  logger.error(f"Segmentation failed for {key}: {e}")
130
  pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
131
  pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
 
132
  return plants
133
 
134
+ def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
135
+ """Extract features from plants."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  for key, pdata in plants.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
 
 
 
 
 
 
 
 
 
138
  pdata['texture_features'] = self._extract_texture_features(pdata)
 
 
139
  pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
 
 
140
  pdata['morphology_features'] = self._extract_morphology_features(pdata)
141
+ logger.info(f"Features extracted for {key}")
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
  logger.error(f"Feature extraction failed for {key}: {e}")
 
144
  pdata['texture_features'] = {}
145
  pdata['vegetation_indices'] = {}
146
  pdata['morphology_features'] = {}
 
147
  return plants
148
 
149
  def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
150
+ """Extract texture features from pseudo-color image only."""
151
  features = {}
152
+ try:
153
+ # Only process pseudo-color composite
154
+ composite = pdata['composite']
155
+ mask = pdata.get('mask')
156
+ if mask is not None:
157
+ masked = self.mask_handler.apply_mask_to_image(composite, mask)
158
+ gray_image = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
159
+ else:
160
+ gray_image = cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
161
+
162
+ band_features = self.texture_extractor.extract_all_texture_features(gray_image)
163
+ stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
164
+
165
+ features['color'] = {
166
+ 'features': band_features,
167
+ 'statistics': stats
168
+ }
169
+ except Exception as e:
170
+ logger.error(f"Texture extraction failed: {e}")
171
+ features['color'] = {'features': {}, 'statistics': {}}
 
 
 
 
172
 
173
  return features
174
 
175
  def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
176
+ """Extract vegetation indices (NDVI, ARI, GNDVI only)."""
177
  try:
178
  spectral_stack = pdata.get('spectral_stack', {})
179
+ mask = pdata.get('mask')
 
 
180
  if not spectral_stack or mask is None:
181
  return {}
182
 
183
+ out: Dict[str, Any] = {}
184
+ for name in ("NDVI", "ARI", "GNDVI"):
185
+ bands = self.vegetation_extractor.index_bands.get(name, [])
186
+ if not all(b in spectral_stack for b in bands):
187
+ continue
188
+ arrays = []
189
+ for b in bands:
190
+ arr = spectral_stack[b]
191
+ if isinstance(arr, np.ndarray):
192
+ arr = arr.squeeze(-1)
193
+ arrays.append(np.asarray(arr, dtype=np.float64))
194
+
195
+ values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
196
+ binary_mask = (np.asarray(mask).astype(np.int32) > 0)
197
+ masked_values = np.where(binary_mask, values, np.nan)
198
+ valid = masked_values[~np.isnan(masked_values)]
199
+
200
+ stats = {
201
+ 'mean': float(np.mean(valid)) if valid.size else 0.0,
202
+ 'std': float(np.std(valid)) if valid.size else 0.0,
203
+ 'min': float(np.min(valid)) if valid.size else 0.0,
204
+ 'max': float(np.max(valid)) if valid.size else 0.0,
205
+ 'median': float(np.median(valid)) if valid.size else 0.0,
206
+ }
207
+ out[name] = {'values': masked_values, 'statistics': stats}
208
+ return out
209
  except Exception as e:
210
  logger.error(f"Vegetation index extraction failed: {e}")
211
  return {}
212
 
213
  def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
214
+ """Extract morphological features."""
215
  try:
216
  composite = pdata.get('composite')
217
+ mask = pdata.get('mask')
 
 
218
  if composite is None or mask is None:
219
  return {}
220
+ return self.morphology_extractor.extract_morphology_features(composite, mask)
 
 
 
 
221
  except Exception as e:
222
+ logger.error(f"Morphology extraction failed: {e}")
223
  return {}
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def _generate_outputs(self, plants: Dict[str, Any]) -> None:
226
+ """Generate output files."""
227
  self.output_manager.create_output_directories()
 
228
  for key, pdata in plants.items():
229
  try:
 
230
  self.output_manager.save_plant_results(key, pdata)
231
  except Exception as e:
232
  logger.error(f"Output generation failed for {key}: {e}")
233
 
234
  def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
235
+ """Create summary of results."""
236
+ return {
237
  "total_plants": len(plants),
238
+ "successful_plants": sum(1 for p in plants.values() if p.get('texture_features')),
 
239
  "features_extracted": {
240
+ "texture": sum(1 for p in plants.values() if p.get('texture_features')),
241
+ "vegetation": sum(1 for p in plants.values() if p.get('vegetation_indices')),
242
+ "morphology": sum(1 for p in plants.values() if p.get('morphology_features'))
243
  }
244
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/segmentation/manager.py CHANGED
@@ -1,8 +1,5 @@
1
  """
2
- Segmentation manager for the Sorghum Pipeline.
3
-
4
- This module handles image segmentation using the BRIA model
5
- and provides post-processing capabilities.
6
  """
7
 
8
  import numpy as np
@@ -11,299 +8,51 @@ import torch
11
  from PIL import Image
12
  from torchvision import transforms
13
  from transformers import AutoModelForImageSegmentation
14
- from typing import Optional, Tuple
15
  import logging
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  class SegmentationManager:
21
- """Manages image segmentation using BRIA model."""
22
 
23
- def __init__(self,
24
- model_name: str = "briaai/RMBG-2.0",
25
- device: str = "auto",
26
- threshold: float = 0.5,
27
- trust_remote_code: bool = True,
28
- cache_dir: Optional[str] = None,
29
- local_files_only: bool = False):
30
- """
31
- Initialize segmentation manager.
32
-
33
- Args:
34
- model_name: Name of the BRIA model
35
- device: Device to run model on ("auto", "cpu", "cuda")
36
- threshold: Segmentation threshold
37
- trust_remote_code: Whether to trust remote code
38
- cache_dir: Hugging Face cache directory for model weights
39
- local_files_only: If True, only load from local cache
40
- """
41
  self.model_name = model_name
42
  self.threshold = threshold
43
- self.trust_remote_code = trust_remote_code
44
- self.cache_dir = cache_dir
45
- self.local_files_only = local_files_only
46
-
47
- # Determine device
48
- if device == "auto":
49
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
- else:
51
- self.device = device
52
-
53
- # Initialize model
54
- self.model = None
55
- self.transform = None
56
- self._load_model()
 
 
57
 
58
- def _load_model(self):
59
- """Load the BRIA segmentation model."""
60
- try:
61
- logger.info(f"Loading BRIA model: {self.model_name}")
62
-
63
- self.model = AutoModelForImageSegmentation.from_pretrained(
64
- self.model_name,
65
- trust_remote_code=self.trust_remote_code,
66
- cache_dir=self.cache_dir if self.cache_dir else None,
67
- local_files_only=self.local_files_only,
68
- ).eval().to(self.device)
69
-
70
- # Define image transform
71
- self.transform = transforms.Compose([
72
- transforms.Resize((1024, 1024)),
73
- transforms.ToTensor(),
74
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
75
- ])
76
-
77
- logger.info("BRIA model loaded successfully")
78
-
79
- except Exception as e:
80
- logger.error(f"Failed to load BRIA model: {e}")
81
- raise
82
-
83
- def segment_image(self, image: np.ndarray) -> np.ndarray:
84
- """
85
- Segment an image using the BRIA model.
86
-
87
- Args:
88
- image: Input image (BGR format)
89
-
90
- Returns:
91
- Binary mask (0/255)
92
- """
93
- if self.model is None:
94
- raise RuntimeError("Model not loaded")
95
-
96
- try:
97
- # Convert BGR to RGB
98
- rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
- pil_image = Image.fromarray(rgb_image)
100
-
101
- # Apply transform
102
- input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
103
-
104
- # Run inference
105
- with torch.no_grad():
106
- predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
107
-
108
- # Apply threshold
109
- mask = (predictions > self.threshold).astype(np.uint8) * 255
110
-
111
- # Resize back to original size
112
- original_size = (image.shape[1], image.shape[0]) # (width, height)
113
- mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
114
-
115
- return mask_resized
116
-
117
- except Exception as e:
118
- logger.error(f"Segmentation failed: {e}")
119
- # Return empty mask
120
- return np.zeros(image.shape[:2], dtype=np.uint8)
121
-
122
  def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
123
- """
124
- Segment an image and return a soft mask in [0, 1] resized to original size.
125
- No thresholding or post-processing is applied.
126
-
127
- Args:
128
- image: Input image (BGR format)
129
-
130
- Returns:
131
- Float mask in [0,1] with shape (H, W)
132
- """
133
- if self.model is None:
134
- raise RuntimeError("Model not loaded")
135
  try:
136
  rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
137
  pil_image = Image.fromarray(rgb_image)
138
  input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
 
139
  with torch.no_grad():
140
  preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
 
141
  original_size = (image.shape[1], image.shape[0])
142
  soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
143
  return np.clip(soft_mask, 0.0, 1.0)
144
  except Exception as e:
145
- logger.error(f"Soft segmentation failed: {e}")
146
- return np.zeros(image.shape[:2], dtype=np.float32)
147
-
148
- def post_process_mask(self, mask: np.ndarray,
149
- min_area: int = 1000,
150
- kernel_size: int = 5) -> np.ndarray:
151
- """
152
- Post-process segmentation mask.
153
-
154
- Args:
155
- mask: Input mask
156
- min_area: Minimum area for connected components
157
- kernel_size: Kernel size for morphological operations
158
-
159
- Returns:
160
- Post-processed mask
161
- """
162
- try:
163
- # Morphological opening to remove noise
164
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
165
- opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
166
-
167
- # Remove small connected components
168
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
169
- opened, connectivity=8
170
- )
171
-
172
- processed_mask = np.zeros_like(opened)
173
- for label in range(1, num_labels): # Skip background
174
- if stats[label, cv2.CC_STAT_AREA] >= min_area:
175
- processed_mask[labels == label] = 255
176
-
177
- return processed_mask
178
-
179
- except Exception as e:
180
- logger.error(f"Mask post-processing failed: {e}")
181
- return mask
182
-
183
- def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
184
- """
185
- Keep only the largest connected component.
186
-
187
- Args:
188
- mask: Input mask
189
-
190
- Returns:
191
- Mask with only the largest component
192
- """
193
- try:
194
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
195
-
196
- if num_labels <= 1:
197
- return mask
198
-
199
- # Find the largest component (excluding background)
200
- areas = stats[1:, cv2.CC_STAT_AREA]
201
- largest_label = 1 + np.argmax(areas)
202
-
203
- # Create mask with only the largest component
204
- largest_mask = (labels == largest_label).astype(np.uint8) * 255
205
-
206
- return largest_mask
207
-
208
- except Exception as e:
209
- logger.error(f"Largest component extraction failed: {e}")
210
- return mask
211
-
212
- def validate_mask(self, mask: np.ndarray) -> bool:
213
- """
214
- Validate segmentation mask.
215
-
216
- Args:
217
- mask: Mask to validate
218
-
219
- Returns:
220
- True if valid, False otherwise
221
- """
222
- if mask is None:
223
- return False
224
-
225
- if not isinstance(mask, np.ndarray):
226
- return False
227
-
228
- if mask.ndim != 2:
229
- return False
230
-
231
- if mask.dtype not in [np.uint8, np.bool_]:
232
- return False
233
-
234
- # Check if mask has any foreground pixels
235
- if np.sum(mask > 0) == 0:
236
- logger.warning("Mask has no foreground pixels")
237
- return False
238
-
239
- return True
240
-
241
- def get_mask_properties(self, mask: np.ndarray) -> dict:
242
- """
243
- Get properties of the segmentation mask.
244
-
245
- Args:
246
- mask: Binary mask
247
-
248
- Returns:
249
- Dictionary of mask properties
250
- """
251
- if not self.validate_mask(mask):
252
- return {}
253
-
254
- try:
255
- # Convert to binary
256
- binary_mask = (mask > 127).astype(np.uint8)
257
-
258
- # Calculate properties
259
- area = np.sum(binary_mask)
260
- perimeter = 0
261
-
262
- # Find contours
263
- contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
264
- if contours:
265
- perimeter = cv2.arcLength(contours[0], True)
266
-
267
- # Bounding box
268
- x, y, w, h = cv2.boundingRect(contours[0])
269
- bbox_area = w * h
270
- aspect_ratio = w / h if h > 0 else 0
271
- else:
272
- bbox_area = 0
273
- aspect_ratio = 0
274
-
275
- return {
276
- "area": int(area),
277
- "perimeter": float(perimeter),
278
- "bbox_area": int(bbox_area),
279
- "aspect_ratio": float(aspect_ratio),
280
- "coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0,
281
- "num_components": len(contours)
282
- }
283
-
284
- except Exception as e:
285
- logger.error(f"Mask property calculation failed: {e}")
286
- return {}
287
-
288
- def create_overlay(self, image: np.ndarray, mask: np.ndarray,
289
- color: Tuple[int, int, int] = (0, 255, 0),
290
- alpha: float = 0.5) -> np.ndarray:
291
- """
292
- Create overlay of mask on image.
293
-
294
- Args:
295
- image: Base image
296
- mask: Binary mask
297
- color: Overlay color (B, G, R)
298
- alpha: Overlay transparency
299
-
300
- Returns:
301
- Image with mask overlay
302
- """
303
- try:
304
- overlay = image.copy()
305
- overlay[mask == 255] = color
306
- return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
307
- except Exception as e:
308
- logger.error(f"Overlay creation failed: {e}")
309
- return image
 
1
  """
2
+ Minimal segmentation manager.
 
 
 
3
  """
4
 
5
  import numpy as np
 
8
  from PIL import Image
9
  from torchvision import transforms
10
  from transformers import AutoModelForImageSegmentation
11
+ from typing import Optional
12
  import logging
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
  class SegmentationManager:
18
+ """Minimal BRIA segmentation."""
19
 
20
+ def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto",
21
+ threshold: float = 0.5, trust_remote_code: bool = True,
22
+ cache_dir: Optional[str] = None, local_files_only: bool = False):
23
+ """Initialize segmentation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  self.model_name = model_name
25
  self.threshold = threshold
26
+ self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device
27
+
28
+ logger.info(f"Loading BRIA model: {model_name}")
29
+ self.model = AutoModelForImageSegmentation.from_pretrained(
30
+ model_name,
31
+ trust_remote_code=trust_remote_code,
32
+ cache_dir=cache_dir if cache_dir else None,
33
+ local_files_only=local_files_only,
34
+ ).eval().to(self.device)
35
+
36
+ self.transform = transforms.Compose([
37
+ transforms.Resize((1024, 1024)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
40
+ ])
41
+ logger.info("BRIA model loaded")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
44
+ """Segment image and return soft mask [0,1]."""
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
  rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
47
  pil_image = Image.fromarray(rgb_image)
48
  input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
49
+
50
  with torch.no_grad():
51
  preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
52
+
53
  original_size = (image.shape[1], image.shape[0])
54
  soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
55
  return np.clip(soft_mask, 0.0, 1.0)
56
  except Exception as e:
57
+ logger.error(f"Segmentation failed: {e}")
58
+ return np.zeros(image.shape[:2], dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wrapper.py CHANGED
@@ -3,9 +3,10 @@ from typing import Dict
3
  import shutil
4
  from PIL import Image
5
  import glob
6
- import tempfile
7
 
8
  from sorghum_pipeline.pipeline import SorghumPipeline
 
9
 
10
 
11
  def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
@@ -21,34 +22,38 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
21
  input_copy = work / Path(input_image_path).name
22
  shutil.copy(input_image_path, input_copy)
23
 
24
- # Initialize pipeline with config
25
- # adjust this if you have a YAML config file (e.g., "configs/demo.yaml")
26
- pipeline = SorghumPipeline(
27
- config_path=str(Path("sorghum_pipeline/config.py")),
28
- enable_occlusion_handling=False,
29
- enable_instance_integration=False
30
  )
 
31
 
32
- # Run the pipeline (single image, no frames, no SAM2Long)
33
- results = pipeline.run(
34
- load_all_frames=False,
35
- segmentation_only=False,
36
- run_instance_segmentation=False,
37
- features_frame_only=None
38
- )
39
 
40
  # Collect outputs
41
  outputs: Dict[str, str] = {}
42
 
43
- # Save original for reference
44
- original = work / "original.png"
45
- Image.open(input_copy).convert("RGB").save(original)
46
- outputs["Original"] = str(original)
47
-
48
- # Gather all PNG files created by OutputManager
49
- for f in glob.glob(str(work / "**/*.png"), recursive=True):
50
- name = Path(f).stem
51
- if name.lower() not in outputs: # avoid duplicate "Original"
52
- outputs[name] = f
 
 
 
 
 
 
53
 
54
  return outputs
 
3
  import shutil
4
  from PIL import Image
5
  import glob
6
+ import os
7
 
8
  from sorghum_pipeline.pipeline import SorghumPipeline
9
+ from sorghum_pipeline.config import Config, Paths
10
 
11
 
12
  def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
 
22
  input_copy = work / Path(input_image_path).name
23
  shutil.copy(input_image_path, input_copy)
24
 
25
+ # Build in-memory config pointing input/output to the working directory
26
+ cfg = Config()
27
+ cfg.paths = Paths(
28
+ input_folder=str(work),
29
+ output_folder=str(work),
30
+ boundingbox_dir=str(work)
31
  )
32
+ pipeline = SorghumPipeline(config=cfg)
33
 
34
+ # Run the pipeline (single image minimal demo)
35
+ os.environ['MINIMAL_DEMO'] = '1'
36
+ os.environ['FAST_OUTPUT'] = '1'
37
+ results = pipeline.run(single_image_path=str(input_copy))
 
 
 
38
 
39
  # Collect outputs
40
  outputs: Dict[str, str] = {}
41
 
42
+ # Return only the requested 7 images with fixed keys
43
+ wanted = [
44
+ work / 'Vegetation_indices_images/ndvi.png',
45
+ work / 'Vegetation_indices_images/ari.png',
46
+ work / 'Vegetation_indices_images/gndvi.png',
47
+ work / 'texture_output/lbp.png',
48
+ work / 'texture_output/hog.png',
49
+ work / 'texture_output/lacunarity.png',
50
+ work / 'results/size.size_analysis.png',
51
+ ]
52
+ labels = [
53
+ 'NDVI', 'ARI', 'GNDVI', 'LBP', 'HOG', 'Lacunarity', 'SizeAnalysis'
54
+ ]
55
+ for label, path in zip(labels, wanted):
56
+ if path.exists():
57
+ outputs[label] = str(path)
58
 
59
  return outputs