Fahimeh Orvati Nia commited on
Commit
91a7a12
·
1 Parent(s): dd1d7f5

minimal pipeline

Browse files
requirements.txt CHANGED
@@ -1,48 +1,12 @@
1
- # --- Core demo UI ---
2
  gradio
3
  pillow
4
-
5
- # --- Scientific / image processing ---
6
  numpy
7
- scipy
8
- matplotlib
9
- scikit-image
10
- opencv-python-headless
11
- tifffile
12
-
13
- # --- Machine learning / deep learning ---
14
  torch
15
  torchvision
16
- timm # for pretrained backbones
17
- segmentation-models-pytorch
18
- ultralytics # YOLO models (if you extend later)
19
-
20
- # --- Plant phenotyping ---
21
- plantcv==4.6
22
-
23
- # --- Data handling & utils ---
24
- pandas
25
- tqdm
26
- pyyaml
27
- joblib
28
-
29
- # --- Geometry / remote sensing ---
30
- shapely
31
- rasterio
32
- fiona
33
-
34
- # --- For morphology / texture analysis ---
35
  scikit-learn
36
- seaborn
37
- networkx
38
- skan # skeleton analysis
39
-
40
- # --- For model configs & logging ---
41
- omegaconf
42
- hydra-core
43
- loguru
44
-
45
- # --- Optional: segmentation research tools ---
46
- # (comment these out if not needed to reduce build time)
47
- segment-anything
48
- git+https://github.com/facebookresearch/segment-anything-2.git@2b90b9f5ceec907a1c18123530e92e794ad901a4
 
 
1
  gradio
2
  pillow
 
 
3
  numpy
4
+ opencv-python
 
 
 
 
 
 
5
  torch
6
  torchvision
7
+ transformers
8
+ scikit-image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  scikit-learn
10
+ scipy
11
+ matplotlib
12
+ plantcv
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/__init__.py CHANGED
@@ -1,31 +1,11 @@
1
  """
2
- Sorghum Plant Phenotyping Pipeline
3
-
4
- A comprehensive pipeline for analyzing sorghum plant images including:
5
- - Data loading and preprocessing
6
- - Image segmentation and masking
7
- - Feature extraction (texture, morphology, vegetation indices)
8
- - Results visualization and export
9
-
10
- Author: Fahime Horvatinia
11
- Version: 2.0.0
12
  """
13
 
14
  __version__ = "2.0.0"
15
- __author__ = "Fahime Horvatinia"
16
 
17
  from .pipeline import SorghumPipeline
18
  from .config import Config
19
- from .data import DataLoader
20
- from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
21
- from .output import OutputManager
22
 
23
- __all__ = [
24
- "SorghumPipeline",
25
- "Config",
26
- "DataLoader",
27
- "TextureExtractor",
28
- "VegetationIndexExtractor",
29
- "MorphologyExtractor",
30
- "OutputManager"
31
- ]
 
1
  """
2
+ Minimal Sorghum Plant Phenotyping Pipeline for Hugging Face Demo.
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  __version__ = "2.0.0"
6
+ __author__ = "Fahimeh Orvati Nia"
7
 
8
  from .pipeline import SorghumPipeline
9
  from .config import Config
 
 
 
10
 
11
+ __all__ = ["SorghumPipeline", "Config"]
 
 
 
 
 
 
 
 
sorghum_pipeline/config.py CHANGED
@@ -1,76 +1,42 @@
1
- """
2
- Minimal configuration for the Sorghum Pipeline.
3
- """
4
 
5
  import os
6
- from pathlib import Path
7
  from dataclasses import dataclass
8
 
9
 
10
  @dataclass
11
  class Paths:
12
- """Configuration for file paths."""
13
  input_folder: str
14
  output_folder: str
15
  boundingbox_dir: str = ""
16
 
17
  def __post_init__(self):
18
- """Ensure paths are absolute."""
19
  self.input_folder = os.path.abspath(self.input_folder)
20
  self.output_folder = os.path.abspath(self.output_folder)
21
 
22
 
23
- @dataclass
24
- class ProcessingParams:
25
- """Minimal processing parameters."""
26
- target_size: tuple = None
27
- min_component_area: int = 1000
28
- morphology_kernel_size: int = 7
29
- segmentation_threshold: float = 0.5
30
-
31
-
32
  @dataclass
33
  class OutputSettings:
34
  """Output settings."""
35
  save_images: bool = True
36
- save_plots: bool = False
37
- save_metadata: bool = False
38
  plot_dpi: int = 100
39
- segmentation_dir: str = "results"
40
- texture_dir: str = "texture_output"
41
- morphology_dir: str = "results"
42
- vegetation_dir: str = "Vegetation_indices_images"
43
-
44
-
45
- @dataclass
46
- class ModelSettings:
47
- """Model settings."""
48
- device: str = "auto"
49
- model_name: str = "briaai/RMBG-2.0"
50
- trust_remote_code: bool = True
51
- cache_dir: str = ""
52
- local_files_only: bool = False
53
 
54
 
55
  class Config:
56
- """Minimal configuration class."""
57
 
58
  def __init__(self):
59
- """Initialize with defaults."""
60
- self.paths = Paths(input_folder="", output_folder="", boundingbox_dir="")
61
- self.processing = ProcessingParams()
62
  self.output = OutputSettings()
63
- self.model = ModelSettings()
64
 
65
  def get_device(self) -> str:
66
- """Get processing device."""
67
- if self.model.device == "auto":
68
- import torch
69
- return "cuda" if torch.cuda.is_available() else "cpu"
70
- return self.model.device
71
 
72
  def validate(self) -> bool:
73
- """Validate configuration."""
74
  if self.paths.input_folder and not os.path.exists(self.paths.input_folder):
75
- raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
76
  return True
 
1
+ """Minimal configuration."""
 
 
2
 
3
  import os
 
4
  from dataclasses import dataclass
5
 
6
 
7
  @dataclass
8
  class Paths:
9
+ """File paths."""
10
  input_folder: str
11
  output_folder: str
12
  boundingbox_dir: str = ""
13
 
14
  def __post_init__(self):
 
15
  self.input_folder = os.path.abspath(self.input_folder)
16
  self.output_folder = os.path.abspath(self.output_folder)
17
 
18
 
 
 
 
 
 
 
 
 
 
19
  @dataclass
20
  class OutputSettings:
21
  """Output settings."""
22
  save_images: bool = True
 
 
23
  plot_dpi: int = 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  class Config:
27
+ """Minimal config."""
28
 
29
  def __init__(self):
30
+ self.paths = Paths(input_folder="", output_folder="")
 
 
31
  self.output = OutputSettings()
 
32
 
33
  def get_device(self) -> str:
34
+ """Get device."""
35
+ import torch
36
+ return "cuda" if torch.cuda.is_available() else "cpu"
 
 
37
 
38
  def validate(self) -> bool:
39
+ """Validate."""
40
  if self.paths.input_folder and not os.path.exists(self.paths.input_folder):
41
+ raise FileNotFoundError(f"Input folder not found: {self.paths.input_folder}")
42
  return True
sorghum_pipeline/data/__init__.py CHANGED
@@ -1,15 +1,6 @@
1
- """
2
- Data loading and preprocessing modules.
3
 
4
- This package contains all data-related functionality including:
5
- - Raw image loading
6
- - Data preprocessing
7
- - Mask handling
8
- - Data validation
9
- """
10
-
11
- from .loader import DataLoader
12
  from .preprocessor import ImagePreprocessor
13
  from .mask_handler import MaskHandler
14
 
15
- __all__ = ["DataLoader", "ImagePreprocessor", "MaskHandler"]
 
1
+ """Data preprocessing modules."""
 
2
 
 
 
 
 
 
 
 
 
3
  from .preprocessor import ImagePreprocessor
4
  from .mask_handler import MaskHandler
5
 
6
+ __all__ = ["ImagePreprocessor", "MaskHandler"]
sorghum_pipeline/data/loader.py CHANGED
@@ -1,444 +1,34 @@
1
  """
2
- Data loading functionality for the Sorghum Pipeline.
3
-
4
- This module handles loading raw images, managing plant data,
5
- and organizing data according to the pipeline requirements.
6
  """
7
 
8
- import os
9
- import glob
10
- import json
11
  from pathlib import Path
12
- from typing import Dict, List, Tuple, Optional, Any
13
  from PIL import Image
14
- import numpy as np
15
  import logging
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  class DataLoader:
21
- """Handles loading and organizing plant image data."""
22
-
23
- # Plants to ignore completely (empty by default)
24
- IGNORE_PLANTS = set()
25
-
26
- # Plants where you want exactly one frame from their own folder
27
- EXACT_FRAME = {
28
- 4: 7, 5: 5, 7: 5, 12: 5, 13: 5, 18: 7, 19: 2, 20: 3,
29
- 24: 6, 25: 5, 26: 5, 30: 8, 37: 7
30
- }
31
-
32
- # Plants where you want to borrow a frame from a different plant folder
33
- BORROW_FRAME = {
34
- 14: (13, 5), 15: (14, 5), 16: (15, 5), 33: (34, 7),
35
- 34: (35, 7), 35: (35, 8), 36: (36, 6)
36
- }
37
-
38
- # Overrides provided by user: preferred frame per target plant name
39
- FRAME_OVERRIDE_BY_NAME = {
40
- 'plant1': 9, 'plant2': 10, 'plant3': 9, 'plant5': 7, 'plant6': 9, 'plant8': 5,
41
- 'plant7': 9, 'plant10': 9, 'plant11': 9, 'plant12': 9,
42
- 'plant13': 10, 'plant14': 8, 'plant15': 11, 'plant19': 4, 'plant20': 7,
43
- 'plant21': 9, 'plant22': 10, 'plant25': 4, 'plant26': 2, 'plant27': 10, 'plant28': 9, 'plant29': 2,
44
- 'plant30': 9, 'plant31': 10, 'plant32': 9, 'plant33': 8,
45
- 'plant35': 9, 'plant36': 4, 'plant38': 9, 'plant39': 9, 'plant41': 9,
46
- 'plant42': 6, 'plant43': 10, 'plant44': 9, 'plant45': 7,
47
- 'plant47': 10, 'plant48': 11,
48
- }
49
-
50
- # Substitutes provided by user: map target plant name -> source plant name
51
- PLANT_SUBSTITUTES_BY_NAME = {
52
- 'plant16': 'plant15', 'plant15': 'plant14', 'plant14': 'plant13',
53
- 'plant13': 'plant12', 'plant33': 'plant34', 'plant34': 'plant35',
54
- 'plant24': 'plant25', 'plant25': 'plant25', 'plant35': 'plant36',
55
- 'plant36': 'plant37', 'plant37': 'plant37', 'plant44': 'plant43',
56
- 'plant45': 'plant44',
57
- }
58
 
59
- def __init__(self, input_folder: str, debug: bool = False, include_ignored: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
60
- """
61
- Initialize the data loader.
62
-
63
- Args:
64
- input_folder: Path to the input dataset folder
65
- debug: Enable debug logging
66
- """
67
  self.input_folder = Path(input_folder)
68
  self.debug = debug
69
- self.include_ignored = include_ignored
70
- self.strict_loader = strict_loader
71
 
72
  if not self.input_folder.exists():
73
  raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
74
- # Normalize excluded dates as a set of folder names (with dashes)
75
- self.excluded_dates = set(excluded_dates or [])
76
 
77
  def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
78
- """
79
- Load selected frames according to predefined rules.
80
- If strict_loader is True, load only frame numbers from the plant's own folder (no borrowing/special picks).
81
-
82
- Returns:
83
- Dictionary with plant data organized by key format: "YYYY_MM_DD_plantX_frameY"
84
- """
85
- logger.info("Loading selected frames from dataset...")
86
- plants = {}
87
-
88
- # Detect if input folder is a direct date folder (contains plant folders)
89
- first_items = list(self.input_folder.iterdir())
90
- has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
91
-
92
- def choose_frame_and_source(pid: int) -> Tuple[int, str]:
93
- if self.strict_loader:
94
- # In strict mode, honor explicit frame overrides AND substitution of source plant
95
- plant_name_local = f"plant{pid}"
96
- frame_num = self.FRAME_OVERRIDE_BY_NAME.get(
97
- plant_name_local,
98
- self.EXACT_FRAME.get(pid, 8)
99
- )
100
- source_plant = self.PLANT_SUBSTITUTES_BY_NAME.get(plant_name_local, plant_name_local)
101
- return frame_num, source_plant
102
- # Original behavior
103
- frame_num = self._get_frame_number(pid)
104
- source_plant = self._get_source_plant(pid)
105
- return frame_num, source_plant
106
-
107
- if has_plant_folders:
108
- # Direct date folder structure
109
- date_name = self.input_folder.name
110
- date_path = self.input_folder
111
- for plant_name in sorted(os.listdir(date_path)):
112
- plant_path = date_path / plant_name
113
- if not plant_path.is_dir():
114
- continue
115
- try:
116
- plant_id = int(plant_name.replace("plant", ""))
117
- except ValueError:
118
- continue
119
- if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
120
- if self.debug:
121
- logger.debug(f"Ignoring plant {plant_id}")
122
- continue
123
- frame_num, source_plant = choose_frame_and_source(plant_id)
124
- frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
125
- if frame_data:
126
- key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
127
- plants[key] = frame_data
128
- logger.debug(f"Loaded {key}")
129
- else:
130
- # Parent folder structure with date subfolders
131
- for date_name in sorted(os.listdir(self.input_folder)):
132
- date_path = self.input_folder / date_name
133
- if not date_path.is_dir():
134
- continue
135
- if date_name in self.excluded_dates:
136
- logger.info(f"Skipping excluded date: {date_name}")
137
- continue
138
- for plant_name in sorted(os.listdir(date_path)):
139
- plant_path = date_path / plant_name
140
- if not plant_path.is_dir():
141
- continue
142
- try:
143
- plant_id = int(plant_name.replace("plant", ""))
144
- except ValueError:
145
- continue
146
- if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
147
- if self.debug:
148
- logger.debug(f"Ignoring plant {plant_id}")
149
- continue
150
- frame_num, source_plant = choose_frame_and_source(plant_id)
151
- frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
152
- if frame_data:
153
- key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
154
- plants[key] = frame_data
155
- logger.debug(f"Loaded {key}")
156
-
157
- logger.info(f"Successfully loaded {len(plants)} plant frames")
158
- return plants
159
 
160
  def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
161
- """
162
- Load all available frames for each plant.
163
-
164
- Returns:
165
- Dictionary with all plant frames
166
- """
167
- logger.info("Loading all frames from dataset...")
168
- plants = {}
169
-
170
- # Check if we're directly in a date folder (contains plant folders)
171
- # or in a parent folder (contains date folders)
172
- first_items = list(self.input_folder.iterdir())
173
- has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
174
-
175
- if has_plant_folders:
176
- # We're directly in a date folder
177
- logger.info("Detected direct date folder structure")
178
- date_name = self.input_folder.name
179
- self._load_plants_from_date_folder(self.input_folder, date_name, plants)
180
- else:
181
- # We're in a parent folder with date subfolders
182
- logger.info("Detected parent folder structure")
183
- for date_name in sorted(os.listdir(self.input_folder)):
184
- date_path = self.input_folder / date_name
185
- if not date_path.is_dir():
186
- continue
187
- if date_name in self.excluded_dates:
188
- logger.info(f"Skipping excluded date: {date_name}")
189
- continue
190
-
191
- logger.info(f"Processing date: {date_name}")
192
- self._load_plants_from_date_folder(date_path, date_name, plants)
193
-
194
- logger.info(f"Successfully loaded {len(plants)} plant frames")
195
- return plants
196
-
197
- def _load_plants_from_date_folder(self, date_path: Path, date_name: str, plants: Dict[str, Dict[str, Any]]) -> None:
198
- """Load plants from a date folder."""
199
- for plant_name in sorted(os.listdir(date_path)):
200
- plant_path = date_path / plant_name
201
- if not plant_path.is_dir():
202
- continue
203
-
204
- # Extract plant ID
205
- try:
206
- plant_id = int(plant_name.replace("plant", ""))
207
- except ValueError:
208
- logger.warning(f"Could not extract plant ID from {plant_name}")
209
- continue
210
-
211
- # Skip ignored plants
212
- if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
213
- logger.info(f"Skipping ignored plant {plant_id}")
214
- continue
215
-
216
- logger.info(f"Processing plant {plant_id}")
217
-
218
- # Load all frames for this plant
219
- pattern = str(plant_path / f"{plant_name}_frame*.tif")
220
- frame_files = sorted(glob.glob(pattern))
221
- logger.info(f"Found {len(frame_files)} frame files for {plant_name}")
222
-
223
- for frame_path in frame_files:
224
- frame_data = self._load_frame_from_path(frame_path, plant_name)
225
- if frame_data:
226
- frame_id = Path(frame_path).stem.split("_frame")[-1]
227
- key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_id}"
228
- plants[key] = frame_data
229
- logger.debug(f"Loaded frame: {key}")
230
- else:
231
- logger.warning(f"Failed to load frame: {frame_path}")
232
-
233
- def load_single_plant(self, date: str, plant: str, frame: int) -> Optional[Dict[str, Any]]:
234
- """
235
- Load a specific plant frame.
236
-
237
- Args:
238
- date: Date string (e.g., "2025-02-05")
239
- plant: Plant name (e.g., "plant1")
240
- frame: Frame number
241
-
242
- Returns:
243
- Plant data dictionary or None if not found
244
- """
245
- date_path = self.input_folder / date
246
- if not date_path.exists():
247
- logger.error(f"Date folder not found: {date}")
248
- return None
249
-
250
- plant_path = date_path / plant
251
- if not plant_path.exists():
252
- logger.error(f"Plant folder not found: {plant}")
253
- return None
254
-
255
- filename = f"{plant}_frame{frame}.tif"
256
- frame_path = plant_path / filename
257
-
258
- return self._load_frame_from_path(str(frame_path), plant)
259
-
260
- def _get_frame_number(self, plant_id: int) -> int:
261
- """Get the frame number for a plant ID."""
262
- plant_name = f"plant{plant_id}"
263
- # Highest priority: explicit overrides by plant name
264
- if plant_name in self.FRAME_OVERRIDE_BY_NAME:
265
- return int(self.FRAME_OVERRIDE_BY_NAME[plant_name])
266
- # Next: original exact/borrrow rules
267
- if plant_id in self.EXACT_FRAME:
268
- return self.EXACT_FRAME[plant_id]
269
- elif plant_id in self.BORROW_FRAME:
270
- return self.BORROW_FRAME[plant_id][1]
271
- else:
272
- return 8 # Default frame
273
-
274
- def _get_source_plant(self, plant_id: int) -> str:
275
- """Get the source plant name for a plant ID."""
276
- plant_name = f"plant{plant_id}"
277
- # Highest priority: explicit substitutes by plant name
278
- if plant_name in self.PLANT_SUBSTITUTES_BY_NAME:
279
- return self.PLANT_SUBSTITUTES_BY_NAME[plant_name]
280
- # Next: original borrow rules
281
- if plant_id in self.BORROW_FRAME:
282
- source_id = self.BORROW_FRAME[plant_id][0]
283
- return f"plant{source_id}"
284
- else:
285
- return f"plant{plant_id}"
286
-
287
- def _load_single_frame(self, date_path: Path, source_plant: str,
288
- frame_num: int, target_plant: str) -> Optional[Dict[str, Any]]:
289
- """Load a single frame from the specified path."""
290
- filename = f"{source_plant}_frame{frame_num}.tif"
291
- frame_path = date_path / source_plant / filename
292
-
293
- if not frame_path.exists():
294
- if self.debug:
295
- logger.warning(f"Frame not found: {frame_path}")
296
- return None
297
-
298
- return self._load_frame_from_path(str(frame_path), target_plant)
299
-
300
- def _load_frame_from_path(self, frame_path: str, plant_name: str) -> Optional[Dict[str, Any]]:
301
- """Load frame data from a file path."""
302
- try:
303
- logger.debug(f"Attempting to load: {frame_path}")
304
- image = Image.open(frame_path)
305
- filename = Path(frame_path).name
306
- logger.debug(f"Successfully loaded image: {filename}, size: {image.size}")
307
-
308
- return {
309
- "raw_image": (image, filename),
310
- "plant_name": plant_name,
311
- "file_path": frame_path
312
- }
313
- except Exception as e:
314
- logger.error(f"Failed to load {frame_path}: {e}")
315
- return None
316
-
317
- def load_bounding_boxes(self, bbox_dir: str) -> Dict[str, Tuple[int, int, int, int]]:
318
- """
319
- Load bounding box data from JSON files.
320
-
321
- Args:
322
- bbox_dir: Directory containing bounding box JSON files
323
-
324
- Returns:
325
- Dictionary mapping plant names to bounding box coordinates
326
- """
327
- bbox_path = Path(bbox_dir)
328
- if not bbox_path.exists():
329
- raise FileNotFoundError(f"Bounding box directory not found: {bbox_dir}")
330
-
331
- bbox_lookup = {}
332
-
333
- for json_file in bbox_path.glob("*.json"):
334
- stem = json_file.stem
335
- # Normalize stems like plant_33_new -> plant33
336
- if stem.startswith('plant_'):
337
- parts = stem.split('_')
338
- try:
339
- idx = next(i for i,p in enumerate(parts) if p.isdigit())
340
- plant_id = f"plant{parts[idx]}"
341
- except Exception:
342
- plant_id = stem.replace('_', '')
343
- else:
344
- plant_id = stem
345
- try:
346
- with open(json_file, 'r') as f:
347
- data = json.load(f)
348
-
349
- shapes = data.get('shapes', [])
350
- # Prefer rectangle labeled 'sorghum' (case-insensitive), else first rectangle
351
- def _is_sorghum_label(s: dict) -> bool:
352
- for key in ('label', 'name', 'text'):
353
- val = s.get(key)
354
- if isinstance(val, str) and val.lower() == 'sorghum':
355
- return True
356
- return False
357
- rect = next((s for s in shapes if s.get('shape_type') == 'rectangle' and _is_sorghum_label(s)), None)
358
- if rect is None:
359
- rect = next((s for s in shapes if s.get('shape_type') == 'rectangle'), None)
360
-
361
- if rect:
362
- (x1, y1), (x2, y2) = rect['points']
363
- bbox_lookup[plant_id] = (
364
- int(max(0, x1)),
365
- int(max(0, y1)),
366
- int(min(1e9, x2)),
367
- int(min(1e9, y2))
368
- )
369
- else:
370
- bbox_lookup[plant_id] = None
371
-
372
- except Exception as e:
373
- logger.error(f"Failed to load bounding box {json_file}: {e}")
374
-
375
- logger.info(f"Loaded {len(bbox_lookup)} bounding boxes")
376
- return bbox_lookup
377
-
378
- def load_hand_labels(self, labels_dir: str) -> Dict[str, np.ndarray]:
379
- """
380
- Load hand-labeled masks from JSON files.
381
-
382
- Args:
383
- labels_dir: Directory containing label JSON files
384
-
385
- Returns:
386
- Dictionary mapping plant names to mask arrays
387
- """
388
- labels_path = Path(labels_dir)
389
- if not labels_path.exists():
390
- logger.warning(f"Labels directory not found: {labels_dir}")
391
- return {}
392
-
393
- masks = {}
394
-
395
- for json_file in labels_path.glob("*.json"):
396
- plant_id = json_file.stem
397
- try:
398
- with open(json_file, 'r') as f:
399
- data = json.load(f)
400
-
401
- # Create mask from shapes (assuming we have image dimensions)
402
- # This would need to be adapted based on your label format
403
- mask = self._create_mask_from_shapes(data)
404
- if mask is not None:
405
- masks[plant_id] = mask
406
-
407
- except Exception as e:
408
- logger.error(f"Failed to load label {json_file}: {e}")
409
-
410
- logger.info(f"Loaded {len(masks)} hand labels")
411
- return masks
412
-
413
- def _create_mask_from_shapes(self, data: Dict) -> Optional[np.ndarray]:
414
- """Create a mask array from shape data."""
415
- # This is a placeholder - implement based on your label format
416
- # For now, return None
417
- return None
418
-
419
- def validate_data(self, plants: Dict[str, Dict[str, Any]]) -> bool:
420
- """
421
- Validate loaded plant data.
422
-
423
- Args:
424
- plants: Dictionary of plant data
425
-
426
- Returns:
427
- True if data is valid, False otherwise
428
- """
429
- if not plants:
430
- logger.error("No plant data loaded")
431
- return False
432
-
433
- for key, data in plants.items():
434
- if "raw_image" not in data:
435
- logger.error(f"Missing raw_image in {key}")
436
- return False
437
-
438
- image, filename = data["raw_image"]
439
- if not isinstance(image, Image.Image):
440
- logger.error(f"Invalid image type in {key}")
441
- return False
442
-
443
- logger.info("Data validation passed")
444
- return True
 
1
  """
2
+ Minimal data loading (not used in single-image demo mode).
 
 
 
3
  """
4
 
 
 
 
5
  from pathlib import Path
6
+ from typing import Dict, List, Optional, Any
7
  from PIL import Image
 
8
  import logging
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
 
13
  class DataLoader:
14
+ """Minimal data loader (placeholder - not used in demo)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def __init__(self, input_folder: str, debug: bool = False,
17
+ include_ignored: bool = False, strict_loader: bool = False,
18
+ excluded_dates: Optional[List[str]] = None):
19
+ """Initialize data loader."""
 
 
 
 
20
  self.input_folder = Path(input_folder)
21
  self.debug = debug
 
 
22
 
23
  if not self.input_folder.exists():
24
  raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
 
 
25
 
26
  def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
27
+ """Load selected frames (not used in minimal demo)."""
28
+ logger.warning("DataLoader not used in minimal demo mode")
29
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
32
+ """Load all frames (not used in minimal demo)."""
33
+ logger.warning("DataLoader not used in minimal demo mode")
34
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sorghum_pipeline/data/mask_handler.py CHANGED
@@ -1,19 +1,13 @@
1
- """
2
- Minimal mask handling for the Sorghum Pipeline.
3
- """
4
 
5
  import numpy as np
6
  import cv2
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
 
11
 
12
  class MaskHandler:
13
- """Minimal mask handling."""
14
 
15
  def __init__(self, min_area: int = 1000, kernel_size: int = 7):
16
- """Initialize mask handler."""
17
  self.min_area = min_area
18
  self.kernel_size = kernel_size
19
 
@@ -22,7 +16,6 @@ class MaskHandler:
22
  if mask is None:
23
  return image
24
  if mask.shape[:2] != image.shape[:2]:
25
- mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]),
26
- interpolation=cv2.INTER_NEAREST)
27
- binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
28
  return cv2.bitwise_and(image, image, mask=binary)
 
1
+ """Minimal mask handling."""
 
 
2
 
3
  import numpy as np
4
  import cv2
 
 
 
5
 
6
 
7
  class MaskHandler:
8
+ """Minimal mask operations."""
9
 
10
  def __init__(self, min_area: int = 1000, kernel_size: int = 7):
 
11
  self.min_area = min_area
12
  self.kernel_size = kernel_size
13
 
 
16
  if mask is None:
17
  return image
18
  if mask.shape[:2] != image.shape[:2]:
19
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
20
+ binary = (mask > 0).astype(np.uint8) * 255
 
21
  return cv2.bitwise_and(image, image, mask=binary)
sorghum_pipeline/data/preprocessor.py CHANGED
@@ -1,26 +1,19 @@
1
- """
2
- Minimal image preprocessing for the Sorghum Pipeline.
3
- """
4
 
5
  import numpy as np
6
- import cv2
7
  from PIL import Image
8
  from typing import Dict, Tuple, Any
9
  from itertools import product
10
- import logging
11
-
12
- logger = logging.getLogger(__name__)
13
 
14
 
15
  class ImagePreprocessor:
16
- """Minimal image preprocessing."""
17
 
18
  def __init__(self, target_size=None):
19
- """Initialize preprocessor."""
20
  self.target_size = target_size
21
 
22
  def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
23
- """Convert array to uint8."""
24
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
25
  if arr.ptp() > 0:
26
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
@@ -29,38 +22,24 @@ class ImagePreprocessor:
29
  return np.clip(normalized, 0, 255).astype(np.uint8)
30
 
31
  def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
32
- """Process 4-band image into composite and spectral bands."""
33
  d = pil_img.size[0] // 2
34
- boxes = [
35
- (j, i, j + d, i + d)
36
- for i, j in product(range(0, pil_img.height, d), range(0, pil_img.width, d))
37
- ]
38
-
39
  stack = np.stack([np.array(pil_img.crop(box), dtype=float) for box in boxes], axis=-1)
40
  green, red, red_edge, nir = np.split(stack, 4, axis=-1)
41
 
42
- # Pseudo-RGB composite: (green, red_edge, red)
43
  composite = np.concatenate([green, red_edge, red], axis=-1)
44
  composite_uint8 = self.convert_to_uint8(composite)
45
 
46
- spectral_bands = {
47
- "green": green,
48
- "red": red,
49
- "red_edge": red_edge,
50
- "nir": nir
51
- }
52
-
53
  return composite_uint8, spectral_bands
54
 
55
  def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
56
- """Create composites for all plants."""
57
  for key, pdata in plants.items():
58
- try:
59
- if "raw_image" in pdata:
60
- image, _ = pdata["raw_image"]
61
- composite, spectral_stack = self.process_raw_image(image)
62
- pdata["composite"] = composite
63
- pdata["spectral_stack"] = spectral_stack
64
- except Exception as e:
65
- logger.error(f"Failed to create composite for {key}: {e}")
66
  return plants
 
1
+ """Minimal image preprocessing."""
 
 
2
 
3
  import numpy as np
 
4
  from PIL import Image
5
  from typing import Dict, Tuple, Any
6
  from itertools import product
 
 
 
7
 
8
 
9
  class ImagePreprocessor:
10
+ """Minimal preprocessor."""
11
 
12
  def __init__(self, target_size=None):
 
13
  self.target_size = target_size
14
 
15
  def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
16
+ """Convert to uint8."""
17
  arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
18
  if arr.ptp() > 0:
19
  normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
 
22
  return np.clip(normalized, 0, 255).astype(np.uint8)
23
 
24
  def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
25
+ """Process 4-band to composite + spectral."""
26
  d = pil_img.size[0] // 2
27
+ boxes = [(j, i, j + d, i + d) for i, j in product(range(0, pil_img.height, d), range(0, pil_img.width, d))]
 
 
 
 
28
  stack = np.stack([np.array(pil_img.crop(box), dtype=float) for box in boxes], axis=-1)
29
  green, red, red_edge, nir = np.split(stack, 4, axis=-1)
30
 
 
31
  composite = np.concatenate([green, red_edge, red], axis=-1)
32
  composite_uint8 = self.convert_to_uint8(composite)
33
 
34
+ spectral_bands = {"green": green, "red": red, "red_edge": red_edge, "nir": nir}
 
 
 
 
 
 
35
  return composite_uint8, spectral_bands
36
 
37
  def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
38
+ """Create composites."""
39
  for key, pdata in plants.items():
40
+ if "raw_image" in pdata:
41
+ image, _ = pdata["raw_image"]
42
+ composite, spectral_stack = self.process_raw_image(image)
43
+ pdata["composite"] = composite
44
+ pdata["spectral_stack"] = spectral_stack
 
 
 
45
  return plants
sorghum_pipeline/features/__init__.py CHANGED
@@ -1,21 +1,7 @@
1
- """
2
- Feature extraction modules for the Sorghum Pipeline.
3
-
4
- This package contains all feature extraction functionality including:
5
- - Texture features (LBP, HOG, Lacunarity, EHD)
6
- - Vegetation indices
7
- - Morphological features
8
- - Spectral features
9
- """
10
 
11
  from .texture import TextureExtractor
12
  from .vegetation import VegetationIndexExtractor
13
  from .morphology import MorphologyExtractor
14
- from .spectral import SpectralExtractor
15
 
16
- __all__ = [
17
- "TextureExtractor",
18
- "VegetationIndexExtractor",
19
- "MorphologyExtractor",
20
- "SpectralExtractor"
21
- ]
 
1
+ """Feature extraction modules."""
 
 
 
 
 
 
 
 
2
 
3
  from .texture import TextureExtractor
4
  from .vegetation import VegetationIndexExtractor
5
  from .morphology import MorphologyExtractor
 
6
 
7
+ __all__ = ["TextureExtractor", "VegetationIndexExtractor", "MorphologyExtractor"]
 
 
 
 
 
sorghum_pipeline/output/__init__.py CHANGED
@@ -1,13 +1,5 @@
1
- """
2
- Output management modules for the Sorghum Pipeline.
3
-
4
- This package contains output functionality including:
5
- - Result saving
6
- - Visualization generation
7
- - Report creation
8
- - Data export
9
- """
10
 
11
  from .manager import OutputManager
12
 
13
- __all__ = ["OutputManager"]
 
1
+ """Output management modules."""
 
 
 
 
 
 
 
 
2
 
3
  from .manager import OutputManager
4
 
5
+ __all__ = ["OutputManager"]
sorghum_pipeline/pipeline.py CHANGED
@@ -1,16 +1,12 @@
1
  """
2
- Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
3
-
4
- Minimal single-image version for Hugging Face demo.
5
  """
6
 
7
- import os
8
  import logging
9
  from pathlib import Path
10
- from typing import Dict, Any, Optional
11
  import numpy as np
12
  import cv2
13
- from sklearn.decomposition import PCA
14
 
15
  from .config import Config
16
  from .data import ImagePreprocessor, MaskHandler
@@ -22,223 +18,112 @@ logger = logging.getLogger(__name__)
22
 
23
 
24
  class SorghumPipeline:
25
- """Minimal pipeline for single-image plant phenotyping."""
26
 
27
  def __init__(self, config: Config):
28
- """Initialize the minimal pipeline."""
29
- self._setup_logging()
30
  self.config = config
31
  self.config.validate()
32
- self._initialize_components()
33
- logger.info("Sorghum Pipeline initialized")
34
-
35
- def _setup_logging(self):
36
- """Setup logging configuration."""
37
- logging.basicConfig(
38
- level=logging.INFO,
39
- format='%(asctime)s - %(levelname)s - %(message)s',
40
- handlers=[logging.StreamHandler()]
41
- )
42
-
43
- def _initialize_components(self):
44
- """Initialize pipeline components."""
45
- self.preprocessor = ImagePreprocessor(target_size=None)
46
- self.mask_handler = MaskHandler(min_area=1000, kernel_size=7)
47
  self.texture_extractor = TextureExtractor()
48
  self.vegetation_extractor = VegetationIndexExtractor()
49
  self.morphology_extractor = MorphologyExtractor()
50
  self.segmentation_manager = SegmentationManager(
51
- model_name="briaai/RMBG-2.0",
52
  device=self.config.get_device(),
53
- threshold=0.5,
54
  trust_remote_code=True
55
  )
56
  self.output_manager = OutputManager(
57
  output_folder=self.config.paths.output_folder,
58
  settings=self.config.output
59
  )
 
60
 
61
  def run(self, single_image_path: str) -> Dict[str, Any]:
62
- """
63
- Run minimal pipeline on single image.
64
 
65
- Args:
66
- single_image_path: Path to input image
67
-
68
- Returns:
69
- Dictionary containing results
70
- """
71
- logger.info("Starting minimal single-image pipeline...")
72
 
73
- try:
74
- import time
75
- from PIL import Image as _Image
76
-
77
- total_start = time.perf_counter()
78
-
79
- # Load single image
80
- _p = Path(single_image_path)
81
- _img = _Image.open(str(_p))
82
- plants = {
83
- "demo_demo_frame1": {
84
- "raw_image": (_img, _p.name),
85
- "plant_name": "demo",
86
- "file_path": str(_p)
87
- }
88
- }
89
-
90
- # Create composite
91
- plants = self.preprocessor.create_composites(plants)
92
-
93
- # Segment
94
- plants = self._segment_plants(plants)
95
-
96
- # Extract features
97
- plants = self._extract_features(plants)
98
-
99
- # Generate outputs
100
- self._generate_outputs(plants)
101
-
102
- # Summary
103
- summary = self._create_summary(plants)
104
-
105
- total_time = time.perf_counter() - total_start
106
- logger.info(f"Pipeline completed in {total_time:.2f}s")
107
-
108
- return {
109
- "plants": plants,
110
- "summary": summary,
111
- "config": self.config,
112
- "timing_seconds": total_time
113
  }
114
-
115
- except Exception as e:
116
- logger.error(f"Pipeline failed: {e}")
117
- raise
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def _segment_plants(self, plants: Dict[str, Any]) -> Dict[str, Any]:
120
- """Segment plants using BRIA model (full image)."""
121
  for key, pdata in plants.items():
122
- try:
123
- composite = pdata['composite']
124
- soft_mask = self.segmentation_manager.segment_image_soft(composite)
125
- pdata['soft_mask'] = soft_mask
126
- pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
127
- logger.info(f"Segmented {key}")
128
- except Exception as e:
129
- logger.error(f"Segmentation failed for {key}: {e}")
130
- pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
131
- pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
132
  return plants
133
 
134
  def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
135
- """Extract features from plants."""
136
  for key, pdata in plants.items():
137
- try:
138
- pdata['texture_features'] = self._extract_texture_features(pdata)
139
- pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
140
- pdata['morphology_features'] = self._extract_morphology_features(pdata)
141
- logger.info(f"Features extracted for {key}")
142
- except Exception as e:
143
- logger.error(f"Feature extraction failed for {key}: {e}")
144
- pdata['texture_features'] = {}
145
- pdata['vegetation_indices'] = {}
146
- pdata['morphology_features'] = {}
147
- return plants
148
-
149
- def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
150
- """Extract texture features from pseudo-color image only."""
151
- features = {}
152
- try:
153
- # Only process pseudo-color composite
154
  composite = pdata['composite']
155
  mask = pdata.get('mask')
156
- if mask is not None:
157
- masked = self.mask_handler.apply_mask_to_image(composite, mask)
158
- gray_image = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
159
- else:
160
- gray_image = cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
161
 
162
- band_features = self.texture_extractor.extract_all_texture_features(gray_image)
163
- stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
 
164
 
165
- features['color'] = {
166
- 'features': band_features,
167
- 'statistics': stats
168
- }
169
- except Exception as e:
170
- logger.error(f"Texture extraction failed: {e}")
171
- features['color'] = {'features': {}, 'statistics': {}}
 
 
172
 
173
- return features
174
 
175
- def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
176
- """Extract vegetation indices (NDVI, ARI, GNDVI only)."""
177
- try:
178
- spectral_stack = pdata.get('spectral_stack', {})
179
- mask = pdata.get('mask')
180
- if not spectral_stack or mask is None:
181
- return {}
182
 
183
- out: Dict[str, Any] = {}
184
- for name in ("NDVI", "ARI", "GNDVI"):
185
- bands = self.vegetation_extractor.index_bands.get(name, [])
186
- if not all(b in spectral_stack for b in bands):
187
- continue
188
- arrays = []
189
- for b in bands:
190
- arr = spectral_stack[b]
191
- if isinstance(arr, np.ndarray):
192
- arr = arr.squeeze(-1)
193
- arrays.append(np.asarray(arr, dtype=np.float64))
194
-
195
- values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
196
- binary_mask = (np.asarray(mask).astype(np.int32) > 0)
197
- masked_values = np.where(binary_mask, values, np.nan)
198
- valid = masked_values[~np.isnan(masked_values)]
199
-
200
- stats = {
201
- 'mean': float(np.mean(valid)) if valid.size else 0.0,
202
- 'std': float(np.std(valid)) if valid.size else 0.0,
203
- 'min': float(np.min(valid)) if valid.size else 0.0,
204
- 'max': float(np.max(valid)) if valid.size else 0.0,
205
- 'median': float(np.median(valid)) if valid.size else 0.0,
206
- }
207
- out[name] = {'values': masked_values, 'statistics': stats}
208
- return out
209
- except Exception as e:
210
- logger.error(f"Vegetation index extraction failed: {e}")
211
- return {}
212
-
213
- def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
214
- """Extract morphological features."""
215
- try:
216
- composite = pdata.get('composite')
217
- mask = pdata.get('mask')
218
- if composite is None or mask is None:
219
- return {}
220
- return self.morphology_extractor.extract_morphology_features(composite, mask)
221
- except Exception as e:
222
- logger.error(f"Morphology extraction failed: {e}")
223
- return {}
224
-
225
- def _generate_outputs(self, plants: Dict[str, Any]) -> None:
226
- """Generate output files."""
227
- self.output_manager.create_output_directories()
228
- for key, pdata in plants.items():
229
- try:
230
- self.output_manager.save_plant_results(key, pdata)
231
- except Exception as e:
232
- logger.error(f"Output generation failed for {key}: {e}")
233
-
234
- def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
235
- """Create summary of results."""
236
- return {
237
- "total_plants": len(plants),
238
- "successful_plants": sum(1 for p in plants.values() if p.get('texture_features')),
239
- "features_extracted": {
240
- "texture": sum(1 for p in plants.values() if p.get('texture_features')),
241
- "vegetation": sum(1 for p in plants.values() if p.get('vegetation_indices')),
242
- "morphology": sum(1 for p in plants.values() if p.get('morphology_features'))
243
  }
244
- }
 
 
1
  """
2
+ Minimal single-image pipeline for Hugging Face demo.
 
 
3
  """
4
 
 
5
  import logging
6
  from pathlib import Path
7
+ from typing import Dict, Any
8
  import numpy as np
9
  import cv2
 
10
 
11
  from .config import Config
12
  from .data import ImagePreprocessor, MaskHandler
 
18
 
19
 
20
  class SorghumPipeline:
21
+ """Minimal pipeline for single-image processing."""
22
 
23
  def __init__(self, config: Config):
24
+ """Initialize pipeline."""
25
+ logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
26
  self.config = config
27
  self.config.validate()
28
+
29
+ # Initialize components with defaults
30
+ self.preprocessor = ImagePreprocessor()
31
+ self.mask_handler = MaskHandler()
 
 
 
 
 
 
 
 
 
 
 
32
  self.texture_extractor = TextureExtractor()
33
  self.vegetation_extractor = VegetationIndexExtractor()
34
  self.morphology_extractor = MorphologyExtractor()
35
  self.segmentation_manager = SegmentationManager(
 
36
  device=self.config.get_device(),
 
37
  trust_remote_code=True
38
  )
39
  self.output_manager = OutputManager(
40
  output_folder=self.config.paths.output_folder,
41
  settings=self.config.output
42
  )
43
+ logger.info("Pipeline initialized")
44
 
45
  def run(self, single_image_path: str) -> Dict[str, Any]:
46
+ """Run pipeline on single image."""
47
+ logger.info("Processing single image...")
48
 
49
+ from PIL import Image
50
+ import time
 
 
 
 
 
51
 
52
+ start = time.perf_counter()
53
+
54
+ # Load image
55
+ img = Image.open(single_image_path)
56
+ plants = {
57
+ "demo": {
58
+ "raw_image": (img, Path(single_image_path).name),
59
+ "plant_name": "demo",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
+ }
62
+
63
+ # Process: composite → segment → features → save
64
+ plants = self.preprocessor.create_composites(plants)
65
+ plants = self._segment(plants)
66
+ plants = self._extract_features(plants)
67
+ self.output_manager.create_output_directories()
68
+
69
+ for key, pdata in plants.items():
70
+ self.output_manager.save_plant_results(key, pdata)
71
+
72
+ elapsed = time.perf_counter() - start
73
+ logger.info(f"Completed in {elapsed:.2f}s")
74
+
75
+ return {"plants": plants, "timing": elapsed}
76
 
77
+ def _segment(self, plants: Dict[str, Any]) -> Dict[str, Any]:
78
+ """Segment using BRIA."""
79
  for key, pdata in plants.items():
80
+ composite = pdata['composite']
81
+ soft_mask = self.segmentation_manager.segment_image_soft(composite)
82
+ pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
 
 
 
 
 
 
 
83
  return plants
84
 
85
  def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
86
+ """Extract texture, vegetation, and morphology features."""
87
  for key, pdata in plants.items():
88
+ # Texture: LBP, HOG, Lacunarity from pseudo-color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  composite = pdata['composite']
90
  mask = pdata.get('mask')
91
+ masked = self.mask_handler.apply_mask_to_image(composite, mask) if mask is not None else composite
92
+ gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
 
 
 
93
 
94
+ feats = self.texture_extractor.extract_all_texture_features(gray)
95
+ stats = self.texture_extractor.compute_texture_statistics(feats, mask)
96
+ pdata['texture_features'] = {'color': {'features': feats, 'statistics': stats}}
97
 
98
+ # Vegetation: NDVI, ARI, GNDVI
99
+ spectral = pdata.get('spectral_stack', {})
100
+ if spectral and mask is not None:
101
+ pdata['vegetation_indices'] = self._compute_vegetation(spectral, mask)
102
+ else:
103
+ pdata['vegetation_indices'] = {}
104
+
105
+ # Morphology: PlantCV size analysis
106
+ pdata['morphology_features'] = self.morphology_extractor.extract_morphology_features(composite, mask)
107
 
108
+ return plants
109
 
110
+ def _compute_vegetation(self, spectral: Dict[str, np.ndarray], mask: np.ndarray) -> Dict[str, Any]:
111
+ """Compute NDVI, ARI, GNDVI only."""
112
+ out = {}
113
+ for name in ("NDVI", "ARI", "GNDVI"):
114
+ bands = self.vegetation_extractor.index_bands.get(name, [])
115
+ if not all(b in spectral for b in bands):
116
+ continue
117
 
118
+ arrays = [np.asarray(spectral[b].squeeze(-1), dtype=np.float64) for b in bands]
119
+ values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
120
+ binary_mask = (mask > 0)
121
+ masked_values = np.where(binary_mask, values, np.nan)
122
+ valid = masked_values[~np.isnan(masked_values)]
123
+
124
+ stats = {
125
+ 'mean': float(np.mean(valid)) if valid.size else 0.0,
126
+ 'std': float(np.std(valid)) if valid.size else 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  }
128
+ out[name] = {'values': masked_values, 'statistics': stats}
129
+ return out
sorghum_pipeline/segmentation/__init__.py CHANGED
@@ -1,12 +1,5 @@
1
- """
2
- Segmentation modules for the Sorghum Pipeline.
3
-
4
- This package contains segmentation functionality including:
5
- - BRIA model integration
6
- - Mask post-processing
7
- - Segmentation validation
8
- """
9
 
10
  from .manager import SegmentationManager
11
 
12
- __all__ = ["SegmentationManager"]
 
1
+ """Segmentation modules."""
 
 
 
 
 
 
 
2
 
3
  from .manager import SegmentationManager
4
 
5
+ __all__ = ["SegmentationManager"]