Fahimeh Orvati Nia
commited on
Commit
·
dd1d7f5
1
Parent(s):
96f1578
make pipeline minimal
Browse files- app.py +7 -2
- sorghum_pipeline/config.py +25 -198
- sorghum_pipeline/data/mask_handler.py +9 -277
- sorghum_pipeline/data/preprocessor.py +15 -228
- sorghum_pipeline/features/morphology.py +25 -309
- sorghum_pipeline/features/texture.py +47 -320
- sorghum_pipeline/features/vegetation.py +16 -253
- sorghum_pipeline/output/manager.py +86 -631
- sorghum_pipeline/pipeline.py +121 -1254
- sorghum_pipeline/segmentation/manager.py +28 -279
- wrapper.py +29 -24
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import tempfile
|
|
|
|
| 3 |
from wrapper import run_pipeline_on_image
|
| 4 |
|
| 5 |
def process(image):
|
|
@@ -10,7 +11,11 @@ def process(image):
|
|
| 10 |
img_path = Path(tmpdir) / "input.png"
|
| 11 |
image.save(img_path)
|
| 12 |
outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
|
| 13 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
with gr.Blocks() as demo:
|
| 16 |
gr.Markdown("# 🌿 Sorghum Single-Image Demo")
|
|
@@ -20,4 +25,4 @@ with gr.Blocks() as demo:
|
|
| 20 |
run.click(process, inputs=inp, outputs=gallery)
|
| 21 |
|
| 22 |
if __name__ == "__main__":
|
| 23 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import tempfile
|
| 3 |
+
from pathlib import Path
|
| 4 |
from wrapper import run_pipeline_on_image
|
| 5 |
|
| 6 |
def process(image):
|
|
|
|
| 11 |
img_path = Path(tmpdir) / "input.png"
|
| 12 |
image.save(img_path)
|
| 13 |
outputs = run_pipeline_on_image(str(img_path), tmpdir, save_artifacts=True)
|
| 14 |
+
# Keep order consistent: return exactly the 7 images
|
| 15 |
+
order = [
|
| 16 |
+
'NDVI', 'ARI', 'GNDVI', 'LBP', 'HOG', 'Lacunarity', 'SizeAnalysis'
|
| 17 |
+
]
|
| 18 |
+
return [outputs[k] for k in order if k in outputs]
|
| 19 |
|
| 20 |
with gr.Blocks() as demo:
|
| 21 |
gr.Markdown("# 🌿 Sorghum Single-Image Demo")
|
|
|
|
| 25 |
run.click(process, inputs=inp, outputs=gallery)
|
| 26 |
|
| 27 |
if __name__ == "__main__":
|
| 28 |
+
demo.launch()
|
sorghum_pipeline/config.py
CHANGED
|
@@ -1,249 +1,76 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles all configuration settings, paths, and parameters
|
| 5 |
-
used throughout the pipeline.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
| 9 |
-
import yaml
|
| 10 |
from pathlib import Path
|
| 11 |
-
from
|
| 12 |
-
from dataclasses import dataclass, field
|
| 13 |
|
| 14 |
|
| 15 |
@dataclass
|
| 16 |
class Paths:
|
| 17 |
-
"""Configuration for
|
| 18 |
input_folder: str
|
| 19 |
output_folder: str
|
| 20 |
-
boundingbox_dir:
|
| 21 |
-
labels_folder: Optional[str] = None
|
| 22 |
|
| 23 |
def __post_init__(self):
|
| 24 |
-
"""Ensure
|
| 25 |
self.input_folder = os.path.abspath(self.input_folder)
|
| 26 |
self.output_folder = os.path.abspath(self.output_folder)
|
| 27 |
-
if self.boundingbox_dir:
|
| 28 |
-
self.boundingbox_dir = os.path.abspath(self.boundingbox_dir)
|
| 29 |
-
if self.labels_folder:
|
| 30 |
-
self.labels_folder = os.path.abspath(self.labels_folder)
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
| 34 |
class ProcessingParams:
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
target_size: tuple = (1024, 1024)
|
| 38 |
-
gaussian_blur_kernel: int = 5
|
| 39 |
-
morphology_kernel_size: int = 7
|
| 40 |
min_component_area: int = 1000
|
| 41 |
-
|
| 42 |
-
# Segmentation
|
| 43 |
segmentation_threshold: float = 0.5
|
| 44 |
-
max_components: int = 10
|
| 45 |
-
|
| 46 |
-
# Texture analysis
|
| 47 |
-
lbp_points: int = 8
|
| 48 |
-
lbp_radius: int = 1
|
| 49 |
-
hog_orientations: int = 9
|
| 50 |
-
hog_pixels_per_cell: tuple = (8, 8)
|
| 51 |
-
hog_cells_per_block: tuple = (2, 2)
|
| 52 |
-
lacunarity_window: int = 15
|
| 53 |
-
ehd_threshold: float = 0.3
|
| 54 |
-
angle_resolution: int = 45
|
| 55 |
-
|
| 56 |
-
# Vegetation indices
|
| 57 |
-
epsilon: float = 1e-10
|
| 58 |
-
soil_factor: float = 0.16
|
| 59 |
-
|
| 60 |
-
# Morphology
|
| 61 |
-
pixel_to_cm: float = 0.1099609375
|
| 62 |
-
prune_sizes: list = field(default_factory=lambda: [200, 100, 50, 30, 10])
|
| 63 |
|
| 64 |
|
| 65 |
@dataclass
|
| 66 |
class OutputSettings:
|
| 67 |
-
"""
|
| 68 |
save_images: bool = True
|
| 69 |
-
save_plots: bool =
|
| 70 |
-
save_metadata: bool =
|
| 71 |
-
image_dpi: int = 150
|
| 72 |
plot_dpi: int = 100
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
features_dir: str = "features"
|
| 78 |
-
texture_dir: str = "texture"
|
| 79 |
-
morphology_dir: str = "morphology"
|
| 80 |
-
vegetation_dir: str = "vegetation_indices"
|
| 81 |
-
analysis_dir: str = "analysis"
|
| 82 |
|
| 83 |
|
| 84 |
@dataclass
|
| 85 |
class ModelSettings:
|
| 86 |
-
"""
|
| 87 |
-
device: str = "auto"
|
| 88 |
model_name: str = "briaai/RMBG-2.0"
|
| 89 |
-
batch_size: int = 1
|
| 90 |
trust_remote_code: bool = True
|
| 91 |
cache_dir: str = ""
|
| 92 |
local_files_only: bool = False
|
| 93 |
|
| 94 |
|
| 95 |
class Config:
|
| 96 |
-
"""
|
| 97 |
|
| 98 |
-
def __init__(self
|
| 99 |
-
"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
config_path: Path to YAML configuration file. If None, uses defaults.
|
| 104 |
-
"""
|
| 105 |
-
self.paths = Paths(
|
| 106 |
-
input_folder="",
|
| 107 |
-
output_folder="",
|
| 108 |
-
boundingbox_dir=""
|
| 109 |
-
)
|
| 110 |
self.processing = ProcessingParams()
|
| 111 |
self.output = OutputSettings()
|
| 112 |
self.model = ModelSettings()
|
| 113 |
-
|
| 114 |
-
if config_path:
|
| 115 |
-
self.load_from_file(config_path)
|
| 116 |
-
|
| 117 |
-
def load_from_file(self, config_path: str) -> None:
|
| 118 |
-
"""Load configuration from YAML file."""
|
| 119 |
-
config_path = Path(config_path)
|
| 120 |
-
if not config_path.exists():
|
| 121 |
-
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 122 |
-
|
| 123 |
-
with open(config_path, 'r') as f:
|
| 124 |
-
config_data = yaml.safe_load(f)
|
| 125 |
-
|
| 126 |
-
# Update paths
|
| 127 |
-
if 'paths' in config_data:
|
| 128 |
-
self.paths = Paths(**config_data['paths'])
|
| 129 |
-
|
| 130 |
-
# Update processing parameters
|
| 131 |
-
if 'processing' in config_data:
|
| 132 |
-
for key, value in config_data['processing'].items():
|
| 133 |
-
if hasattr(self.processing, key):
|
| 134 |
-
setattr(self.processing, key, value)
|
| 135 |
-
|
| 136 |
-
# Update output settings
|
| 137 |
-
if 'output' in config_data:
|
| 138 |
-
for key, value in config_data['output'].items():
|
| 139 |
-
if hasattr(self.output, key):
|
| 140 |
-
setattr(self.output, key, value)
|
| 141 |
-
|
| 142 |
-
# Update model settings
|
| 143 |
-
if 'model' in config_data:
|
| 144 |
-
for key, value in config_data['model'].items():
|
| 145 |
-
if hasattr(self.model, key):
|
| 146 |
-
setattr(self.model, key, value)
|
| 147 |
-
|
| 148 |
-
def save_to_file(self, config_path: str) -> None:
|
| 149 |
-
"""Save current configuration to YAML file."""
|
| 150 |
-
config_data = {
|
| 151 |
-
'paths': {
|
| 152 |
-
'input_folder': self.paths.input_folder,
|
| 153 |
-
'output_folder': self.paths.output_folder,
|
| 154 |
-
'boundingbox_dir': self.paths.boundingbox_dir,
|
| 155 |
-
'labels_folder': self.paths.labels_folder
|
| 156 |
-
},
|
| 157 |
-
'processing': {
|
| 158 |
-
'target_size': self.processing.target_size,
|
| 159 |
-
'gaussian_blur_kernel': self.processing.gaussian_blur_kernel,
|
| 160 |
-
'morphology_kernel_size': self.processing.morphology_kernel_size,
|
| 161 |
-
'min_component_area': self.processing.min_component_area,
|
| 162 |
-
'segmentation_threshold': self.processing.segmentation_threshold,
|
| 163 |
-
'max_components': self.processing.max_components,
|
| 164 |
-
'lbp_points': self.processing.lbp_points,
|
| 165 |
-
'lbp_radius': self.processing.lbp_radius,
|
| 166 |
-
'hog_orientations': self.processing.hog_orientations,
|
| 167 |
-
'hog_pixels_per_cell': self.processing.hog_pixels_per_cell,
|
| 168 |
-
'hog_cells_per_block': self.processing.hog_cells_per_block,
|
| 169 |
-
'lacunarity_window': self.processing.lacunarity_window,
|
| 170 |
-
'ehd_threshold': self.processing.ehd_threshold,
|
| 171 |
-
'angle_resolution': self.processing.angle_resolution,
|
| 172 |
-
'epsilon': self.processing.epsilon,
|
| 173 |
-
'soil_factor': self.processing.soil_factor,
|
| 174 |
-
'pixel_to_cm': self.processing.pixel_to_cm,
|
| 175 |
-
'prune_sizes': self.processing.prune_sizes
|
| 176 |
-
},
|
| 177 |
-
'output': {
|
| 178 |
-
'save_images': self.output.save_images,
|
| 179 |
-
'save_plots': self.output.save_plots,
|
| 180 |
-
'save_metadata': self.output.save_metadata,
|
| 181 |
-
'image_dpi': self.output.image_dpi,
|
| 182 |
-
'plot_dpi': self.output.plot_dpi,
|
| 183 |
-
'image_format': self.output.image_format,
|
| 184 |
-
'segmentation_dir': self.output.segmentation_dir,
|
| 185 |
-
'features_dir': self.output.features_dir,
|
| 186 |
-
'texture_dir': self.output.texture_dir,
|
| 187 |
-
'morphology_dir': self.output.morphology_dir,
|
| 188 |
-
'vegetation_dir': self.output.vegetation_dir,
|
| 189 |
-
'analysis_dir': self.output.analysis_dir
|
| 190 |
-
},
|
| 191 |
-
'model': {
|
| 192 |
-
'device': self.model.device,
|
| 193 |
-
'model_name': self.model.model_name,
|
| 194 |
-
'batch_size': self.model.batch_size,
|
| 195 |
-
'trust_remote_code': self.model.trust_remote_code,
|
| 196 |
-
'cache_dir': self.model.cache_dir,
|
| 197 |
-
'local_files_only': self.model.local_files_only,
|
| 198 |
-
}
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
with open(config_path, 'w') as f:
|
| 202 |
-
yaml.dump(config_data, f, default_flow_style=False, indent=2)
|
| 203 |
|
| 204 |
def get_device(self) -> str:
|
| 205 |
-
"""Get
|
| 206 |
if self.model.device == "auto":
|
| 207 |
import torch
|
| 208 |
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 209 |
return self.model.device
|
| 210 |
|
| 211 |
-
def create_output_directories(self, base_path: str) -> None:
|
| 212 |
-
"""Ensure base output directory exists only.
|
| 213 |
-
|
| 214 |
-
Subdirectories are created per plant in the output manager.
|
| 215 |
-
"""
|
| 216 |
-
base_path = Path(base_path)
|
| 217 |
-
base_path.mkdir(parents=True, exist_ok=True)
|
| 218 |
-
|
| 219 |
def validate(self) -> bool:
|
| 220 |
-
"""Validate configuration
|
| 221 |
-
|
| 222 |
-
if not os.path.exists(self.paths.input_folder):
|
| 223 |
raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
|
| 224 |
-
|
| 225 |
-
# Check if bounding box directory exists (optional)
|
| 226 |
-
if hasattr(self.paths, 'boundingbox_dir') and self.paths.boundingbox_dir and not os.path.exists(self.paths.boundingbox_dir):
|
| 227 |
-
raise FileNotFoundError(f"Bounding box directory does not exist: {self.paths.boundingbox_dir}")
|
| 228 |
-
|
| 229 |
-
# Validate processing parameters
|
| 230 |
-
if self.processing.target_size[0] <= 0 or self.processing.target_size[1] <= 0:
|
| 231 |
-
raise ValueError("Target size must be positive")
|
| 232 |
-
|
| 233 |
-
if self.processing.segmentation_threshold < 0 or self.processing.segmentation_threshold > 1:
|
| 234 |
-
raise ValueError("Segmentation threshold must be between 0 and 1")
|
| 235 |
-
|
| 236 |
-
return True
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def create_default_config(output_path: str) -> None:
|
| 240 |
-
"""Create a default configuration file."""
|
| 241 |
-
config = Config()
|
| 242 |
-
config.paths = Paths(
|
| 243 |
-
input_folder="Sorghum_dataset",
|
| 244 |
-
output_folder="Sorghum_pipeline_Results",
|
| 245 |
-
boundingbox_dir="boundingbox",
|
| 246 |
-
labels_folder="labels"
|
| 247 |
-
)
|
| 248 |
-
config.save_to_file(output_path)
|
| 249 |
-
print(f"Default configuration created at: {output_path}")
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal configuration for the Sorghum Pipeline.
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
from dataclasses import dataclass
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class Paths:
|
| 12 |
+
"""Configuration for file paths."""
|
| 13 |
input_folder: str
|
| 14 |
output_folder: str
|
| 15 |
+
boundingbox_dir: str = ""
|
|
|
|
| 16 |
|
| 17 |
def __post_init__(self):
|
| 18 |
+
"""Ensure paths are absolute."""
|
| 19 |
self.input_folder = os.path.abspath(self.input_folder)
|
| 20 |
self.output_folder = os.path.abspath(self.output_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class ProcessingParams:
|
| 25 |
+
"""Minimal processing parameters."""
|
| 26 |
+
target_size: tuple = None
|
|
|
|
|
|
|
|
|
|
| 27 |
min_component_area: int = 1000
|
| 28 |
+
morphology_kernel_size: int = 7
|
|
|
|
| 29 |
segmentation_threshold: float = 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class OutputSettings:
|
| 34 |
+
"""Output settings."""
|
| 35 |
save_images: bool = True
|
| 36 |
+
save_plots: bool = False
|
| 37 |
+
save_metadata: bool = False
|
|
|
|
| 38 |
plot_dpi: int = 100
|
| 39 |
+
segmentation_dir: str = "results"
|
| 40 |
+
texture_dir: str = "texture_output"
|
| 41 |
+
morphology_dir: str = "results"
|
| 42 |
+
vegetation_dir: str = "Vegetation_indices_images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
@dataclass
|
| 46 |
class ModelSettings:
|
| 47 |
+
"""Model settings."""
|
| 48 |
+
device: str = "auto"
|
| 49 |
model_name: str = "briaai/RMBG-2.0"
|
|
|
|
| 50 |
trust_remote_code: bool = True
|
| 51 |
cache_dir: str = ""
|
| 52 |
local_files_only: bool = False
|
| 53 |
|
| 54 |
|
| 55 |
class Config:
|
| 56 |
+
"""Minimal configuration class."""
|
| 57 |
|
| 58 |
+
def __init__(self):
|
| 59 |
+
"""Initialize with defaults."""
|
| 60 |
+
self.paths = Paths(input_folder="", output_folder="", boundingbox_dir="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self.processing = ProcessingParams()
|
| 62 |
self.output = OutputSettings()
|
| 63 |
self.model = ModelSettings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def get_device(self) -> str:
|
| 66 |
+
"""Get processing device."""
|
| 67 |
if self.model.device == "auto":
|
| 68 |
import torch
|
| 69 |
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
return self.model.device
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def validate(self) -> bool:
|
| 73 |
+
"""Validate configuration."""
|
| 74 |
+
if self.paths.input_folder and not os.path.exists(self.paths.input_folder):
|
|
|
|
| 75 |
raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
|
| 76 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/data/mask_handler.py
CHANGED
|
@@ -1,296 +1,28 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles mask creation, processing, and validation
|
| 5 |
-
for plant segmentation tasks.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
| 10 |
-
from typing import Dict, Tuple, Optional, List
|
| 11 |
import logging
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
|
| 16 |
class MaskHandler:
|
| 17 |
-
"""
|
| 18 |
|
| 19 |
def __init__(self, min_area: int = 1000, kernel_size: int = 7):
|
| 20 |
-
"""
|
| 21 |
-
Initialize the mask handler.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
min_area: Minimum area for connected components
|
| 25 |
-
kernel_size: Kernel size for morphological operations
|
| 26 |
-
"""
|
| 27 |
self.min_area = min_area
|
| 28 |
self.kernel_size = kernel_size
|
| 29 |
|
| 30 |
-
def create_bounding_box_mask(self, image_shape: Tuple[int, int],
|
| 31 |
-
bbox: Tuple[int, int, int, int]) -> np.ndarray:
|
| 32 |
-
"""
|
| 33 |
-
Create a mask from bounding box coordinates.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
image_shape: Shape of the image (height, width)
|
| 37 |
-
bbox: Bounding box coordinates (x1, y1, x2, y2)
|
| 38 |
-
|
| 39 |
-
Returns:
|
| 40 |
-
Binary mask array
|
| 41 |
-
"""
|
| 42 |
-
h, w = image_shape[:2]
|
| 43 |
-
mask = np.zeros((h, w), dtype=np.uint8)
|
| 44 |
-
|
| 45 |
-
x1, y1, x2, y2 = bbox
|
| 46 |
-
# Clamp coordinates to image bounds
|
| 47 |
-
x1 = max(0, min(w, x1))
|
| 48 |
-
y1 = max(0, min(h, y1))
|
| 49 |
-
x2 = max(0, min(w, x2))
|
| 50 |
-
y2 = max(0, min(h, y2))
|
| 51 |
-
|
| 52 |
-
mask[y1:y2, x1:x2] = 255
|
| 53 |
-
return mask
|
| 54 |
-
|
| 55 |
-
def preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 56 |
-
"""
|
| 57 |
-
Preprocess mask by cleaning and filtering.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
mask: Input mask
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
Cleaned mask
|
| 64 |
-
"""
|
| 65 |
-
if mask is None:
|
| 66 |
-
return None
|
| 67 |
-
|
| 68 |
-
# Convert to binary if needed
|
| 69 |
-
if isinstance(mask, tuple):
|
| 70 |
-
mask = mask[0]
|
| 71 |
-
|
| 72 |
-
# Ensure binary format
|
| 73 |
-
mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
|
| 74 |
-
|
| 75 |
-
# Morphological opening to remove noise
|
| 76 |
-
kernel = cv2.getStructuringElement(
|
| 77 |
-
cv2.MORPH_ELLIPSE,
|
| 78 |
-
(self.kernel_size, self.kernel_size)
|
| 79 |
-
)
|
| 80 |
-
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 81 |
-
|
| 82 |
-
# Remove small connected components
|
| 83 |
-
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
| 84 |
-
opened, connectivity=8
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
clean_mask = np.zeros_like(opened)
|
| 88 |
-
for label in range(1, num_labels): # Skip background
|
| 89 |
-
if stats[label, cv2.CC_STAT_AREA] >= self.min_area:
|
| 90 |
-
clean_mask[labels == label] = 255
|
| 91 |
-
|
| 92 |
-
return clean_mask
|
| 93 |
-
|
| 94 |
-
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
|
| 95 |
-
"""
|
| 96 |
-
Keep only the largest connected component in the mask.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
mask: Input mask
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
Mask with only the largest component
|
| 103 |
-
"""
|
| 104 |
-
if mask is None:
|
| 105 |
-
return None
|
| 106 |
-
|
| 107 |
-
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
|
| 108 |
-
|
| 109 |
-
if num_labels <= 1:
|
| 110 |
-
return mask
|
| 111 |
-
|
| 112 |
-
# Find the largest component (excluding background)
|
| 113 |
-
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 114 |
-
largest_label = 1 + np.argmax(areas)
|
| 115 |
-
|
| 116 |
-
# Create mask with only the largest component
|
| 117 |
-
largest_mask = (labels == largest_label).astype(np.uint8) * 255
|
| 118 |
-
|
| 119 |
-
return largest_mask
|
| 120 |
-
|
| 121 |
def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 122 |
-
"""
|
| 123 |
-
Apply mask to image.
|
| 124 |
-
|
| 125 |
-
Args:
|
| 126 |
-
image: Input image
|
| 127 |
-
mask: Binary mask
|
| 128 |
-
|
| 129 |
-
Returns:
|
| 130 |
-
Masked image
|
| 131 |
-
"""
|
| 132 |
if mask is None:
|
| 133 |
return image
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
alpha: float = 0.5) -> np.ndarray:
|
| 140 |
-
"""
|
| 141 |
-
Create overlay of mask on image.
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
image: Base image
|
| 145 |
-
mask: Binary mask
|
| 146 |
-
color: Overlay color (B, G, R)
|
| 147 |
-
alpha: Overlay transparency
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
Image with mask overlay
|
| 151 |
-
"""
|
| 152 |
-
overlay = image.copy()
|
| 153 |
-
overlay[mask == 255] = color
|
| 154 |
-
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
|
| 155 |
-
|
| 156 |
-
def get_mask_properties(self, mask: np.ndarray) -> Dict[str, float]:
|
| 157 |
-
"""
|
| 158 |
-
Get properties of the mask.
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
mask: Binary mask
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
Dictionary of mask properties
|
| 165 |
-
"""
|
| 166 |
-
if mask is None:
|
| 167 |
-
return {}
|
| 168 |
-
|
| 169 |
-
# Convert to binary
|
| 170 |
-
binary_mask = (mask > 127).astype(np.uint8)
|
| 171 |
-
|
| 172 |
-
# Calculate properties
|
| 173 |
-
area = np.sum(binary_mask)
|
| 174 |
-
perimeter = cv2.arcLength(
|
| 175 |
-
cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0],
|
| 176 |
-
True
|
| 177 |
-
) if len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]) > 0 else 0
|
| 178 |
-
|
| 179 |
-
# Bounding box
|
| 180 |
-
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 181 |
-
if contours:
|
| 182 |
-
x, y, w, h = cv2.boundingRect(contours[0])
|
| 183 |
-
bbox_area = w * h
|
| 184 |
-
aspect_ratio = w / h if h > 0 else 0
|
| 185 |
-
else:
|
| 186 |
-
bbox_area = 0
|
| 187 |
-
aspect_ratio = 0
|
| 188 |
-
|
| 189 |
-
return {
|
| 190 |
-
"area": float(area),
|
| 191 |
-
"perimeter": float(perimeter),
|
| 192 |
-
"bbox_area": float(bbox_area),
|
| 193 |
-
"aspect_ratio": float(aspect_ratio),
|
| 194 |
-
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
def validate_mask(self, mask: np.ndarray) -> bool:
|
| 198 |
-
"""
|
| 199 |
-
Validate mask format and content.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
mask: Mask to validate
|
| 203 |
-
|
| 204 |
-
Returns:
|
| 205 |
-
True if valid, False otherwise
|
| 206 |
-
"""
|
| 207 |
-
if mask is None:
|
| 208 |
-
return False
|
| 209 |
-
|
| 210 |
-
if not isinstance(mask, np.ndarray):
|
| 211 |
-
return False
|
| 212 |
-
|
| 213 |
-
if mask.ndim != 2:
|
| 214 |
-
return False
|
| 215 |
-
|
| 216 |
-
if mask.dtype not in [np.uint8, np.bool_]:
|
| 217 |
-
return False
|
| 218 |
-
|
| 219 |
-
# Check if mask has any foreground pixels
|
| 220 |
-
if np.sum(mask > 0) == 0:
|
| 221 |
-
logger.warning("Mask has no foreground pixels")
|
| 222 |
-
return False
|
| 223 |
-
|
| 224 |
-
return True
|
| 225 |
-
|
| 226 |
-
def resize_mask(self, mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
| 227 |
-
"""
|
| 228 |
-
Resize mask to target size.
|
| 229 |
-
|
| 230 |
-
Args:
|
| 231 |
-
mask: Input mask
|
| 232 |
-
target_size: Target size (width, height)
|
| 233 |
-
|
| 234 |
-
Returns:
|
| 235 |
-
Resized mask
|
| 236 |
-
"""
|
| 237 |
-
if mask is None:
|
| 238 |
-
return None
|
| 239 |
-
|
| 240 |
-
return cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
|
| 241 |
-
|
| 242 |
-
def dilate_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 243 |
-
"""
|
| 244 |
-
Dilate mask to expand foreground regions.
|
| 245 |
-
|
| 246 |
-
Args:
|
| 247 |
-
mask: Input mask
|
| 248 |
-
kernel_size: Size of dilation kernel
|
| 249 |
-
|
| 250 |
-
Returns:
|
| 251 |
-
Dilated mask
|
| 252 |
-
"""
|
| 253 |
-
if mask is None:
|
| 254 |
-
return None
|
| 255 |
-
|
| 256 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 257 |
-
return cv2.dilate(mask, kernel, iterations=1)
|
| 258 |
-
|
| 259 |
-
def erode_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 260 |
-
"""
|
| 261 |
-
Erode mask to shrink foreground regions.
|
| 262 |
-
|
| 263 |
-
Args:
|
| 264 |
-
mask: Input mask
|
| 265 |
-
kernel_size: Size of erosion kernel
|
| 266 |
-
|
| 267 |
-
Returns:
|
| 268 |
-
Eroded mask
|
| 269 |
-
"""
|
| 270 |
-
if mask is None:
|
| 271 |
-
return None
|
| 272 |
-
|
| 273 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 274 |
-
return cv2.erode(mask, kernel, iterations=1)
|
| 275 |
-
|
| 276 |
-
def fill_holes(self, mask: np.ndarray) -> np.ndarray:
|
| 277 |
-
"""
|
| 278 |
-
Fill holes in the mask.
|
| 279 |
-
|
| 280 |
-
Args:
|
| 281 |
-
mask: Input mask
|
| 282 |
-
|
| 283 |
-
Returns:
|
| 284 |
-
Mask with filled holes
|
| 285 |
-
"""
|
| 286 |
-
if mask is None:
|
| 287 |
-
return None
|
| 288 |
-
|
| 289 |
-
# Find contours
|
| 290 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 291 |
-
|
| 292 |
-
# Create filled mask
|
| 293 |
-
filled_mask = np.zeros_like(mask)
|
| 294 |
-
cv2.fillPoly(filled_mask, contours, 255)
|
| 295 |
-
|
| 296 |
-
return filled_mask
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal mask handling for the Sorghum Pipeline.
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import cv2
|
|
|
|
| 7 |
import logging
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
|
| 12 |
class MaskHandler:
|
| 13 |
+
"""Minimal mask handling."""
|
| 14 |
|
| 15 |
def __init__(self, min_area: int = 1000, kernel_size: int = 7):
|
| 16 |
+
"""Initialize mask handler."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
self.min_area = min_area
|
| 18 |
self.kernel_size = kernel_size
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 21 |
+
"""Apply mask to image."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if mask is None:
|
| 23 |
return image
|
| 24 |
+
if mask.shape[:2] != image.shape[:2]:
|
| 25 |
+
mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]),
|
| 26 |
+
interpolation=cv2.INTER_NEAREST)
|
| 27 |
+
binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
|
| 28 |
+
return cv2.bitwise_and(image, image, mask=binary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/data/preprocessor.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles image preprocessing, composite creation,
|
| 5 |
-
and basic image transformations.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
| 10 |
from PIL import Image
|
| 11 |
-
from typing import Dict, Tuple, Any
|
| 12 |
from itertools import product
|
| 13 |
import logging
|
| 14 |
|
|
@@ -16,72 +13,36 @@ logger = logging.getLogger(__name__)
|
|
| 16 |
|
| 17 |
|
| 18 |
class ImagePreprocessor:
|
| 19 |
-
"""
|
| 20 |
|
| 21 |
-
def __init__(self, target_size
|
| 22 |
-
"""
|
| 23 |
-
Initialize the image preprocessor.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
target_size: Target size for image resizing (width, height)
|
| 27 |
-
"""
|
| 28 |
self.target_size = target_size
|
| 29 |
|
| 30 |
def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 31 |
-
"""
|
| 32 |
-
Convert array to uint8 format with proper normalization.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
arr: Input array
|
| 36 |
-
|
| 37 |
-
Returns:
|
| 38 |
-
Normalized uint8 array
|
| 39 |
-
"""
|
| 40 |
-
# Handle NaN and infinite values
|
| 41 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 42 |
-
|
| 43 |
-
# Normalize to 0-255 range
|
| 44 |
if arr.ptp() > 0:
|
| 45 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 46 |
else:
|
| 47 |
normalized = np.zeros_like(arr)
|
| 48 |
-
|
| 49 |
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 50 |
|
| 51 |
def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
|
| 52 |
-
"""
|
| 53 |
-
Process raw 4-band image into composite and spectral bands.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
pil_img: PIL Image object containing 4-band data
|
| 57 |
-
|
| 58 |
-
Returns:
|
| 59 |
-
Tuple of (composite_image, spectral_bands_dict)
|
| 60 |
-
"""
|
| 61 |
-
# Split the 4-band RAW into tiles and stack them
|
| 62 |
d = pil_img.size[0] // 2
|
| 63 |
boxes = [
|
| 64 |
(j, i, j + d, i + d)
|
| 65 |
-
for i, j in product(
|
| 66 |
-
range(0, pil_img.height, d),
|
| 67 |
-
range(0, pil_img.width, d)
|
| 68 |
-
)
|
| 69 |
]
|
| 70 |
|
| 71 |
-
|
| 72 |
-
stack = np.stack([
|
| 73 |
-
np.array(pil_img.crop(box), dtype=float)
|
| 74 |
-
for box in boxes
|
| 75 |
-
], axis=-1)
|
| 76 |
-
|
| 77 |
-
# Bands come in order: [green, red, red_edge, nir]
|
| 78 |
green, red, red_edge, nir = np.split(stack, 4, axis=-1)
|
| 79 |
|
| 80 |
-
#
|
| 81 |
composite = np.concatenate([green, red_edge, red], axis=-1)
|
| 82 |
composite_uint8 = self.convert_to_uint8(composite)
|
| 83 |
|
| 84 |
-
# Prepare spectral stack
|
| 85 |
spectral_bands = {
|
| 86 |
"green": green,
|
| 87 |
"red": red,
|
|
@@ -92,188 +53,14 @@ class ImagePreprocessor:
|
|
| 92 |
return composite_uint8, spectral_bands
|
| 93 |
|
| 94 |
def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
| 95 |
-
"""
|
| 96 |
-
Create composites for all plants in the dataset.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
plants: Dictionary of plant data
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
Updated plant data with composites and spectral stacks
|
| 103 |
-
"""
|
| 104 |
-
logger.info("Creating composites for all plants...")
|
| 105 |
-
|
| 106 |
for key, pdata in plants.items():
|
| 107 |
try:
|
| 108 |
-
# Find the PIL Image
|
| 109 |
if "raw_image" in pdata:
|
| 110 |
image, _ = pdata["raw_image"]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
logger.warning(f"No raw image found for {key}")
|
| 115 |
-
continue
|
| 116 |
-
|
| 117 |
-
# Process the image
|
| 118 |
-
composite, spectral_stack = self.process_raw_image(image)
|
| 119 |
-
|
| 120 |
-
# Store results
|
| 121 |
-
pdata["composite"] = composite
|
| 122 |
-
pdata["spectral_stack"] = spectral_stack
|
| 123 |
-
|
| 124 |
-
logger.debug(f"Created composite for {key}")
|
| 125 |
-
|
| 126 |
except Exception as e:
|
| 127 |
logger.error(f"Failed to create composite for {key}: {e}")
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
logger.info("Composite creation completed")
|
| 131 |
-
return plants
|
| 132 |
-
|
| 133 |
-
def resize_image(self, image: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
|
| 134 |
-
"""
|
| 135 |
-
Resize image to target size.
|
| 136 |
-
|
| 137 |
-
Args:
|
| 138 |
-
image: Input image
|
| 139 |
-
target_size: Target size (width, height). If None, uses self.target_size
|
| 140 |
-
|
| 141 |
-
Returns:
|
| 142 |
-
Resized image
|
| 143 |
-
"""
|
| 144 |
-
if target_size is None:
|
| 145 |
-
target_size = self.target_size
|
| 146 |
-
|
| 147 |
-
if target_size is None:
|
| 148 |
-
return image
|
| 149 |
-
|
| 150 |
-
return cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
| 151 |
-
|
| 152 |
-
def normalize_image(self, image: np.ndarray, method: str = "minmax") -> np.ndarray:
|
| 153 |
-
"""
|
| 154 |
-
Normalize image using specified method.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
image: Input image
|
| 158 |
-
method: Normalization method ("minmax", "zscore", "robust")
|
| 159 |
-
|
| 160 |
-
Returns:
|
| 161 |
-
Normalized image
|
| 162 |
-
"""
|
| 163 |
-
if method == "minmax":
|
| 164 |
-
if image.dtype == np.uint8:
|
| 165 |
-
return image.astype(np.float32) / 255.0
|
| 166 |
-
else:
|
| 167 |
-
img_min, img_max = image.min(), image.max()
|
| 168 |
-
if img_max > img_min:
|
| 169 |
-
return (image - img_min) / (img_max - img_min)
|
| 170 |
-
else:
|
| 171 |
-
return np.zeros_like(image, dtype=np.float32)
|
| 172 |
-
|
| 173 |
-
elif method == "zscore":
|
| 174 |
-
mean, std = image.mean(), image.std()
|
| 175 |
-
if std > 0:
|
| 176 |
-
return (image - mean) / std
|
| 177 |
-
else:
|
| 178 |
-
return np.zeros_like(image, dtype=np.float32)
|
| 179 |
-
|
| 180 |
-
elif method == "robust":
|
| 181 |
-
q25, q75 = np.percentile(image, [25, 75])
|
| 182 |
-
if q75 > q25:
|
| 183 |
-
return (image - q25) / (q75 - q25)
|
| 184 |
-
else:
|
| 185 |
-
return np.zeros_like(image, dtype=np.float32)
|
| 186 |
-
|
| 187 |
-
else:
|
| 188 |
-
raise ValueError(f"Unknown normalization method: {method}")
|
| 189 |
-
|
| 190 |
-
def apply_gaussian_blur(self, image: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 191 |
-
"""
|
| 192 |
-
Apply Gaussian blur to image.
|
| 193 |
-
|
| 194 |
-
Args:
|
| 195 |
-
image: Input image
|
| 196 |
-
kernel_size: Size of Gaussian kernel
|
| 197 |
-
|
| 198 |
-
Returns:
|
| 199 |
-
Blurred image
|
| 200 |
-
"""
|
| 201 |
-
if kernel_size % 2 == 0:
|
| 202 |
-
kernel_size += 1
|
| 203 |
-
|
| 204 |
-
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
| 205 |
-
|
| 206 |
-
def apply_sharpening(self, image: np.ndarray) -> np.ndarray:
|
| 207 |
-
"""
|
| 208 |
-
Apply sharpening filter to image.
|
| 209 |
-
|
| 210 |
-
Args:
|
| 211 |
-
image: Input image
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
Sharpened image
|
| 215 |
-
"""
|
| 216 |
-
kernel = np.array([
|
| 217 |
-
[0, -1, 0],
|
| 218 |
-
[-1, 5, -1],
|
| 219 |
-
[0, -1, 0]
|
| 220 |
-
])
|
| 221 |
-
|
| 222 |
-
return cv2.filter2D(image, -1, kernel)
|
| 223 |
-
|
| 224 |
-
def enhance_contrast(self, image: np.ndarray, alpha: float = 1.2, beta: int = 15) -> np.ndarray:
|
| 225 |
-
"""
|
| 226 |
-
Enhance image contrast.
|
| 227 |
-
|
| 228 |
-
Args:
|
| 229 |
-
image: Input image
|
| 230 |
-
alpha: Contrast control (1.0 = no change)
|
| 231 |
-
beta: Brightness control (0 = no change)
|
| 232 |
-
|
| 233 |
-
Returns:
|
| 234 |
-
Enhanced image
|
| 235 |
-
"""
|
| 236 |
-
return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
|
| 237 |
-
|
| 238 |
-
def create_overlay(self, base_image: np.ndarray, mask: np.ndarray,
|
| 239 |
-
color: Tuple[int, int, int] = (0, 255, 0),
|
| 240 |
-
alpha: float = 0.5) -> np.ndarray:
|
| 241 |
-
"""
|
| 242 |
-
Create overlay of mask on base image.
|
| 243 |
-
|
| 244 |
-
Args:
|
| 245 |
-
base_image: Base image
|
| 246 |
-
mask: Binary mask
|
| 247 |
-
color: Overlay color (B, G, R)
|
| 248 |
-
alpha: Overlay transparency
|
| 249 |
-
|
| 250 |
-
Returns:
|
| 251 |
-
Image with overlay
|
| 252 |
-
"""
|
| 253 |
-
overlay = base_image.copy()
|
| 254 |
-
overlay[mask == 255] = color
|
| 255 |
-
return cv2.addWeighted(base_image, 1.0 - alpha, overlay, alpha, 0)
|
| 256 |
-
|
| 257 |
-
def validate_composite(self, composite: np.ndarray) -> bool:
|
| 258 |
-
"""
|
| 259 |
-
Validate composite image.
|
| 260 |
-
|
| 261 |
-
Args:
|
| 262 |
-
composite: Composite image to validate
|
| 263 |
-
|
| 264 |
-
Returns:
|
| 265 |
-
True if valid, False otherwise
|
| 266 |
-
"""
|
| 267 |
-
if composite is None:
|
| 268 |
-
return False
|
| 269 |
-
|
| 270 |
-
if not isinstance(composite, np.ndarray):
|
| 271 |
-
return False
|
| 272 |
-
|
| 273 |
-
if composite.ndim != 3 or composite.shape[2] != 3:
|
| 274 |
-
return False
|
| 275 |
-
|
| 276 |
-
if composite.dtype != np.uint8:
|
| 277 |
-
return False
|
| 278 |
-
|
| 279 |
-
return True
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal image preprocessing for the Sorghum Pipeline.
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import cv2
|
| 7 |
from PIL import Image
|
| 8 |
+
from typing import Dict, Tuple, Any
|
| 9 |
from itertools import product
|
| 10 |
import logging
|
| 11 |
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class ImagePreprocessor:
|
| 16 |
+
"""Minimal image preprocessing."""
|
| 17 |
|
| 18 |
+
def __init__(self, target_size=None):
|
| 19 |
+
"""Initialize preprocessor."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.target_size = target_size
|
| 21 |
|
| 22 |
def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 23 |
+
"""Convert array to uint8."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
|
|
|
| 25 |
if arr.ptp() > 0:
|
| 26 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 27 |
else:
|
| 28 |
normalized = np.zeros_like(arr)
|
|
|
|
| 29 |
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 30 |
|
| 31 |
def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
|
| 32 |
+
"""Process 4-band image into composite and spectral bands."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
d = pil_img.size[0] // 2
|
| 34 |
boxes = [
|
| 35 |
(j, i, j + d, i + d)
|
| 36 |
+
for i, j in product(range(0, pil_img.height, d), range(0, pil_img.width, d))
|
|
|
|
|
|
|
|
|
|
| 37 |
]
|
| 38 |
|
| 39 |
+
stack = np.stack([np.array(pil_img.crop(box), dtype=float) for box in boxes], axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
green, red, red_edge, nir = np.split(stack, 4, axis=-1)
|
| 41 |
|
| 42 |
+
# Pseudo-RGB composite: (green, red_edge, red)
|
| 43 |
composite = np.concatenate([green, red_edge, red], axis=-1)
|
| 44 |
composite_uint8 = self.convert_to_uint8(composite)
|
| 45 |
|
|
|
|
| 46 |
spectral_bands = {
|
| 47 |
"green": green,
|
| 48 |
"red": red,
|
|
|
|
| 53 |
return composite_uint8, spectral_bands
|
| 54 |
|
| 55 |
def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
| 56 |
+
"""Create composites for all plants."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
for key, pdata in plants.items():
|
| 58 |
try:
|
|
|
|
| 59 |
if "raw_image" in pdata:
|
| 60 |
image, _ = pdata["raw_image"]
|
| 61 |
+
composite, spectral_stack = self.process_raw_image(image)
|
| 62 |
+
pdata["composite"] = composite
|
| 63 |
+
pdata["spectral_stack"] = spectral_stack
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
logger.error(f"Failed to create composite for {key}: {e}")
|
| 66 |
+
return plants
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/features/morphology.py
CHANGED
|
@@ -1,44 +1,32 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles extraction of morphological features using PlantCV
|
| 5 |
-
and other computer vision techniques.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
| 10 |
import contextlib
|
| 11 |
import sys
|
| 12 |
-
from typing import Dict, Any,
|
| 13 |
import logging
|
| 14 |
|
| 15 |
-
# Try to import PlantCV, but don't fail if not available
|
| 16 |
try:
|
| 17 |
from plantcv import plantcv as pcv
|
| 18 |
PLANT_CV_AVAILABLE = True
|
| 19 |
except ImportError:
|
| 20 |
PLANT_CV_AVAILABLE = False
|
| 21 |
-
logger.warning("PlantCV not available. Morphological features will be limited.")
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
class MorphologyExtractor:
|
| 27 |
-
"""
|
| 28 |
|
| 29 |
def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
|
| 30 |
-
"""
|
| 31 |
-
Initialize morphology extractor.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
pixel_to_cm: Conversion factor from pixels to centimeters
|
| 35 |
-
prune_sizes: List of pruning sizes for skeleton processing
|
| 36 |
-
"""
|
| 37 |
self.pixel_to_cm = pixel_to_cm
|
| 38 |
self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
|
| 39 |
|
| 40 |
if PLANT_CV_AVAILABLE:
|
| 41 |
-
# Configure PlantCV
|
| 42 |
pcv.params.debug = None
|
| 43 |
pcv.params.text_size = 0.7
|
| 44 |
pcv.params.text_thickness = 2
|
|
@@ -46,283 +34,53 @@ class MorphologyExtractor:
|
|
| 46 |
pcv.params.dpi = 100
|
| 47 |
|
| 48 |
def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 49 |
-
"""
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
Returns:
|
| 57 |
-
Dictionary containing morphological features and images
|
| 58 |
-
"""
|
| 59 |
-
features = {
|
| 60 |
-
'traits': {},
|
| 61 |
-
'images': {},
|
| 62 |
-
'success': False
|
| 63 |
-
}
|
| 64 |
|
| 65 |
try:
|
| 66 |
-
# Preprocess mask
|
| 67 |
clean_mask = self._preprocess_mask(mask)
|
| 68 |
if clean_mask is None:
|
| 69 |
-
logger.warning("Failed to preprocess mask")
|
| 70 |
return features
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# Fallback to basic OpenCV features
|
| 83 |
-
cv_features = self._extract_opencv_features(image, clean_mask)
|
| 84 |
-
features['traits'].update(cv_features['traits'])
|
| 85 |
-
features['images'].update(cv_features['images'])
|
| 86 |
-
|
| 87 |
-
features['success'] = True
|
| 88 |
-
logger.debug("Morphological features extracted successfully")
|
| 89 |
|
| 90 |
except Exception as e:
|
| 91 |
-
logger.error(f"
|
| 92 |
|
| 93 |
return features
|
| 94 |
|
| 95 |
-
def _preprocess_mask(self, mask: np.ndarray) ->
|
| 96 |
-
"""Preprocess mask
|
| 97 |
if mask is None:
|
| 98 |
return None
|
| 99 |
-
|
| 100 |
-
# Convert to binary if needed
|
| 101 |
-
if isinstance(mask, tuple):
|
| 102 |
-
mask = mask[0]
|
| 103 |
-
|
| 104 |
-
# Ensure binary format
|
| 105 |
mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
|
| 106 |
-
|
| 107 |
-
# Morphological opening to remove noise
|
| 108 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 109 |
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 110 |
|
| 111 |
-
# Remove small connected components
|
| 112 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
|
| 113 |
clean_mask = np.zeros_like(opened)
|
| 114 |
|
| 115 |
-
for label in range(1, num_labels):
|
| 116 |
if stats[label, cv2.CC_STAT_AREA] >= 1000:
|
| 117 |
clean_mask[labels == label] = 255
|
| 118 |
|
| 119 |
return clean_mask
|
| 120 |
|
| 121 |
-
def _extract_basic_features(self, mask: np.ndarray) -> Dict[str, float]:
|
| 122 |
-
"""Extract basic morphological features using OpenCV."""
|
| 123 |
-
features = {}
|
| 124 |
-
|
| 125 |
-
try:
|
| 126 |
-
# Find contours
|
| 127 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 128 |
-
|
| 129 |
-
if not contours:
|
| 130 |
-
return features
|
| 131 |
-
|
| 132 |
-
# Get the largest contour
|
| 133 |
-
largest_contour = max(contours, key=cv2.contourArea)
|
| 134 |
-
|
| 135 |
-
# Basic measurements
|
| 136 |
-
area = cv2.contourArea(largest_contour)
|
| 137 |
-
perimeter = cv2.arcLength(largest_contour, True)
|
| 138 |
-
|
| 139 |
-
# Bounding box
|
| 140 |
-
x, y, w, h = cv2.boundingRect(largest_contour)
|
| 141 |
-
bbox_area = w * h
|
| 142 |
-
|
| 143 |
-
# Ellipse fitting
|
| 144 |
-
if len(largest_contour) >= 5:
|
| 145 |
-
ellipse = cv2.fitEllipse(largest_contour)
|
| 146 |
-
(center, axes, angle) = ellipse
|
| 147 |
-
major_axis = max(axes)
|
| 148 |
-
minor_axis = min(axes)
|
| 149 |
-
else:
|
| 150 |
-
major_axis = max(w, h)
|
| 151 |
-
minor_axis = min(w, h)
|
| 152 |
-
|
| 153 |
-
# Convert to centimeters
|
| 154 |
-
features['area_cm2'] = area * (self.pixel_to_cm ** 2)
|
| 155 |
-
features['perimeter_cm'] = perimeter * self.pixel_to_cm
|
| 156 |
-
features['width_cm'] = w * self.pixel_to_cm
|
| 157 |
-
features['height_cm'] = h * self.pixel_to_cm
|
| 158 |
-
features['bbox_area_cm2'] = bbox_area * (self.pixel_to_cm ** 2)
|
| 159 |
-
features['major_axis_cm'] = major_axis * self.pixel_to_cm
|
| 160 |
-
features['minor_axis_cm'] = minor_axis * self.pixel_to_cm
|
| 161 |
-
features['aspect_ratio'] = w / h if h > 0 else 0
|
| 162 |
-
features['elongation'] = major_axis / minor_axis if minor_axis > 0 else 0
|
| 163 |
-
features['circularity'] = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
|
| 164 |
-
features['solidity'] = area / bbox_area if bbox_area > 0 else 0
|
| 165 |
-
|
| 166 |
-
# Convex hull
|
| 167 |
-
hull = cv2.convexHull(largest_contour)
|
| 168 |
-
hull_area = cv2.contourArea(hull)
|
| 169 |
-
features['convexity'] = area / hull_area if hull_area > 0 else 0
|
| 170 |
-
|
| 171 |
-
except Exception as e:
|
| 172 |
-
logger.error(f"Basic feature extraction failed: {e}")
|
| 173 |
-
|
| 174 |
-
return features
|
| 175 |
-
|
| 176 |
-
def _extract_skeleton_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 177 |
-
"""Extract skeleton-based features using PlantCV."""
|
| 178 |
-
features = {'traits': {}, 'images': {}}
|
| 179 |
-
|
| 180 |
-
if not PLANT_CV_AVAILABLE:
|
| 181 |
-
return features
|
| 182 |
-
|
| 183 |
-
try:
|
| 184 |
-
# Suppress PlantCV output
|
| 185 |
-
with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
|
| 186 |
-
contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
|
| 187 |
-
|
| 188 |
-
# Skeletonize
|
| 189 |
-
skeleton = pcv.morphology.skeletonize(mask=mask)
|
| 190 |
-
features['images']['skeleton'] = skeleton
|
| 191 |
-
|
| 192 |
-
# Prune skeleton
|
| 193 |
-
pruned_skel = skeleton
|
| 194 |
-
for size in self.prune_sizes:
|
| 195 |
-
pruned_skel, _, _ = pcv.morphology.prune(
|
| 196 |
-
skel_img=pruned_skel, size=size, mask=mask
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
features['images']['pruned_skeleton'] = pruned_skel
|
| 200 |
-
|
| 201 |
-
# Find branch points and tips
|
| 202 |
-
branch_pts = pcv.morphology.find_branch_pts(pruned_skel, mask)
|
| 203 |
-
features['images']['branch_points'] = branch_pts
|
| 204 |
-
|
| 205 |
-
try:
|
| 206 |
-
tip_pts = pcv.morphology.find_tips(pruned_skel, mask)
|
| 207 |
-
features['images']['tip_points'] = tip_pts
|
| 208 |
-
except Exception as e:
|
| 209 |
-
logger.warning(f"Tip detection failed: {e}")
|
| 210 |
-
|
| 211 |
-
# Segment objects
|
| 212 |
-
try:
|
| 213 |
-
leaf_obj, stem_obj = pcv.morphology.segment_sort(
|
| 214 |
-
pruned_skel, [], mask
|
| 215 |
-
)
|
| 216 |
-
features['traits']['num_leaves'] = len(leaf_obj)
|
| 217 |
-
features['traits']['num_stems'] = len(stem_obj)
|
| 218 |
-
except Exception as e:
|
| 219 |
-
logger.warning(f"Object segmentation failed: {e}")
|
| 220 |
-
features['traits']['num_leaves'] = 0
|
| 221 |
-
features['traits']['num_stems'] = 0
|
| 222 |
-
|
| 223 |
-
# Size analysis
|
| 224 |
-
try:
|
| 225 |
-
labeled_mask, n_labels = pcv.create_labels(mask)
|
| 226 |
-
size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
|
| 227 |
-
features['images']['size_analysis'] = size_analysis
|
| 228 |
-
|
| 229 |
-
# Get size traits
|
| 230 |
-
obs = pcv.outputs.observations.get("default_1", {})
|
| 231 |
-
for trait, info in obs.items():
|
| 232 |
-
if trait not in ["in_bounds", "object_in_frame"]:
|
| 233 |
-
val = info.get("value", None)
|
| 234 |
-
if val is not None:
|
| 235 |
-
if trait == "area":
|
| 236 |
-
val = val * (self.pixel_to_cm ** 2)
|
| 237 |
-
elif trait in ["perimeter", "width", "height", "longest_path",
|
| 238 |
-
"ellipse_major_axis", "ellipse_minor_axis"]:
|
| 239 |
-
val = val * self.pixel_to_cm
|
| 240 |
-
features['traits'][trait] = val
|
| 241 |
-
|
| 242 |
-
except Exception as e:
|
| 243 |
-
logger.warning(f"Size analysis failed: {e}")
|
| 244 |
-
|
| 245 |
-
except Exception as e:
|
| 246 |
-
logger.error(f"Skeleton feature extraction failed: {e}")
|
| 247 |
-
|
| 248 |
-
return features
|
| 249 |
-
|
| 250 |
-
def _extract_opencv_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 251 |
-
"""Extract features using only OpenCV (fallback when PlantCV is not available)."""
|
| 252 |
-
features = {'traits': {}, 'images': {}}
|
| 253 |
-
|
| 254 |
-
try:
|
| 255 |
-
# Create skeleton using OpenCV
|
| 256 |
-
skeleton = self._create_skeleton_opencv(mask)
|
| 257 |
-
features['images']['skeleton'] = skeleton
|
| 258 |
-
|
| 259 |
-
# Find branch points
|
| 260 |
-
branch_points = self._find_branch_points_opencv(skeleton)
|
| 261 |
-
features['images']['branch_points'] = branch_points
|
| 262 |
-
features['traits']['num_branches'] = len(branch_points)
|
| 263 |
-
|
| 264 |
-
# Find endpoints
|
| 265 |
-
endpoints = self._find_endpoints_opencv(skeleton)
|
| 266 |
-
features['images']['endpoints'] = endpoints
|
| 267 |
-
features['traits']['num_endpoints'] = len(endpoints)
|
| 268 |
-
|
| 269 |
-
# Skeleton length
|
| 270 |
-
skeleton_length = np.sum(skeleton > 0)
|
| 271 |
-
features['traits']['skeleton_length_pixels'] = skeleton_length
|
| 272 |
-
features['traits']['skeleton_length_cm'] = skeleton_length * self.pixel_to_cm
|
| 273 |
-
|
| 274 |
-
except Exception as e:
|
| 275 |
-
logger.error(f"OpenCV feature extraction failed: {e}")
|
| 276 |
-
|
| 277 |
-
return features
|
| 278 |
-
|
| 279 |
-
def _create_skeleton_opencv(self, mask: np.ndarray) -> np.ndarray:
|
| 280 |
-
"""Create skeleton using OpenCV."""
|
| 281 |
-
# Convert to binary
|
| 282 |
-
binary = (mask > 0).astype(np.uint8)
|
| 283 |
-
|
| 284 |
-
# Create skeleton using morphological operations
|
| 285 |
-
skeleton = np.zeros_like(binary)
|
| 286 |
-
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
|
| 287 |
-
|
| 288 |
-
while True:
|
| 289 |
-
eroded = cv2.erode(binary, element)
|
| 290 |
-
temp = cv2.dilate(eroded, element)
|
| 291 |
-
temp = cv2.subtract(binary, temp)
|
| 292 |
-
skeleton = cv2.bitwise_or(skeleton, temp)
|
| 293 |
-
binary = eroded.copy()
|
| 294 |
-
|
| 295 |
-
if cv2.countNonZero(binary) == 0:
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
-
return skeleton * 255
|
| 299 |
-
|
| 300 |
-
def _find_branch_points_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
|
| 301 |
-
"""Find branch points in skeleton using OpenCV."""
|
| 302 |
-
# Count neighbors for each pixel
|
| 303 |
-
kernel = np.ones((3, 3), dtype=np.uint8)
|
| 304 |
-
kernel[1, 1] = 0 # Don't count center pixel
|
| 305 |
-
|
| 306 |
-
neighbor_count = cv2.filter2D(skeleton, -1, kernel)
|
| 307 |
-
|
| 308 |
-
# Branch points have 3 or more neighbors
|
| 309 |
-
branch_points = np.where((skeleton > 0) & (neighbor_count >= 3))
|
| 310 |
-
return list(zip(branch_points[1], branch_points[0])) # (x, y) format
|
| 311 |
-
|
| 312 |
-
def _find_endpoints_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
|
| 313 |
-
"""Find endpoints in skeleton using OpenCV."""
|
| 314 |
-
# Count neighbors for each pixel
|
| 315 |
-
kernel = np.ones((3, 3), dtype=np.uint8)
|
| 316 |
-
kernel[1, 1] = 0 # Don't count center pixel
|
| 317 |
-
|
| 318 |
-
neighbor_count = cv2.filter2D(skeleton, -1, kernel)
|
| 319 |
-
|
| 320 |
-
# Endpoints have exactly 1 neighbor
|
| 321 |
-
endpoints = np.where((skeleton > 0) & (neighbor_count == 1))
|
| 322 |
-
return list(zip(endpoints[1], endpoints[0])) # (x, y) format
|
| 323 |
-
|
| 324 |
class _FilteredStream:
|
| 325 |
-
"""Filter PlantCV output
|
| 326 |
def __init__(self, stream):
|
| 327 |
self.stream = stream
|
| 328 |
|
|
@@ -335,46 +93,4 @@ class MorphologyExtractor:
|
|
| 335 |
try:
|
| 336 |
self.stream.flush()
|
| 337 |
except Exception:
|
| 338 |
-
pass
|
| 339 |
-
|
| 340 |
-
def create_morphology_visualization(self, image: np.ndarray, mask: np.ndarray,
|
| 341 |
-
features: Dict[str, Any]) -> np.ndarray:
|
| 342 |
-
"""
|
| 343 |
-
Create visualization of morphological features.
|
| 344 |
-
|
| 345 |
-
Args:
|
| 346 |
-
image: Original image
|
| 347 |
-
mask: Binary mask
|
| 348 |
-
features: Extracted features
|
| 349 |
-
|
| 350 |
-
Returns:
|
| 351 |
-
Visualization image
|
| 352 |
-
"""
|
| 353 |
-
try:
|
| 354 |
-
# Create visualization
|
| 355 |
-
vis = image.copy()
|
| 356 |
-
|
| 357 |
-
# Draw mask outline
|
| 358 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 359 |
-
cv2.drawContours(vis, contours, -1, (0, 255, 0), 2)
|
| 360 |
-
|
| 361 |
-
# Draw bounding box
|
| 362 |
-
if contours:
|
| 363 |
-
x, y, w, h = cv2.boundingRect(contours[0])
|
| 364 |
-
cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
|
| 365 |
-
|
| 366 |
-
# Draw skeleton if available
|
| 367 |
-
if 'skeleton' in features.get('images', {}):
|
| 368 |
-
skeleton = features['images']['skeleton']
|
| 369 |
-
vis[skeleton > 0] = [0, 0, 255] # Red skeleton
|
| 370 |
-
|
| 371 |
-
# Draw branch points if available
|
| 372 |
-
if 'branch_points' in features.get('images', {}):
|
| 373 |
-
branch_img = features['images']['branch_points']
|
| 374 |
-
vis[branch_img > 0] = [255, 255, 0] # Yellow branch points
|
| 375 |
-
|
| 376 |
-
return vis
|
| 377 |
-
|
| 378 |
-
except Exception as e:
|
| 379 |
-
logger.error(f"Visualization creation failed: {e}")
|
| 380 |
-
return image
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal morphological feature extraction (PlantCV size analysis only).
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import cv2
|
| 7 |
import contextlib
|
| 8 |
import sys
|
| 9 |
+
from typing import Dict, Any, List
|
| 10 |
import logging
|
| 11 |
|
|
|
|
| 12 |
try:
|
| 13 |
from plantcv import plantcv as pcv
|
| 14 |
PLANT_CV_AVAILABLE = True
|
| 15 |
except ImportError:
|
| 16 |
PLANT_CV_AVAILABLE = False
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
|
| 21 |
class MorphologyExtractor:
|
| 22 |
+
"""Minimal morphology extraction (PlantCV size analysis)."""
|
| 23 |
|
| 24 |
def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
|
| 25 |
+
"""Initialize."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.pixel_to_cm = pixel_to_cm
|
| 27 |
self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
|
| 28 |
|
| 29 |
if PLANT_CV_AVAILABLE:
|
|
|
|
| 30 |
pcv.params.debug = None
|
| 31 |
pcv.params.text_size = 0.7
|
| 32 |
pcv.params.text_thickness = 2
|
|
|
|
| 34 |
pcv.params.dpi = 100
|
| 35 |
|
| 36 |
def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 37 |
+
"""Extract only PlantCV size analysis image."""
|
| 38 |
+
features = {'traits': {}, 'images': {}, 'success': False}
|
| 39 |
|
| 40 |
+
if not PLANT_CV_AVAILABLE:
|
| 41 |
+
logger.warning("PlantCV not available")
|
| 42 |
+
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
try:
|
|
|
|
| 45 |
clean_mask = self._preprocess_mask(mask)
|
| 46 |
if clean_mask is None:
|
|
|
|
| 47 |
return features
|
| 48 |
|
| 49 |
+
# Size analysis only
|
| 50 |
+
with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
|
| 51 |
+
contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
|
| 52 |
+
try:
|
| 53 |
+
labeled_mask, n_labels = pcv.create_labels(clean_mask)
|
| 54 |
+
size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
|
| 55 |
+
features['images']['size_analysis'] = size_analysis
|
| 56 |
+
features['success'] = True
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.warning(f"Size analysis failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
except Exception as e:
|
| 61 |
+
logger.error(f"Morphology extraction failed: {e}")
|
| 62 |
|
| 63 |
return features
|
| 64 |
|
| 65 |
+
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 66 |
+
"""Preprocess mask."""
|
| 67 |
if mask is None:
|
| 68 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
|
|
|
|
|
|
|
| 70 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 71 |
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 72 |
|
|
|
|
| 73 |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
|
| 74 |
clean_mask = np.zeros_like(opened)
|
| 75 |
|
| 76 |
+
for label in range(1, num_labels):
|
| 77 |
if stats[label, cv2.CC_STAT_AREA] >= 1000:
|
| 78 |
clean_mask[labels == label] = 255
|
| 79 |
|
| 80 |
return clean_mask
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
class _FilteredStream:
|
| 83 |
+
"""Filter PlantCV output."""
|
| 84 |
def __init__(self, stream):
|
| 85 |
self.stream = stream
|
| 86 |
|
|
|
|
| 93 |
try:
|
| 94 |
self.stream.flush()
|
| 95 |
except Exception:
|
| 96 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/features/texture.py
CHANGED
|
@@ -1,299 +1,79 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles extraction of texture features including:
|
| 5 |
-
- Local Binary Patterns (LBP)
|
| 6 |
-
- Histogram of Oriented Gradients (HOG)
|
| 7 |
-
- Lacunarity features
|
| 8 |
-
- Edge Histogram Descriptor (EHD)
|
| 9 |
"""
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
-
import cv2
|
| 13 |
import torch
|
| 14 |
import torch.nn.functional as F
|
| 15 |
from skimage.feature import local_binary_pattern, hog
|
| 16 |
from skimage import exposure
|
| 17 |
-
from scipy import ndimage
|
| 18 |
-
from
|
| 19 |
-
from typing import Dict, Tuple, Optional, Any
|
| 20 |
import logging
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
class TextureExtractor:
|
| 26 |
-
"""
|
| 27 |
|
| 28 |
-
def __init__(self,
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
hog_cells_per_block: Tuple[int, int] = (2, 2),
|
| 34 |
-
lacunarity_window: int = 15,
|
| 35 |
-
ehd_threshold: float = 0.3,
|
| 36 |
-
angle_resolution: int = 45):
|
| 37 |
-
"""
|
| 38 |
-
Initialize texture extractor.
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
lbp_points: Number of points for LBP
|
| 42 |
-
lbp_radius: Radius for LBP
|
| 43 |
-
hog_orientations: Number of orientations for HOG
|
| 44 |
-
hog_pixels_per_cell: Pixels per cell for HOG
|
| 45 |
-
hog_cells_per_block: Cells per block for HOG
|
| 46 |
-
lacunarity_window: Window size for lacunarity
|
| 47 |
-
ehd_threshold: Threshold for EHD
|
| 48 |
-
angle_resolution: Angle resolution for EHD
|
| 49 |
-
"""
|
| 50 |
self.lbp_points = lbp_points
|
| 51 |
self.lbp_radius = lbp_radius
|
| 52 |
self.hog_orientations = hog_orientations
|
| 53 |
self.hog_pixels_per_cell = hog_pixels_per_cell
|
| 54 |
self.hog_cells_per_block = hog_cells_per_block
|
| 55 |
self.lacunarity_window = lacunarity_window
|
| 56 |
-
self.ehd_threshold = ehd_threshold
|
| 57 |
-
self.angle_resolution = angle_resolution
|
| 58 |
|
| 59 |
def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
|
| 60 |
-
"""
|
| 61 |
-
Extract Local Binary Pattern features.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
gray_image: Grayscale input image
|
| 65 |
-
|
| 66 |
-
Returns:
|
| 67 |
-
LBP feature map
|
| 68 |
-
"""
|
| 69 |
try:
|
| 70 |
-
lbp = local_binary_pattern(
|
| 71 |
-
gray_image,
|
| 72 |
-
self.lbp_points,
|
| 73 |
-
self.lbp_radius,
|
| 74 |
-
method='uniform'
|
| 75 |
-
)
|
| 76 |
return self._convert_to_uint8(lbp)
|
| 77 |
except Exception as e:
|
| 78 |
-
logger.error(f"LBP
|
| 79 |
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 80 |
|
| 81 |
def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
|
| 82 |
-
"""
|
| 83 |
-
Extract Histogram of Oriented Gradients features.
|
| 84 |
-
|
| 85 |
-
Args:
|
| 86 |
-
gray_image: Grayscale input image
|
| 87 |
-
|
| 88 |
-
Returns:
|
| 89 |
-
HOG feature map
|
| 90 |
-
"""
|
| 91 |
try:
|
| 92 |
-
_, vis = hog(
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
cells_per_block=self.hog_cells_per_block,
|
| 97 |
-
visualize=True,
|
| 98 |
-
feature_vector=True
|
| 99 |
-
)
|
| 100 |
return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
|
| 101 |
except Exception as e:
|
| 102 |
-
logger.error(f"HOG
|
| 103 |
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 104 |
|
| 105 |
-
def compute_local_lacunarity(self, gray_image: np.ndarray
|
| 106 |
-
"""
|
| 107 |
-
Compute local lacunarity.
|
| 108 |
-
|
| 109 |
-
Args:
|
| 110 |
-
gray_image: Grayscale input image
|
| 111 |
-
window_size: Size of the sliding window
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Local lacunarity map
|
| 115 |
-
"""
|
| 116 |
try:
|
| 117 |
arr = gray_image.astype(np.float32)
|
| 118 |
-
m1 = ndimage.uniform_filter(arr, size=
|
| 119 |
-
m2 = ndimage.uniform_filter(arr * arr, size=
|
| 120 |
var = m2 - m1 * m1
|
| 121 |
-
|
| 122 |
-
lac
|
| 123 |
-
lac
|
| 124 |
-
return lac
|
| 125 |
-
except Exception as e:
|
| 126 |
-
logger.error(f"Local lacunarity computation failed: {e}")
|
| 127 |
-
return np.zeros_like(gray_image, dtype=np.float32)
|
| 128 |
-
|
| 129 |
-
def compute_lacunarity_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 130 |
-
"""
|
| 131 |
-
Compute three types of lacunarity features.
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
gray_image: Grayscale input image
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
Tuple of (lac1, lac2, lac3) lacunarity maps
|
| 138 |
-
"""
|
| 139 |
-
try:
|
| 140 |
-
# L1: Single window lacunarity
|
| 141 |
-
lac1 = self.compute_local_lacunarity(gray_image, self.lacunarity_window)
|
| 142 |
-
|
| 143 |
-
# L2: Multi-scale lacunarity
|
| 144 |
-
scales = [max(3, self.lacunarity_window//2), self.lacunarity_window, self.lacunarity_window*2]
|
| 145 |
-
lac2 = np.mean([
|
| 146 |
-
self.compute_local_lacunarity(gray_image, s) for s in scales
|
| 147 |
-
], axis=0)
|
| 148 |
-
|
| 149 |
-
# L3: DBC Lacunarity (if available)
|
| 150 |
-
try:
|
| 151 |
-
from ..models.dbc_lacunarity import DBC_Lacunarity
|
| 152 |
-
x = torch.from_numpy(gray_image.astype(np.float32)/255.0)[None, None]
|
| 153 |
-
layer = DBC_Lacunarity(window_size=self.lacunarity_window).eval()
|
| 154 |
-
with torch.no_grad():
|
| 155 |
-
lac3 = layer(x).squeeze().cpu().numpy()
|
| 156 |
-
except ImportError:
|
| 157 |
-
logger.warning("DBC Lacunarity not available, using L2 as L3")
|
| 158 |
-
lac3 = lac2.copy()
|
| 159 |
-
|
| 160 |
-
return (
|
| 161 |
-
self._convert_to_uint8(lac1),
|
| 162 |
-
self._convert_to_uint8(lac2),
|
| 163 |
-
self._convert_to_uint8(lac3)
|
| 164 |
-
)
|
| 165 |
-
except Exception as e:
|
| 166 |
-
logger.error(f"Lacunarity features computation failed: {e}")
|
| 167 |
-
empty = np.zeros_like(gray_image, dtype=np.uint8)
|
| 168 |
-
return empty, empty, empty
|
| 169 |
-
|
| 170 |
-
def generate_ehd_masks(self, mask_size: int = 3) -> np.ndarray:
|
| 171 |
-
"""
|
| 172 |
-
Generate masks for Edge Histogram Descriptor.
|
| 173 |
-
|
| 174 |
-
Args:
|
| 175 |
-
mask_size: Size of the mask
|
| 176 |
-
|
| 177 |
-
Returns:
|
| 178 |
-
Array of EHD masks
|
| 179 |
-
"""
|
| 180 |
-
if mask_size < 3:
|
| 181 |
-
mask_size = 3
|
| 182 |
-
if mask_size % 2 == 0:
|
| 183 |
-
mask_size += 1
|
| 184 |
-
|
| 185 |
-
# Base gradient mask
|
| 186 |
-
Gy = np.outer([1, 0, -1], [1, 2, 1])
|
| 187 |
-
|
| 188 |
-
# Expand if needed
|
| 189 |
-
if mask_size > 3:
|
| 190 |
-
expd = np.outer([1, 2, 1], [1, 2, 1])
|
| 191 |
-
for _ in range((mask_size - 3) // 2):
|
| 192 |
-
Gy = signal.convolve2d(expd, Gy, mode='full')
|
| 193 |
-
|
| 194 |
-
# Generate masks for different angles
|
| 195 |
-
angles = np.arange(0, 360, self.angle_resolution)
|
| 196 |
-
masks = np.zeros((len(angles), mask_size, mask_size), dtype=np.float32)
|
| 197 |
-
|
| 198 |
-
for i, angle in enumerate(angles):
|
| 199 |
-
masks[i] = ndimage.rotate(Gy, angle, reshape=False, mode='nearest')
|
| 200 |
-
|
| 201 |
-
return masks
|
| 202 |
-
|
| 203 |
-
def extract_ehd_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 204 |
-
"""
|
| 205 |
-
Extract Edge Histogram Descriptor features.
|
| 206 |
-
|
| 207 |
-
Args:
|
| 208 |
-
gray_image: Grayscale input image
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
Tuple of (ehd_features, ehd_map)
|
| 212 |
-
"""
|
| 213 |
-
try:
|
| 214 |
-
# Generate masks
|
| 215 |
-
masks = self.generate_ehd_masks()
|
| 216 |
-
|
| 217 |
-
# Convert to tensor
|
| 218 |
-
X = torch.from_numpy(gray_image.astype(np.float32)/255.0).unsqueeze(0).unsqueeze(0)
|
| 219 |
-
masks_tensor = torch.tensor(masks).unsqueeze(1).float()
|
| 220 |
-
|
| 221 |
-
# Convolve with masks
|
| 222 |
-
edge_responses = F.conv2d(X, masks_tensor, dilation=7)
|
| 223 |
-
|
| 224 |
-
# Find maximum response
|
| 225 |
-
values, indices = torch.max(edge_responses, dim=1)
|
| 226 |
-
indices[values < self.ehd_threshold] = masks.shape[0]
|
| 227 |
-
|
| 228 |
-
# Pool features
|
| 229 |
-
feat_vect = []
|
| 230 |
-
for edge in range(masks.shape[0] + 1):
|
| 231 |
-
pooled = F.avg_pool2d(
|
| 232 |
-
(indices == edge).unsqueeze(1).float(),
|
| 233 |
-
kernel_size=5, stride=1, padding=2
|
| 234 |
-
)
|
| 235 |
-
feat_vect.append(pooled.squeeze(1))
|
| 236 |
-
|
| 237 |
-
ehd_features = torch.stack(feat_vect, dim=1).squeeze(0).cpu().numpy()
|
| 238 |
-
ehd_map = np.argmax(ehd_features, axis=0).astype(np.uint8)
|
| 239 |
-
|
| 240 |
-
return ehd_features, ehd_map
|
| 241 |
-
|
| 242 |
except Exception as e:
|
| 243 |
-
logger.error(f"
|
| 244 |
-
|
| 245 |
-
empty_map = np.zeros_like(gray_image, dtype=np.uint8)
|
| 246 |
-
return empty_features, empty_map
|
| 247 |
|
| 248 |
def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
|
| 249 |
-
"""
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
Returns:
|
| 256 |
-
Dictionary of texture features
|
| 257 |
-
"""
|
| 258 |
-
features = {}
|
| 259 |
-
|
| 260 |
-
try:
|
| 261 |
-
# LBP
|
| 262 |
-
features['lbp'] = self.extract_lbp(gray_image)
|
| 263 |
-
|
| 264 |
-
# HOG
|
| 265 |
-
features['hog'] = self.extract_hog(gray_image)
|
| 266 |
-
|
| 267 |
-
# Lacunarity
|
| 268 |
-
lac1, lac2, lac3 = self.compute_lacunarity_features(gray_image)
|
| 269 |
-
features['lac1'] = lac1
|
| 270 |
-
features['lac2'] = lac2
|
| 271 |
-
features['lac3'] = lac3
|
| 272 |
-
|
| 273 |
-
# EHD
|
| 274 |
-
ehd_features, ehd_map = self.extract_ehd_features(gray_image)
|
| 275 |
-
features['ehd_features'] = ehd_features
|
| 276 |
-
features['ehd_map'] = ehd_map
|
| 277 |
-
|
| 278 |
-
logger.debug("All texture features extracted successfully")
|
| 279 |
-
|
| 280 |
-
except Exception as e:
|
| 281 |
-
logger.error(f"Texture feature extraction failed: {e}")
|
| 282 |
-
# Return empty features
|
| 283 |
-
features = {
|
| 284 |
-
'lbp': np.zeros_like(gray_image, dtype=np.uint8),
|
| 285 |
-
'hog': np.zeros_like(gray_image, dtype=np.uint8),
|
| 286 |
-
'lac1': np.zeros_like(gray_image, dtype=np.uint8),
|
| 287 |
-
'lac2': np.zeros_like(gray_image, dtype=np.uint8),
|
| 288 |
-
'lac3': np.zeros_like(gray_image, dtype=np.uint8),
|
| 289 |
-
'ehd_features': np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32),
|
| 290 |
-
'ehd_map': np.zeros_like(gray_image, dtype=np.uint8)
|
| 291 |
-
}
|
| 292 |
-
|
| 293 |
-
return features
|
| 294 |
|
| 295 |
def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 296 |
-
"""Convert
|
| 297 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 298 |
if arr.ptp() > 0:
|
| 299 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
|
@@ -302,72 +82,19 @@ class TextureExtractor:
|
|
| 302 |
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 303 |
|
| 304 |
def compute_texture_statistics(self, features: Dict[str, np.ndarray],
|
| 305 |
-
|
| 306 |
-
"""
|
| 307 |
-
Compute statistics for texture features.
|
| 308 |
-
|
| 309 |
-
Args:
|
| 310 |
-
features: Dictionary of texture features
|
| 311 |
-
mask: Optional mask to apply
|
| 312 |
-
|
| 313 |
-
Returns:
|
| 314 |
-
Dictionary of feature statistics
|
| 315 |
-
"""
|
| 316 |
stats = {}
|
| 317 |
-
|
| 318 |
for feature_name, feature_data in features.items():
|
| 319 |
-
if
|
| 320 |
-
|
| 321 |
-
if mask is not None:
|
| 322 |
-
# Apply mask to each channel
|
| 323 |
-
masked_features = []
|
| 324 |
-
for i in range(feature_data.shape[0]):
|
| 325 |
-
channel = feature_data[i]
|
| 326 |
-
if mask.shape != channel.shape:
|
| 327 |
-
# Resize mask to match channel
|
| 328 |
-
mask_resized = cv2.resize(mask, (channel.shape[1], channel.shape[0]),
|
| 329 |
-
interpolation=cv2.INTER_NEAREST)
|
| 330 |
-
masked_channel = np.where(mask_resized > 0, channel, np.nan)
|
| 331 |
-
else:
|
| 332 |
-
masked_channel = np.where(mask > 0, channel, np.nan)
|
| 333 |
-
masked_features.append(masked_channel)
|
| 334 |
-
feature_data = np.stack(masked_features, axis=0)
|
| 335 |
-
else:
|
| 336 |
-
feature_data = feature_data
|
| 337 |
-
|
| 338 |
-
# Compute statistics for each EHD channel
|
| 339 |
-
channel_stats = {}
|
| 340 |
-
for i in range(feature_data.shape[0]):
|
| 341 |
-
channel = feature_data[i]
|
| 342 |
-
valid_data = channel[~np.isnan(channel)]
|
| 343 |
-
if len(valid_data) > 0:
|
| 344 |
-
channel_stats[f'channel_{i}'] = {
|
| 345 |
-
'mean': float(np.mean(valid_data)),
|
| 346 |
-
'std': float(np.std(valid_data)),
|
| 347 |
-
'min': float(np.min(valid_data)),
|
| 348 |
-
'max': float(np.max(valid_data)),
|
| 349 |
-
'median': float(np.median(valid_data))
|
| 350 |
-
}
|
| 351 |
-
stats[feature_name] = channel_stats
|
| 352 |
else:
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
'mean': float(np.mean(valid_data)),
|
| 363 |
-
'std': float(np.std(valid_data)),
|
| 364 |
-
'min': float(np.min(valid_data)),
|
| 365 |
-
'max': float(np.max(valid_data)),
|
| 366 |
-
'median': float(np.median(valid_data))
|
| 367 |
-
}
|
| 368 |
-
else:
|
| 369 |
-
stats[feature_name] = {
|
| 370 |
-
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0, 'median': 0.0
|
| 371 |
-
}
|
| 372 |
-
|
| 373 |
-
return stats
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal texture feature extraction.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from skimage.feature import local_binary_pattern, hog
|
| 9 |
from skimage import exposure
|
| 10 |
+
from scipy import ndimage
|
| 11 |
+
from typing import Dict, Tuple, Optional
|
|
|
|
| 12 |
import logging
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class TextureExtractor:
|
| 18 |
+
"""Minimal texture extraction (LBP, HOG, Lacunarity only)."""
|
| 19 |
|
| 20 |
+
def __init__(self, lbp_points: int = 8, lbp_radius: int = 1,
|
| 21 |
+
hog_orientations: int = 9, hog_pixels_per_cell: Tuple[int, int] = (8, 8),
|
| 22 |
+
hog_cells_per_block: Tuple[int, int] = (2, 2), lacunarity_window: int = 15,
|
| 23 |
+
ehd_threshold: float = 0.3, angle_resolution: int = 45):
|
| 24 |
+
"""Initialize with defaults."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.lbp_points = lbp_points
|
| 26 |
self.lbp_radius = lbp_radius
|
| 27 |
self.hog_orientations = hog_orientations
|
| 28 |
self.hog_pixels_per_cell = hog_pixels_per_cell
|
| 29 |
self.hog_cells_per_block = hog_cells_per_block
|
| 30 |
self.lacunarity_window = lacunarity_window
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
|
| 33 |
+
"""Extract Local Binary Pattern."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
try:
|
| 35 |
+
lbp = local_binary_pattern(gray_image, self.lbp_points, self.lbp_radius, method='uniform')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
return self._convert_to_uint8(lbp)
|
| 37 |
except Exception as e:
|
| 38 |
+
logger.error(f"LBP failed: {e}")
|
| 39 |
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 40 |
|
| 41 |
def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
|
| 42 |
+
"""Extract HOG features."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
try:
|
| 44 |
+
_, vis = hog(gray_image, orientations=self.hog_orientations,
|
| 45 |
+
pixels_per_cell=self.hog_pixels_per_cell,
|
| 46 |
+
cells_per_block=self.hog_cells_per_block,
|
| 47 |
+
visualize=True, feature_vector=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
|
| 49 |
except Exception as e:
|
| 50 |
+
logger.error(f"HOG failed: {e}")
|
| 51 |
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 52 |
|
| 53 |
+
def compute_local_lacunarity(self, gray_image: np.ndarray) -> np.ndarray:
|
| 54 |
+
"""Compute lacunarity."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
try:
|
| 56 |
arr = gray_image.astype(np.float32)
|
| 57 |
+
m1 = ndimage.uniform_filter(arr, size=self.lacunarity_window)
|
| 58 |
+
m2 = ndimage.uniform_filter(arr * arr, size=self.lacunarity_window)
|
| 59 |
var = m2 - m1 * m1
|
| 60 |
+
lac = var / (m1 * m1 + 1e-6) + 1
|
| 61 |
+
lac[m1 <= 1e-6] = 0
|
| 62 |
+
return self._convert_to_uint8(lac)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
except Exception as e:
|
| 64 |
+
logger.error(f"Lacunarity failed: {e}")
|
| 65 |
+
return np.zeros_like(gray_image, dtype=np.uint8)
|
|
|
|
|
|
|
| 66 |
|
| 67 |
def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
|
| 68 |
+
"""Extract LBP, HOG, and Lacunarity."""
|
| 69 |
+
return {
|
| 70 |
+
'lbp': self.extract_lbp(gray_image),
|
| 71 |
+
'hog': self.extract_hog(gray_image),
|
| 72 |
+
'lac2': self.compute_local_lacunarity(gray_image)
|
| 73 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 76 |
+
"""Convert to uint8."""
|
| 77 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 78 |
if arr.ptp() > 0:
|
| 79 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
|
|
|
| 82 |
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 83 |
|
| 84 |
def compute_texture_statistics(self, features: Dict[str, np.ndarray],
|
| 85 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
|
| 86 |
+
"""Compute basic statistics."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
stats = {}
|
|
|
|
| 88 |
for feature_name, feature_data in features.items():
|
| 89 |
+
if mask is not None and mask.shape == feature_data.shape:
|
| 90 |
+
masked_data = np.where(mask > 0, feature_data, np.nan)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
else:
|
| 92 |
+
masked_data = feature_data
|
| 93 |
+
|
| 94 |
+
valid_data = masked_data[~np.isnan(masked_data)]
|
| 95 |
+
if len(valid_data) > 0:
|
| 96 |
+
stats[feature_name] = {
|
| 97 |
+
'mean': float(np.mean(valid_data)),
|
| 98 |
+
'std': float(np.std(valid_data)),
|
| 99 |
+
}
|
| 100 |
+
return stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/features/vegetation.py
CHANGED
|
@@ -1,308 +1,71 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles extraction of various vegetation indices
|
| 5 |
-
from multispectral data.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
-
import
|
| 10 |
-
from typing import Dict, Tuple, Optional, Any
|
| 11 |
import logging
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
|
| 16 |
class VegetationIndexExtractor:
|
| 17 |
-
"""
|
| 18 |
|
| 19 |
def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
epsilon: Small value to avoid division by zero
|
| 25 |
-
soil_factor: Soil factor for certain indices
|
| 26 |
-
"""
|
| 27 |
-
# Coerce to float in case config passed strings like "1e-10"
|
| 28 |
-
try:
|
| 29 |
-
self.epsilon = float(epsilon)
|
| 30 |
-
except Exception:
|
| 31 |
-
self.epsilon = 1e-10
|
| 32 |
-
try:
|
| 33 |
-
self.soil_factor = float(soil_factor)
|
| 34 |
-
except Exception:
|
| 35 |
-
self.soil_factor = 0.16
|
| 36 |
|
| 37 |
-
# Define vegetation index formulas
|
| 38 |
self.index_formulas = {
|
| 39 |
"NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
|
| 40 |
-
"GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
|
| 41 |
-
"NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
|
| 42 |
-
"GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
|
| 43 |
-
"TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
|
| 44 |
-
"MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
|
| 45 |
-
"GRVI": lambda nir, green: nir / (green + self.epsilon),
|
| 46 |
-
"NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
|
| 47 |
-
"MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
|
| 48 |
-
"OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
|
| 49 |
-
"TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
|
| 50 |
-
"GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
|
| 51 |
-
# Requested additions and aliases
|
| 52 |
-
"GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
|
| 53 |
-
"GDVI": lambda nir, green: nir - green,
|
| 54 |
-
"NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
|
| 55 |
-
"DSWI4": lambda green, red: green / (red + self.epsilon),
|
| 56 |
-
"CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
|
| 57 |
-
"LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
|
| 58 |
-
"CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
|
| 59 |
-
"MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
|
| 60 |
-
"MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
|
| 61 |
-
"MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
|
| 62 |
-
# MTVI variants per request
|
| 63 |
-
"MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
|
| 64 |
-
"MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
|
| 65 |
-
"CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
|
| 66 |
"ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
|
| 67 |
-
"
|
| 68 |
-
"DVI": lambda nir, red: nir - red,
|
| 69 |
-
"WDVI": lambda nir, red, a=0.5: nir - a * red,
|
| 70 |
-
"SR": lambda nir, red: nir / (red + self.epsilon),
|
| 71 |
-
"MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
|
| 72 |
-
"PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
|
| 73 |
-
"GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
|
| 74 |
-
"ExR": lambda red, green: 1.3 * red - green,
|
| 75 |
-
"RI": lambda red, green: (red - green) / (red + green + self.epsilon),
|
| 76 |
-
"RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
|
| 77 |
-
"RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
|
| 78 |
-
"RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
|
| 79 |
-
"AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
|
| 80 |
-
"SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
|
| 81 |
-
"TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
|
| 82 |
-
"TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
|
| 83 |
-
"CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
|
| 84 |
-
# Additional indices
|
| 85 |
-
"RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
|
| 86 |
-
"NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
|
| 87 |
-
"BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
|
| 88 |
-
"IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
|
| 89 |
-
"EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
|
| 90 |
}
|
| 91 |
|
| 92 |
-
# Define required bands for each index
|
| 93 |
self.index_bands = {
|
| 94 |
"NDVI": ["nir", "red"],
|
| 95 |
-
"GNDVI": ["nir", "green"],
|
| 96 |
-
"NDRE": ["nir", "red_edge"],
|
| 97 |
-
"GRNDVI": ["nir", "green", "red"],
|
| 98 |
-
"TNDVI": ["nir", "red"],
|
| 99 |
-
"MGRVI": ["green", "red"],
|
| 100 |
-
"GRVI": ["nir", "green"],
|
| 101 |
-
"NGRDI": ["green", "red"],
|
| 102 |
-
"MSAVI": ["nir", "red"],
|
| 103 |
-
"OSAVI": ["nir", "red"],
|
| 104 |
-
"TSAVI": ["nir", "red"],
|
| 105 |
-
"GSAVI": ["nir", "green"],
|
| 106 |
-
"GOSAVI": ["nir", "green"],
|
| 107 |
-
"GDVI": ["nir", "green"],
|
| 108 |
-
"NDWI": ["green", "nir"],
|
| 109 |
-
"DSWI4": ["green", "red"],
|
| 110 |
-
"CIRE": ["nir", "red_edge"],
|
| 111 |
-
"LCI": ["nir", "red_edge"],
|
| 112 |
-
"CIgreen": ["nir", "green"],
|
| 113 |
-
"MCARI": ["red_edge", "red", "green"],
|
| 114 |
-
"MCARI1": ["nir", "red", "green"],
|
| 115 |
-
"MCARI2": ["nir", "red", "green"],
|
| 116 |
-
"MTVI1": ["nir", "red", "green"],
|
| 117 |
-
"MTVI2": ["nir", "red", "green"],
|
| 118 |
-
"CVI": ["nir", "red", "green"],
|
| 119 |
"ARI": ["green", "red_edge"],
|
| 120 |
-
"
|
| 121 |
-
"DVI": ["nir", "red"],
|
| 122 |
-
"WDVI": ["nir", "red"],
|
| 123 |
-
"SR": ["nir", "red"],
|
| 124 |
-
"MSR": ["nir", "red"],
|
| 125 |
-
"PVI": ["nir", "red"],
|
| 126 |
-
"GEMI": ["nir", "red"],
|
| 127 |
-
"ExR": ["red", "green"],
|
| 128 |
-
"RI": ["red", "green"],
|
| 129 |
-
"RRI1": ["nir", "red_edge"],
|
| 130 |
-
"RRI2": ["red_edge", "red"],
|
| 131 |
-
"RRI": ["nir", "red_edge"],
|
| 132 |
-
"AVI": ["nir", "red"],
|
| 133 |
-
"SIPI2": ["nir", "green", "red"],
|
| 134 |
-
"TCARI": ["red_edge", "red", "green"],
|
| 135 |
-
"TCARIOSAVI": ["red_edge", "red", "green", "nir"],
|
| 136 |
-
"CCCI": ["nir", "red_edge", "red"],
|
| 137 |
-
"RDVI": ["nir", "red"],
|
| 138 |
-
"NLI": ["nir", "red"],
|
| 139 |
-
"BIXS": ["green", "red"],
|
| 140 |
-
"IPVI": ["nir", "red"],
|
| 141 |
-
"EVI2": ["nir", "red"]
|
| 142 |
}
|
| 143 |
|
| 144 |
def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
|
| 145 |
-
|
| 146 |
-
"""
|
| 147 |
-
Compute vegetation indices from spectral data.
|
| 148 |
-
|
| 149 |
-
Args:
|
| 150 |
-
spectral_stack: Dictionary of spectral bands
|
| 151 |
-
mask: Binary mask for the plant
|
| 152 |
-
|
| 153 |
-
Returns:
|
| 154 |
-
Dictionary of vegetation indices with values and statistics
|
| 155 |
-
"""
|
| 156 |
indices = {}
|
| 157 |
|
| 158 |
for index_name, formula in self.index_formulas.items():
|
| 159 |
try:
|
| 160 |
-
|
| 161 |
-
required_bands = self.index_bands.get(index_name, [])
|
| 162 |
-
|
| 163 |
-
# Check if all required bands are available
|
| 164 |
if not all(band in spectral_stack for band in required_bands):
|
| 165 |
-
logger.warning(f"Skipping {index_name}: missing required bands")
|
| 166 |
continue
|
| 167 |
|
| 168 |
-
# Extract band data as float arrays
|
| 169 |
band_data = []
|
| 170 |
for band in required_bands:
|
| 171 |
arr = spectral_stack[band]
|
| 172 |
-
# Ensure numeric float np.ndarray
|
| 173 |
if isinstance(arr, np.ndarray):
|
| 174 |
arr = arr.squeeze(-1)
|
| 175 |
-
|
| 176 |
-
band_data.append(arr)
|
| 177 |
|
| 178 |
-
# Compute index (ensure float math)
|
| 179 |
index_values = formula(*band_data).astype(np.float64)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
# Apply mask
|
| 182 |
-
if mask is not None:
|
| 183 |
-
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
|
| 184 |
-
masked_values = np.where(binary_mask, index_values, np.nan)
|
| 185 |
-
else:
|
| 186 |
-
masked_values = index_values
|
| 187 |
-
|
| 188 |
-
# Compute statistics
|
| 189 |
valid_values = masked_values[~np.isnan(masked_values)]
|
| 190 |
if len(valid_values) > 0:
|
| 191 |
stats = {
|
| 192 |
'mean': float(np.mean(valid_values)),
|
| 193 |
'std': float(np.std(valid_values)),
|
| 194 |
-
'min': float(np.min(valid_values)),
|
| 195 |
-
'max': float(np.max(valid_values)),
|
| 196 |
-
'median': float(np.median(valid_values)),
|
| 197 |
-
'q25': float(np.percentile(valid_values, 25)),
|
| 198 |
-
'q75': float(np.percentile(valid_values, 75)),
|
| 199 |
-
'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
|
| 200 |
}
|
| 201 |
else:
|
| 202 |
-
stats = {
|
| 203 |
-
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
|
| 204 |
-
'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
|
| 205 |
-
}
|
| 206 |
|
| 207 |
indices[index_name] = {
|
| 208 |
'values': masked_values,
|
| 209 |
'statistics': stats
|
| 210 |
}
|
| 211 |
|
| 212 |
-
logger.debug(f"Computed {index_name}")
|
| 213 |
-
|
| 214 |
except Exception as e:
|
| 215 |
logger.error(f"Failed to compute {index_name}: {e}")
|
| 216 |
-
continue
|
| 217 |
-
|
| 218 |
-
return indices
|
| 219 |
-
|
| 220 |
-
def create_vegetation_index_image(self, index_values: np.ndarray,
|
| 221 |
-
colormap: str = 'RdYlGn',
|
| 222 |
-
vmin: Optional[float] = None,
|
| 223 |
-
vmax: Optional[float] = None) -> np.ndarray:
|
| 224 |
-
"""
|
| 225 |
-
Create visualization image for vegetation index.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
index_values: Vegetation index values
|
| 229 |
-
colormap: Matplotlib colormap name
|
| 230 |
-
vmin: Minimum value for normalization
|
| 231 |
-
vmax: Maximum value for normalization
|
| 232 |
-
|
| 233 |
-
Returns:
|
| 234 |
-
RGB image array
|
| 235 |
-
"""
|
| 236 |
-
try:
|
| 237 |
-
import matplotlib.pyplot as plt
|
| 238 |
-
import matplotlib.cm as cm
|
| 239 |
-
from matplotlib.colors import Normalize
|
| 240 |
-
|
| 241 |
-
# Determine value range
|
| 242 |
-
valid_values = index_values[~np.isnan(index_values)]
|
| 243 |
-
if len(valid_values) == 0:
|
| 244 |
-
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
|
| 245 |
-
|
| 246 |
-
if vmin is None:
|
| 247 |
-
vmin = np.min(valid_values)
|
| 248 |
-
if vmax is None:
|
| 249 |
-
vmax = np.max(valid_values)
|
| 250 |
-
|
| 251 |
-
# Normalize values
|
| 252 |
-
norm = Normalize(vmin=vmin, vmax=vmax)
|
| 253 |
-
cmap = cm.get_cmap(colormap)
|
| 254 |
-
|
| 255 |
-
# Apply colormap
|
| 256 |
-
rgba_img = cmap(norm(index_values))
|
| 257 |
-
rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
|
| 258 |
-
|
| 259 |
-
# Convert to RGB uint8
|
| 260 |
-
rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
|
| 261 |
-
|
| 262 |
-
return rgb_img
|
| 263 |
-
|
| 264 |
-
except Exception as e:
|
| 265 |
-
logger.error(f"Failed to create vegetation index image: {e}")
|
| 266 |
-
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
|
| 267 |
-
|
| 268 |
-
def get_available_indices(self) -> list:
|
| 269 |
-
"""Get list of available vegetation indices."""
|
| 270 |
-
return list(self.index_formulas.keys())
|
| 271 |
-
|
| 272 |
-
def get_index_requirements(self, index_name: str) -> list:
|
| 273 |
-
"""
|
| 274 |
-
Get required bands for a specific index.
|
| 275 |
-
|
| 276 |
-
Args:
|
| 277 |
-
index_name: Name of the vegetation index
|
| 278 |
-
|
| 279 |
-
Returns:
|
| 280 |
-
List of required band names
|
| 281 |
-
"""
|
| 282 |
-
return self.index_bands.get(index_name, [])
|
| 283 |
-
|
| 284 |
-
def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
|
| 285 |
-
"""
|
| 286 |
-
Validate spectral data for vegetation index computation.
|
| 287 |
-
|
| 288 |
-
Args:
|
| 289 |
-
spectral_stack: Dictionary of spectral bands
|
| 290 |
-
|
| 291 |
-
Returns:
|
| 292 |
-
True if valid, False otherwise
|
| 293 |
-
"""
|
| 294 |
-
if not spectral_stack:
|
| 295 |
-
return False
|
| 296 |
-
|
| 297 |
-
required_bands = ['nir', 'red', 'green', 'red_edge']
|
| 298 |
-
if not all(band in spectral_stack for band in required_bands):
|
| 299 |
-
logger.warning("Missing required spectral bands")
|
| 300 |
-
return False
|
| 301 |
-
|
| 302 |
-
# Check data shapes
|
| 303 |
-
shapes = [arr.shape for arr in spectral_stack.values()]
|
| 304 |
-
if not all(shape == shapes[0] for shape in shapes):
|
| 305 |
-
logger.warning("Inconsistent spectral band shapes")
|
| 306 |
-
return False
|
| 307 |
|
| 308 |
-
return
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal vegetation index extraction (NDVI, ARI, GNDVI only).
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
+
from typing import Dict, Any
|
|
|
|
| 7 |
import logging
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
|
| 12 |
class VegetationIndexExtractor:
|
| 13 |
+
"""Minimal vegetation index extraction."""
|
| 14 |
|
| 15 |
def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
|
| 16 |
+
"""Initialize with defaults."""
|
| 17 |
+
self.epsilon = epsilon
|
| 18 |
+
self.soil_factor = soil_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
|
|
|
| 20 |
self.index_formulas = {
|
| 21 |
"NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
|
| 23 |
+
"GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
|
|
|
|
| 26 |
self.index_bands = {
|
| 27 |
"NDVI": ["nir", "red"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"ARI": ["green", "red_edge"],
|
| 29 |
+
"GNDVI": ["nir", "green"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
|
| 33 |
+
mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
|
| 34 |
+
"""Compute NDVI, ARI, and GNDVI."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
indices = {}
|
| 36 |
|
| 37 |
for index_name, formula in self.index_formulas.items():
|
| 38 |
try:
|
| 39 |
+
required_bands = self.index_bands[index_name]
|
|
|
|
|
|
|
|
|
|
| 40 |
if not all(band in spectral_stack for band in required_bands):
|
|
|
|
| 41 |
continue
|
| 42 |
|
|
|
|
| 43 |
band_data = []
|
| 44 |
for band in required_bands:
|
| 45 |
arr = spectral_stack[band]
|
|
|
|
| 46 |
if isinstance(arr, np.ndarray):
|
| 47 |
arr = arr.squeeze(-1)
|
| 48 |
+
band_data.append(np.asarray(arr, dtype=np.float64))
|
|
|
|
| 49 |
|
|
|
|
| 50 |
index_values = formula(*band_data).astype(np.float64)
|
| 51 |
+
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
|
| 52 |
+
masked_values = np.where(binary_mask, index_values, np.nan)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
valid_values = masked_values[~np.isnan(masked_values)]
|
| 55 |
if len(valid_values) > 0:
|
| 56 |
stats = {
|
| 57 |
'mean': float(np.mean(valid_values)),
|
| 58 |
'std': float(np.std(valid_values)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
else:
|
| 61 |
+
stats = {'mean': 0.0, 'std': 0.0}
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
indices[index_name] = {
|
| 64 |
'values': masked_values,
|
| 65 |
'statistics': stats
|
| 66 |
}
|
| 67 |
|
|
|
|
|
|
|
| 68 |
except Exception as e:
|
| 69 |
logger.error(f"Failed to compute {index_name}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
return indices
|
sorghum_pipeline/output/manager.py
CHANGED
|
@@ -1,688 +1,143 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles saving results, generating visualizations,
|
| 5 |
-
and creating reports.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
| 9 |
-
import json
|
| 10 |
import numpy as np
|
| 11 |
import cv2
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
matplotlib.use('Agg')
|
| 18 |
-
import matplotlib.pyplot as plt
|
| 19 |
-
import matplotlib.cm as cm
|
| 20 |
-
from matplotlib.colors import Normalize
|
| 21 |
-
except Exception:
|
| 22 |
-
# Fallback safe imports (should not happen normally)
|
| 23 |
-
import matplotlib.pyplot as plt
|
| 24 |
-
import matplotlib.cm as cm
|
| 25 |
-
from matplotlib.colors import Normalize
|
| 26 |
-
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 27 |
from pathlib import Path
|
| 28 |
-
from typing import Dict, Any
|
| 29 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 30 |
-
import pandas as pd
|
| 31 |
import logging
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
|
| 35 |
|
| 36 |
class OutputManager:
|
| 37 |
-
"""
|
| 38 |
|
| 39 |
def __init__(self, output_folder: str, settings: Any):
|
| 40 |
-
"""
|
| 41 |
-
Initialize output manager.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
output_folder: Base output folder
|
| 45 |
-
settings: Output settings from config
|
| 46 |
-
"""
|
| 47 |
self.output_folder = Path(output_folder)
|
| 48 |
self.settings = settings
|
| 49 |
-
# Fast mode and parallel save controls
|
| 50 |
-
try:
|
| 51 |
-
self.fast_mode: bool = bool(int(os.environ.get('FAST_OUTPUT', '0'))) or bool(getattr(settings, 'fast_mode', False))
|
| 52 |
-
except Exception:
|
| 53 |
-
self.fast_mode = False
|
| 54 |
-
try:
|
| 55 |
-
self.max_workers: int = int(os.environ.get('FAST_SAVE_WORKERS', '4'))
|
| 56 |
-
except Exception:
|
| 57 |
-
self.max_workers = 4
|
| 58 |
try:
|
| 59 |
-
self.
|
| 60 |
except Exception:
|
| 61 |
-
self.
|
| 62 |
-
|
| 63 |
-
# Reduce thread usage to lower risk of native library segfaults
|
| 64 |
-
try:
|
| 65 |
-
import os as _os
|
| 66 |
-
_os.environ.setdefault('OMP_NUM_THREADS', '1')
|
| 67 |
-
_os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
|
| 68 |
-
_os.environ.setdefault('MKL_NUM_THREADS', '1')
|
| 69 |
-
_os.environ.setdefault('NUMEXPR_NUM_THREADS', '1')
|
| 70 |
-
except Exception:
|
| 71 |
-
pass
|
| 72 |
-
try:
|
| 73 |
-
cv2.setNumThreads(1)
|
| 74 |
-
except Exception:
|
| 75 |
-
pass
|
| 76 |
-
|
| 77 |
-
# Create base directories
|
| 78 |
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 79 |
|
| 80 |
-
def _imwrite_fast(self, dest: Path, img: np.ndarray) -> None:
|
| 81 |
-
try:
|
| 82 |
-
cv2.imwrite(str(dest), img, [cv2.IMWRITE_PNG_COMPRESSION, int(self.png_compression)])
|
| 83 |
-
except Exception:
|
| 84 |
-
cv2.imwrite(str(dest), img)
|
| 85 |
-
|
| 86 |
def create_output_directories(self) -> None:
|
| 87 |
-
"""
|
| 88 |
-
|
| 89 |
-
Note: Do NOT create subdirectories at the root (e.g., 'analysis').
|
| 90 |
-
Subdirectories are created within each plant's directory only.
|
| 91 |
-
"""
|
| 92 |
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 93 |
|
| 94 |
def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
|
| 95 |
-
"""
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
plant_key: Plant identifier (e.g., "2025_02_05_plant1_frame8")
|
| 100 |
-
plant_data: Plant data dictionary
|
| 101 |
-
"""
|
| 102 |
-
try:
|
| 103 |
-
# Parse plant key
|
| 104 |
-
parts = plant_key.split('_')
|
| 105 |
-
date_key = "_".join(parts[:3])
|
| 106 |
-
plant_name = parts[3]
|
| 107 |
-
frame_key = parts[4] if len(parts) > 4 else "frame0"
|
| 108 |
-
|
| 109 |
-
# Create plant-specific directory
|
| 110 |
-
plant_dir = self.output_folder / date_key / plant_name
|
| 111 |
-
plant_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
-
|
| 113 |
-
# Save segmentation results
|
| 114 |
-
self._save_segmentation_results(plant_dir, plant_name, plant_data)
|
| 115 |
-
|
| 116 |
-
# Save texture features
|
| 117 |
-
self._save_texture_features(plant_dir, plant_data)
|
| 118 |
-
|
| 119 |
-
# Save vegetation indices
|
| 120 |
-
self._save_vegetation_indices(plant_dir, plant_data)
|
| 121 |
-
|
| 122 |
-
# Save morphology features
|
| 123 |
-
self._save_morphology_features(plant_dir, plant_data)
|
| 124 |
-
|
| 125 |
-
# Save analysis plots
|
| 126 |
-
self._save_analysis_plots(plant_dir, plant_data)
|
| 127 |
-
|
| 128 |
-
# Save metadata
|
| 129 |
-
self._save_metadata(plant_dir, plant_key, plant_data)
|
| 130 |
-
|
| 131 |
-
logger.debug(f"Results saved for {plant_key}")
|
| 132 |
-
|
| 133 |
-
except Exception as e:
|
| 134 |
-
logger.error(f"Failed to save results for {plant_key}: {e}")
|
| 135 |
-
|
| 136 |
-
def _save_segmentation_results(self, plant_dir: Path, plant_name: str, plant_data: Dict[str, Any]) -> None:
|
| 137 |
-
"""Save segmentation results."""
|
| 138 |
-
if not self.settings.save_images:
|
| 139 |
return
|
| 140 |
|
| 141 |
-
|
| 142 |
-
seg_dir.mkdir(exist_ok=True)
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
tasks: List[Tuple[Path, np.ndarray]] = []
|
| 146 |
-
# Choose which base image to present in original/overlay
|
| 147 |
-
use_feature_image = False
|
| 148 |
-
try:
|
| 149 |
-
# Allow env override, and special-case plants 13-16 per user requirement
|
| 150 |
-
use_feature_image = bool(int(os.environ.get('OUTPUT_USE_FEATURE_IMAGE', '0'))) or plant_name in { 'plant13','plant14','plant15','plant16' }
|
| 151 |
-
except Exception:
|
| 152 |
-
use_feature_image = plant_name in { 'plant13','plant14','plant15','plant16' }
|
| 153 |
-
if use_feature_image:
|
| 154 |
-
base_image = plant_data.get('composite', plant_data.get('segmentation_composite'))
|
| 155 |
-
else:
|
| 156 |
-
base_image = plant_data.get('segmentation_composite', plant_data.get('composite'))
|
| 157 |
-
if base_image is not None:
|
| 158 |
-
tasks.append((seg_dir / 'original.png', base_image))
|
| 159 |
-
if 'mask' in plant_data:
|
| 160 |
-
tasks.append((seg_dir / 'mask.png', plant_data['mask']))
|
| 161 |
-
if 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
|
| 162 |
-
tasks.append((seg_dir / 'mask3.png', plant_data['mask3']))
|
| 163 |
-
# Save the BRIA-generated mask (if present before overrides) as mask2.png
|
| 164 |
-
if 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
|
| 165 |
-
tasks.append((seg_dir / 'mask2.png', plant_data['original_mask']))
|
| 166 |
-
if base_image is not None and 'mask' in plant_data:
|
| 167 |
-
overlay = self._create_overlay(base_image, plant_data['mask'])
|
| 168 |
-
tasks.append((seg_dir / 'overlay.png', overlay))
|
| 169 |
-
if 'masked_composite' in plant_data:
|
| 170 |
-
tasks.append((seg_dir / 'masked_composite.png', plant_data['masked_composite']))
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
tasks.append((seg_dir / 'maskout_bria.png', maskout_bria))
|
| 181 |
-
# mask3 maskout on original composite
|
| 182 |
-
if base_image is not None and 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
|
| 183 |
-
maskout_mask3 = self._create_maskout_white_background(base_image, plant_data['mask3'])
|
| 184 |
-
tasks.append((seg_dir / 'maskout_mask3.png', maskout_mask3))
|
| 185 |
-
except Exception as _e:
|
| 186 |
-
logger.debug(f"Failed to create double maskouts: {_e}")
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
else:
|
| 194 |
-
for p, img in tasks:
|
| 195 |
-
self._imwrite_fast(p, img)
|
| 196 |
except Exception as e:
|
| 197 |
-
logger.error(f"Failed to save
|
| 198 |
-
|
| 199 |
-
def _save_texture_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 200 |
-
"""Save texture features."""
|
| 201 |
-
if not self.settings.save_images or 'texture_features' not in plant_data:
|
| 202 |
-
return
|
| 203 |
-
|
| 204 |
-
texture_dir = plant_dir / self.settings.texture_dir
|
| 205 |
-
texture_dir.mkdir(exist_ok=True)
|
| 206 |
-
|
| 207 |
-
def save_feature_png(feature_name: str, values: Any, dest: Path, cmap_name: str = 'viridis') -> None:
|
| 208 |
-
try:
|
| 209 |
-
arr = np.asarray(values)
|
| 210 |
-
if arr.ndim == 3 and arr.shape[-1] == 3:
|
| 211 |
-
self._imwrite_fast(dest, cv2.cvtColor(arr.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
| 212 |
-
return
|
| 213 |
-
if self.fast_mode:
|
| 214 |
-
# Fast path: simple normalization, no matplotlib
|
| 215 |
-
normalized = self._normalize_to_uint8(np.nan_to_num(arr.astype(np.float64), nan=0.0))
|
| 216 |
-
self._imwrite_fast(dest, normalized)
|
| 217 |
-
else:
|
| 218 |
-
arr = arr.astype(np.float64)
|
| 219 |
-
masked = np.ma.masked_invalid(arr)
|
| 220 |
-
fig, ax = plt.subplots(figsize=(5, 5))
|
| 221 |
-
ax.set_axis_off()
|
| 222 |
-
ax.set_facecolor('white')
|
| 223 |
-
im = ax.imshow(masked, cmap=cmap_name)
|
| 224 |
-
divider = make_axes_locatable(ax)
|
| 225 |
-
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 226 |
-
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 227 |
-
cbar.set_label(feature_name, fontsize=7)
|
| 228 |
-
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 229 |
-
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 230 |
-
cbar.outline.set_linewidth(0.5)
|
| 231 |
-
plt.tight_layout()
|
| 232 |
-
plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 233 |
-
plt.close(fig)
|
| 234 |
-
except Exception as e:
|
| 235 |
-
logger.error(f"Failed to save texture feature image for {feature_name}: {e}")
|
| 236 |
-
try:
|
| 237 |
-
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 238 |
-
self._imwrite_fast(dest, normalized)
|
| 239 |
-
except Exception:
|
| 240 |
-
pass
|
| 241 |
|
|
|
|
| 242 |
try:
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
band_dir = texture_dir / band
|
| 250 |
-
band_dir.mkdir(exist_ok=True)
|
| 251 |
-
|
| 252 |
-
features = band_data['features']
|
| 253 |
-
|
| 254 |
-
# Save individual feature maps (optionally in parallel)
|
| 255 |
-
items: List[Tuple[str, np.ndarray, Path, str]] = []
|
| 256 |
-
for feature_name, feature_map in features.items():
|
| 257 |
-
if feature_name == 'ehd_features':
|
| 258 |
-
for i in range(feature_map.shape[0]):
|
| 259 |
-
channel = feature_map[i]
|
| 260 |
-
if isinstance(channel, np.ndarray) and channel.size > 0:
|
| 261 |
-
items.append((f'ehd_channel_{i}', channel, band_dir / f'ehd_channel_{i}.png', 'magma'))
|
| 262 |
-
else:
|
| 263 |
-
if isinstance(feature_map, np.ndarray) and feature_map.size > 0:
|
| 264 |
-
cmap_choice = 'gray' if feature_name in ('lbp', 'hog') else 'plasma' if feature_name.startswith('lac') else 'viridis'
|
| 265 |
-
items.append((feature_name, feature_map, band_dir / f'{feature_name}.png', cmap_choice))
|
| 266 |
-
|
| 267 |
-
if self.max_workers > 1 and len(items) > 1:
|
| 268 |
-
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
|
| 269 |
-
futures = [ex.submit(save_feature_png, n, m, p, c) for (n, m, p, c) in items]
|
| 270 |
-
for _ in as_completed(futures):
|
| 271 |
-
pass
|
| 272 |
-
else:
|
| 273 |
-
for (n, m, p, c) in items:
|
| 274 |
-
save_feature_png(n, m, p, c)
|
| 275 |
-
|
| 276 |
-
# Create feature summary plot
|
| 277 |
-
self._create_texture_summary_plot(band_dir, features, band)
|
| 278 |
-
|
| 279 |
-
# Save texture statistics if available
|
| 280 |
-
if 'statistics' in band_data and isinstance(band_data['statistics'], dict):
|
| 281 |
-
try:
|
| 282 |
-
with open(band_dir / 'texture_statistics.json', 'w') as f:
|
| 283 |
-
json.dump(band_data['statistics'], f, indent=2)
|
| 284 |
-
except Exception as e:
|
| 285 |
-
logger.error(f"Failed to save texture statistics for {band}: {e}")
|
| 286 |
-
|
| 287 |
except Exception as e:
|
| 288 |
-
logger.error(f"Failed to save
|
| 289 |
-
|
| 290 |
-
def _save_vegetation_indices(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 291 |
-
"""Save vegetation indices."""
|
| 292 |
-
if not self.settings.save_images or 'vegetation_indices' not in plant_data:
|
| 293 |
-
return
|
| 294 |
-
|
| 295 |
-
veg_dir = plant_dir / self.settings.vegetation_dir
|
| 296 |
-
veg_dir.mkdir(exist_ok=True)
|
| 297 |
-
|
| 298 |
-
# Colormap and range settings per index
|
| 299 |
-
index_cmap_settings = {
|
| 300 |
-
"NDVI": (cm.RdYlGn, -1, 1),
|
| 301 |
-
"GNDVI": (cm.RdYlGn, -1, 1),
|
| 302 |
-
"NDRE": (cm.RdYlGn, -1, 1),
|
| 303 |
-
"GRNDVI": (cm.RdYlGn, -1, 1),
|
| 304 |
-
"TNDVI": (cm.RdYlGn, -1, 1),
|
| 305 |
-
"MGRVI": (cm.RdYlGn, -1, 1),
|
| 306 |
-
"GRVI": (cm.RdYlGn, -1, 1),
|
| 307 |
-
"NGRDI": (cm.RdYlGn, -1, 1),
|
| 308 |
-
"MSAVI": (cm.YlGn, 0, 1),
|
| 309 |
-
"OSAVI": (cm.YlGn, 0, 1),
|
| 310 |
-
"TSAVI": (cm.YlGn, 0, 1),
|
| 311 |
-
"GSAVI": (cm.YlGn, 0, 1),
|
| 312 |
-
"NDWI": (cm.Blues, -1, 1),
|
| 313 |
-
"DSWI4": (cm.Blues, -1, 1),
|
| 314 |
-
"CIRE": (cm.viridis, 0, 10),
|
| 315 |
-
"LCI": (cm.viridis, 0, 5),
|
| 316 |
-
"CIgreen": (cm.viridis, 0, 5),
|
| 317 |
-
"MCARI": (cm.viridis, 0, 1.5),
|
| 318 |
-
"MCARI1": (cm.viridis, 0, 1.5),
|
| 319 |
-
"MCARI2": (cm.viridis, 0, 1.5),
|
| 320 |
-
"CVI": (cm.plasma, 0, 10),
|
| 321 |
-
"TCARI": (cm.viridis, 0, 1),
|
| 322 |
-
"TCARIOSAVI": (cm.viridis, 0, 1),
|
| 323 |
-
"AVI": (cm.magma, 0, 1),
|
| 324 |
-
"SIPI2": (cm.inferno, 0, 1),
|
| 325 |
-
"ARI": (cm.magma, 0, 1),
|
| 326 |
-
"ARI2": (cm.magma, 0, 1),
|
| 327 |
-
"DVI": (cm.Greens, 0, None),
|
| 328 |
-
"WDVI": (cm.Greens, 0, None),
|
| 329 |
-
"SR": (cm.viridis, 0, 10),
|
| 330 |
-
"MSR": (cm.viridis, 0, 10),
|
| 331 |
-
"PVI": (cm.cividis, None, None),
|
| 332 |
-
"GEMI": (cm.cividis, 0, 1),
|
| 333 |
-
"ExR": (cm.Reds, -1, 1),
|
| 334 |
-
"RI": (cm.Reds, 0, None),
|
| 335 |
-
"RRI1": (cm.Reds, 0, 1)
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
def save_index_png(index_name: str, values: Any, dest: Path) -> None:
|
| 339 |
-
try:
|
| 340 |
-
arr = values
|
| 341 |
-
if not isinstance(arr, (list, tuple,)) and isinstance(arr, (float, int)):
|
| 342 |
-
return
|
| 343 |
-
arr = np.asarray(arr, dtype=np.float64)
|
| 344 |
-
if self.fast_mode:
|
| 345 |
-
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 346 |
-
self._imwrite_fast(dest, normalized)
|
| 347 |
-
else:
|
| 348 |
-
cmap, vmin, vmax = index_cmap_settings.get(index_name, (cm.viridis, np.nanmin(arr), np.nanmax(arr)))
|
| 349 |
-
if vmin is None:
|
| 350 |
-
vmin = np.nanmin(arr)
|
| 351 |
-
if vmax is None:
|
| 352 |
-
vmax = np.nanmax(arr)
|
| 353 |
-
if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
|
| 354 |
-
vmin, vmax = 0.0, 1.0
|
| 355 |
-
masked = np.ma.masked_invalid(arr)
|
| 356 |
-
fig, ax = plt.subplots(figsize=(5, 5))
|
| 357 |
-
ax.set_axis_off()
|
| 358 |
-
ax.set_facecolor('white')
|
| 359 |
-
im = ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
| 360 |
-
divider = make_axes_locatable(ax)
|
| 361 |
-
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 362 |
-
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 363 |
-
cbar.set_label(index_name, fontsize=7)
|
| 364 |
-
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 365 |
-
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 366 |
-
cbar.outline.set_linewidth(0.5)
|
| 367 |
-
plt.tight_layout()
|
| 368 |
-
plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 369 |
-
plt.close(fig)
|
| 370 |
-
except Exception as e:
|
| 371 |
-
logger.error(f"Failed to save vegetation index image for {index_name}: {e}")
|
| 372 |
-
try:
|
| 373 |
-
# Fallback simple normalization
|
| 374 |
-
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 375 |
-
self._imwrite_fast(dest, normalized)
|
| 376 |
-
except Exception:
|
| 377 |
-
pass
|
| 378 |
|
|
|
|
| 379 |
try:
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
json.dump(stats, f, indent=2)
|
| 400 |
-
except Exception as e:
|
| 401 |
-
logger.error(f"Failed to save stats for {path.name.split('.')[0]}: {e}")
|
| 402 |
-
|
| 403 |
-
# Create vegetation index summary (skip in fast mode)
|
| 404 |
-
if not self.fast_mode:
|
| 405 |
-
self._create_vegetation_summary_plot(veg_dir, vegetation_indices)
|
| 406 |
-
|
| 407 |
-
# Save aggregated vegetation statistics
|
| 408 |
-
try:
|
| 409 |
-
all_stats = {k: v.get('statistics', {}) for k, v in vegetation_indices.items() if isinstance(v, dict)}
|
| 410 |
-
with open(veg_dir / 'vegetation_statistics.json', 'w') as f:
|
| 411 |
-
json.dump(all_stats, f, indent=2)
|
| 412 |
-
except Exception as e:
|
| 413 |
-
logger.error(f"Failed to save aggregated vegetation statistics: {e}")
|
| 414 |
-
|
| 415 |
except Exception as e:
|
| 416 |
logger.error(f"Failed to save vegetation indices: {e}")
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
"""Save morphological features."""
|
| 420 |
-
if not self.settings.save_images or 'morphology_features' not in plant_data:
|
| 421 |
-
return
|
| 422 |
-
|
| 423 |
-
morph_dir = plant_dir / self.settings.morphology_dir
|
| 424 |
-
morph_dir.mkdir(exist_ok=True)
|
| 425 |
-
|
| 426 |
try:
|
| 427 |
-
|
|
|
|
|
|
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
for image_name, image_data in morphology_features['images'].items():
|
| 432 |
-
if isinstance(image_data, np.ndarray) and image_data.size > 0:
|
| 433 |
-
cv2.imwrite(str(morph_dir / f'{image_name}.png'), image_data)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
traits = morphology_features['traits']
|
| 438 |
-
with open(morph_dir / 'traits.json', 'w') as f:
|
| 439 |
-
json.dump(traits, f, indent=2)
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
except Exception as e:
|
| 442 |
-
logger.error(f"Failed to save
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
"""Save analysis plots."""
|
| 446 |
-
if not self.settings.save_plots or self.fast_mode:
|
| 447 |
-
return
|
| 448 |
-
|
| 449 |
-
analysis_dir = plant_dir / self.settings.analysis_dir
|
| 450 |
-
analysis_dir.mkdir(exist_ok=True)
|
| 451 |
-
|
| 452 |
-
try:
|
| 453 |
-
# Create comprehensive analysis plot
|
| 454 |
-
self._create_comprehensive_analysis_plot(analysis_dir, plant_data)
|
| 455 |
-
|
| 456 |
-
except Exception as e:
|
| 457 |
-
logger.error(f"Failed to save analysis plots: {e}")
|
| 458 |
-
|
| 459 |
-
def _save_metadata(self, plant_dir: Path, plant_key: str, plant_data: Dict[str, Any]) -> None:
|
| 460 |
-
"""Save metadata for the plant."""
|
| 461 |
-
if not self.settings.save_metadata:
|
| 462 |
-
return
|
| 463 |
-
|
| 464 |
try:
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
'features_available': {
|
| 471 |
-
'texture': 'texture_features' in plant_data,
|
| 472 |
-
'vegetation': 'vegetation_indices' in plant_data,
|
| 473 |
-
'morphology': 'morphology_features' in plant_data
|
| 474 |
-
}
|
| 475 |
-
}
|
| 476 |
-
|
| 477 |
-
with open(plant_dir / 'metadata.json', 'w') as f:
|
| 478 |
-
json.dump(metadata, f, indent=2)
|
| 479 |
-
|
| 480 |
except Exception as e:
|
| 481 |
-
logger.error(f"Failed to save
|
| 482 |
|
| 483 |
-
def _create_overlay(self, image: np.ndarray, mask: np.ndarray
|
| 484 |
-
|
| 485 |
-
alpha: float = 0.5) -> np.ndarray:
|
| 486 |
-
"""Return a strictly masked image: pixels where mask>0 keep original; others set to 0."""
|
| 487 |
if mask is None:
|
| 488 |
return image
|
| 489 |
-
# Resize mask to image size if needed
|
| 490 |
if mask.shape[:2] != image.shape[:2]:
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
except Exception:
|
| 494 |
-
pass
|
| 495 |
binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
|
| 496 |
return cv2.bitwise_and(image, image, mask=binary)
|
| 497 |
|
| 498 |
-
def _create_maskout_white_background(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 499 |
-
"""Create maskout image with white background."""
|
| 500 |
-
# Create white background
|
| 501 |
-
white_background = np.full_like(image, 255, dtype=np.uint8)
|
| 502 |
-
|
| 503 |
-
# Apply mask to original image (keep only masked regions)
|
| 504 |
-
masked_image = image.copy()
|
| 505 |
-
masked_image[mask == 0] = 0 # Set non-masked regions to black
|
| 506 |
-
|
| 507 |
-
# Combine: white background + masked image
|
| 508 |
-
result = white_background.copy()
|
| 509 |
-
result[mask > 0] = masked_image[mask > 0]
|
| 510 |
-
|
| 511 |
-
return result
|
| 512 |
-
|
| 513 |
def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 514 |
-
"""Normalize
|
| 515 |
-
if arr.size == 0:
|
| 516 |
-
return arr.astype(np.uint8)
|
| 517 |
-
|
| 518 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 519 |
-
|
| 520 |
if arr.ptp() > 0:
|
| 521 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 522 |
else:
|
| 523 |
normalized = np.zeros_like(arr)
|
| 524 |
-
|
| 525 |
-
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 526 |
-
|
| 527 |
-
def _create_texture_summary_plot(self, output_dir: Path, features: Dict[str, np.ndarray], band: str) -> None:
|
| 528 |
-
"""Create texture feature summary plot."""
|
| 529 |
-
try:
|
| 530 |
-
# Get available features
|
| 531 |
-
available_features = [k for k, v in features.items()
|
| 532 |
-
if isinstance(v, np.ndarray) and v.size > 0 and k != 'ehd_features']
|
| 533 |
-
|
| 534 |
-
if not available_features:
|
| 535 |
-
return
|
| 536 |
-
|
| 537 |
-
# Create subplot
|
| 538 |
-
n_features = len(available_features)
|
| 539 |
-
cols = min(3, n_features)
|
| 540 |
-
rows = (n_features + cols - 1) // cols
|
| 541 |
-
|
| 542 |
-
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 543 |
-
if n_features == 1:
|
| 544 |
-
axes = [axes]
|
| 545 |
-
elif rows == 1:
|
| 546 |
-
axes = axes.reshape(1, -1)
|
| 547 |
-
|
| 548 |
-
for i, feature_name in enumerate(available_features):
|
| 549 |
-
row, col = divmod(i, cols)
|
| 550 |
-
ax = axes[row, col] if rows > 1 else axes[col]
|
| 551 |
-
|
| 552 |
-
feature_map = features[feature_name]
|
| 553 |
-
ax.imshow(feature_map, cmap='viridis')
|
| 554 |
-
ax.set_title(f'{band.upper()} - {feature_name.upper()}')
|
| 555 |
-
ax.axis('off')
|
| 556 |
-
|
| 557 |
-
# Hide unused subplots
|
| 558 |
-
for i in range(n_features, rows * cols):
|
| 559 |
-
row, col = divmod(i, cols)
|
| 560 |
-
ax = axes[row, col] if rows > 1 else axes[col]
|
| 561 |
-
ax.axis('off')
|
| 562 |
-
|
| 563 |
-
plt.tight_layout()
|
| 564 |
-
plt.savefig(output_dir / f'{band}_texture_summary.png',
|
| 565 |
-
dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 566 |
-
plt.close()
|
| 567 |
-
|
| 568 |
-
except Exception as e:
|
| 569 |
-
logger.error(f"Failed to create texture summary plot: {e}")
|
| 570 |
-
|
| 571 |
-
def _create_vegetation_summary_plot(self, output_dir: Path, vegetation_indices: Dict[str, Any]) -> None:
|
| 572 |
-
"""Create vegetation index summary plot."""
|
| 573 |
-
try:
|
| 574 |
-
# Get available indices
|
| 575 |
-
available_indices = [k for k, v in vegetation_indices.items()
|
| 576 |
-
if isinstance(v, dict) and 'values' in v and isinstance(v['values'], np.ndarray)]
|
| 577 |
-
|
| 578 |
-
if not available_indices:
|
| 579 |
-
return
|
| 580 |
-
|
| 581 |
-
# Create subplot
|
| 582 |
-
n_indices = len(available_indices)
|
| 583 |
-
cols = min(3, n_indices)
|
| 584 |
-
rows = (n_indices + cols - 1) // cols
|
| 585 |
-
|
| 586 |
-
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 587 |
-
if n_indices == 1:
|
| 588 |
-
axes = [axes]
|
| 589 |
-
elif rows == 1:
|
| 590 |
-
axes = axes.reshape(1, -1)
|
| 591 |
-
|
| 592 |
-
for i, index_name in enumerate(available_indices):
|
| 593 |
-
row, col = divmod(i, cols)
|
| 594 |
-
ax = axes[row, col] if rows > 1 else axes[col]
|
| 595 |
-
|
| 596 |
-
values = vegetation_indices[index_name]['values']
|
| 597 |
-
im = ax.imshow(values, cmap='RdYlGn')
|
| 598 |
-
ax.set_title(f'{index_name}')
|
| 599 |
-
ax.axis('off')
|
| 600 |
-
divider = make_axes_locatable(ax)
|
| 601 |
-
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 602 |
-
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 603 |
-
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 604 |
-
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 605 |
-
cbar.outline.set_linewidth(0.5)
|
| 606 |
-
|
| 607 |
-
# Hide unused subplots
|
| 608 |
-
for i in range(n_indices, rows * cols):
|
| 609 |
-
row, col = divmod(i, cols)
|
| 610 |
-
ax = axes[row, col] if rows > 1 else axes[col]
|
| 611 |
-
ax.axis('off')
|
| 612 |
-
|
| 613 |
-
plt.tight_layout()
|
| 614 |
-
plt.savefig(output_dir / 'vegetation_indices_summary.png',
|
| 615 |
-
dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 616 |
-
plt.close()
|
| 617 |
-
|
| 618 |
-
except Exception as e:
|
| 619 |
-
logger.error(f"Failed to create vegetation summary plot: {e}")
|
| 620 |
-
|
| 621 |
-
def _create_comprehensive_analysis_plot(self, output_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 622 |
-
"""Create comprehensive analysis plot."""
|
| 623 |
-
try:
|
| 624 |
-
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
| 625 |
-
|
| 626 |
-
# Original image
|
| 627 |
-
if 'composite' in plant_data:
|
| 628 |
-
axes[0, 0].imshow(cv2.cvtColor(plant_data['composite'], cv2.COLOR_BGR2RGB))
|
| 629 |
-
axes[0, 0].set_title('Original Composite')
|
| 630 |
-
axes[0, 0].axis('off')
|
| 631 |
-
|
| 632 |
-
# Mask
|
| 633 |
-
if 'mask' in plant_data:
|
| 634 |
-
axes[0, 1].imshow(plant_data['mask'], cmap='gray')
|
| 635 |
-
axes[0, 1].set_title('Segmentation Mask')
|
| 636 |
-
axes[0, 1].axis('off')
|
| 637 |
-
|
| 638 |
-
# Overlay
|
| 639 |
-
if 'composite' in plant_data and 'mask' in plant_data:
|
| 640 |
-
overlay = self._create_overlay(plant_data['composite'], plant_data['mask'])
|
| 641 |
-
axes[0, 2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
|
| 642 |
-
axes[0, 2].set_title('Overlay')
|
| 643 |
-
axes[0, 2].axis('off')
|
| 644 |
-
|
| 645 |
-
# Texture features (if available)
|
| 646 |
-
if 'texture_features' in plant_data and 'color' in plant_data['texture_features']:
|
| 647 |
-
color_features = plant_data['texture_features']['color'].get('features', {})
|
| 648 |
-
if 'lbp' in color_features:
|
| 649 |
-
axes[1, 0].imshow(color_features['lbp'], cmap='viridis')
|
| 650 |
-
axes[1, 0].set_title('LBP Texture')
|
| 651 |
-
axes[1, 0].axis('off')
|
| 652 |
-
|
| 653 |
-
# Vegetation indices (if available)
|
| 654 |
-
if 'vegetation_indices' in plant_data:
|
| 655 |
-
veg_indices = plant_data['vegetation_indices']
|
| 656 |
-
if 'NDVI' in veg_indices and 'values' in veg_indices['NDVI']:
|
| 657 |
-
axes[1, 1].imshow(veg_indices['NDVI']['values'], cmap='RdYlGn')
|
| 658 |
-
axes[1, 1].set_title('NDVI')
|
| 659 |
-
axes[1, 1].axis('off')
|
| 660 |
-
|
| 661 |
-
# Morphology (if available)
|
| 662 |
-
if 'morphology_features' in plant_data and 'images' in plant_data['morphology_features']:
|
| 663 |
-
morph_images = plant_data['morphology_features']['images']
|
| 664 |
-
if 'skeleton' in morph_images:
|
| 665 |
-
axes[1, 2].imshow(morph_images['skeleton'], cmap='gray')
|
| 666 |
-
axes[1, 2].set_title('Skeleton')
|
| 667 |
-
axes[1, 2].axis('off')
|
| 668 |
-
|
| 669 |
-
plt.tight_layout()
|
| 670 |
-
plt.savefig(output_dir / 'comprehensive_analysis.png',
|
| 671 |
-
dpi=min(getattr(self.settings, 'plot_dpi', 100), 100), bbox_inches='tight')
|
| 672 |
-
plt.close()
|
| 673 |
-
|
| 674 |
-
except Exception as e:
|
| 675 |
-
logger.error(f"Failed to create comprehensive analysis plot: {e}")
|
| 676 |
-
|
| 677 |
-
def create_pipeline_summary(self, results: Dict[str, Any]) -> None:
|
| 678 |
-
"""Create a summary of the entire pipeline run."""
|
| 679 |
-
try:
|
| 680 |
-
summary_file = self.output_folder / 'pipeline_summary.json'
|
| 681 |
-
|
| 682 |
-
with open(summary_file, 'w') as f:
|
| 683 |
-
json.dump(results['summary'], f, indent=2)
|
| 684 |
-
|
| 685 |
-
logger.info(f"Pipeline summary saved to {summary_file}")
|
| 686 |
-
|
| 687 |
-
except Exception as e:
|
| 688 |
-
logger.error(f"Failed to create pipeline summary: {e}")
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal output manager for demo (saves only 7 required images).
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
+
import matplotlib
|
| 9 |
+
if os.environ.get('MPLBACKEND') is None:
|
| 10 |
+
matplotlib.use('Agg')
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import matplotlib.cm as cm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
+
from typing import Dict, Any
|
|
|
|
|
|
|
| 15 |
import logging
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
| 20 |
class OutputManager:
|
| 21 |
+
"""Minimal output manager for demo."""
|
| 22 |
|
| 23 |
def __init__(self, output_folder: str, settings: Any):
|
| 24 |
+
"""Initialize output manager."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.output_folder = Path(output_folder)
|
| 26 |
self.settings = settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
try:
|
| 28 |
+
self.minimal_demo: bool = bool(int(os.environ.get('MINIMAL_DEMO', '0')))
|
| 29 |
except Exception:
|
| 30 |
+
self.minimal_demo = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def create_output_directories(self) -> None:
|
| 34 |
+
"""Create output directories."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 36 |
|
| 37 |
def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
|
| 38 |
+
"""Save minimal demo outputs only."""
|
| 39 |
+
if not self.minimal_demo:
|
| 40 |
+
logger.warning("OutputManager configured for minimal demo only")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return
|
| 42 |
|
| 43 |
+
self._save_minimal_demo_outputs(plant_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
def _save_minimal_demo_outputs(self, plant_data: Dict[str, Any]) -> None:
|
| 46 |
+
"""Save only the 7 required images."""
|
| 47 |
+
results_dir = self.output_folder / 'results'
|
| 48 |
+
veg_dir = self.output_folder / 'Vegetation_indices_images'
|
| 49 |
+
tex_dir = self.output_folder / 'texture_output'
|
| 50 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
veg_dir.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
tex_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# 1. Mask
|
| 55 |
+
try:
|
| 56 |
+
mask = plant_data.get('mask')
|
| 57 |
+
if isinstance(mask, np.ndarray):
|
| 58 |
+
cv2.imwrite(str(results_dir / 'mask.png'), mask)
|
|
|
|
|
|
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to save mask: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# 2. Overlay
|
| 63 |
try:
|
| 64 |
+
base_image = plant_data.get('composite')
|
| 65 |
+
mask = plant_data.get('mask')
|
| 66 |
+
if isinstance(base_image, np.ndarray) and isinstance(mask, np.ndarray):
|
| 67 |
+
overlay = self._create_overlay(base_image, mask)
|
| 68 |
+
cv2.imwrite(str(results_dir / 'overlay.png'), overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
except Exception as e:
|
| 70 |
+
logger.error(f"Failed to save overlay: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# 3-5. Vegetation indices (NDVI, ARI, GNDVI)
|
| 73 |
try:
|
| 74 |
+
veg = plant_data.get('vegetation_indices', {})
|
| 75 |
+
for name in ['NDVI', 'ARI', 'GNDVI']:
|
| 76 |
+
data = veg.get(name, {})
|
| 77 |
+
values = data.get('values') if isinstance(data, dict) else None
|
| 78 |
+
if isinstance(values, np.ndarray) and values.size > 0:
|
| 79 |
+
try:
|
| 80 |
+
cmap = cm.RdYlGn if name in ['NDVI', 'GNDVI'] else cm.magma
|
| 81 |
+
vmin, vmax = (-1, 1) if name in ['NDVI', 'GNDVI'] else (0, 1)
|
| 82 |
+
|
| 83 |
+
masked = np.ma.masked_invalid(values.astype(np.float64))
|
| 84 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 85 |
+
ax.set_axis_off()
|
| 86 |
+
ax.set_facecolor('white')
|
| 87 |
+
ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
| 88 |
+
plt.tight_layout()
|
| 89 |
+
plt.savefig(veg_dir / f"{name.lower()}.png", dpi=100, bbox_inches='tight')
|
| 90 |
+
plt.close(fig)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Failed to save {name}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
except Exception as e:
|
| 94 |
logger.error(f"Failed to save vegetation indices: {e}")
|
| 95 |
+
|
| 96 |
+
# 6-8. Texture features (LBP, HOG, Lacunarity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
try:
|
| 98 |
+
tex = plant_data.get('texture_features', {})
|
| 99 |
+
color_band = tex.get('color', {})
|
| 100 |
+
feats = color_band.get('features', {})
|
| 101 |
|
| 102 |
+
if isinstance(feats.get('lbp'), np.ndarray) and feats['lbp'].size > 0:
|
| 103 |
+
cv2.imwrite(str(tex_dir / 'lbp.png'), feats['lbp'].astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
if isinstance(feats.get('hog'), np.ndarray) and feats['hog'].size > 0:
|
| 106 |
+
cv2.imwrite(str(tex_dir / 'hog.png'), feats['hog'].astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
lac = feats.get('lac2')
|
| 109 |
+
if isinstance(lac, np.ndarray) and lac.size > 0:
|
| 110 |
+
if lac.dtype != np.uint8:
|
| 111 |
+
lac = self._normalize_to_uint8(lac.astype(np.float64))
|
| 112 |
+
cv2.imwrite(str(tex_dir / 'lacunarity.png'), lac)
|
| 113 |
except Exception as e:
|
| 114 |
+
logger.error(f"Failed to save texture: {e}")
|
| 115 |
+
|
| 116 |
+
# 9. Morphology size analysis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
+
morph = plant_data.get('morphology_features', {})
|
| 119 |
+
images = morph.get('images', {})
|
| 120 |
+
size_img = images.get('size_analysis')
|
| 121 |
+
if isinstance(size_img, np.ndarray) and size_img.size > 0:
|
| 122 |
+
cv2.imwrite(str(results_dir / 'size.size_analysis.png'), size_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
except Exception as e:
|
| 124 |
+
logger.error(f"Failed to save size analysis: {e}")
|
| 125 |
|
| 126 |
+
def _create_overlay(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 127 |
+
"""Create overlay (masked pixels only)."""
|
|
|
|
|
|
|
| 128 |
if mask is None:
|
| 129 |
return image
|
|
|
|
| 130 |
if mask.shape[:2] != image.shape[:2]:
|
| 131 |
+
mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]),
|
| 132 |
+
interpolation=cv2.INTER_NEAREST)
|
|
|
|
|
|
|
| 133 |
binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
|
| 134 |
return cv2.bitwise_and(image, image, mask=binary)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 137 |
+
"""Normalize to uint8."""
|
|
|
|
|
|
|
|
|
|
| 138 |
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
| 139 |
if arr.ptp() > 0:
|
| 140 |
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 141 |
else:
|
| 142 |
normalized = np.zeros_like(arr)
|
| 143 |
+
return np.clip(normalized, 0, 255).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/pipeline.py
CHANGED
|
@@ -1,620 +1,110 @@
|
|
| 1 |
"""
|
| 2 |
Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
to feature extraction and result output.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
| 9 |
-
import subprocess
|
| 10 |
import logging
|
| 11 |
from pathlib import Path
|
| 12 |
-
from typing import Dict, Any, Optional
|
| 13 |
import numpy as np
|
| 14 |
import cv2
|
| 15 |
-
import torch
|
| 16 |
-
from torchvision import transforms
|
| 17 |
-
from transformers import AutoModelForImageSegmentation
|
| 18 |
from sklearn.decomposition import PCA
|
| 19 |
-
try:
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
except Exception:
|
| 22 |
-
tqdm = None
|
| 23 |
|
| 24 |
from .config import Config
|
| 25 |
-
from .data import
|
| 26 |
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
|
| 27 |
from .output import OutputManager
|
| 28 |
from .segmentation import SegmentationManager
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
from .segmentation.occlusion_handler import OcclusionHandler # type: ignore
|
| 32 |
-
except Exception:
|
| 33 |
-
OcclusionHandler = None # type: ignore
|
| 34 |
|
| 35 |
|
| 36 |
class SorghumPipeline:
|
| 37 |
-
"""
|
| 38 |
-
Main pipeline class for sorghum plant phenotyping.
|
| 39 |
-
|
| 40 |
-
This class orchestrates the entire pipeline from data loading
|
| 41 |
-
to feature extraction and result output.
|
| 42 |
-
"""
|
| 43 |
|
| 44 |
-
def __init__(self,
|
| 45 |
-
"""
|
| 46 |
-
Initialize the pipeline.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
config_path: Path to configuration file
|
| 50 |
-
config: Configuration object (if not using file)
|
| 51 |
-
include_ignored: Whether to include ignored plants
|
| 52 |
-
enable_occlusion_handling: Whether to enable SAM2Long occlusion handling
|
| 53 |
-
"""
|
| 54 |
-
# Setup logging
|
| 55 |
self._setup_logging()
|
| 56 |
-
|
| 57 |
-
# Load configuration
|
| 58 |
-
if config is not None:
|
| 59 |
-
self.config = config
|
| 60 |
-
elif config_path is not None:
|
| 61 |
-
self.config = Config(config_path)
|
| 62 |
-
else:
|
| 63 |
-
raise ValueError("Either config_path or config must be provided")
|
| 64 |
-
|
| 65 |
-
# Validate configuration
|
| 66 |
self.config.validate()
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
self.enable_occlusion_handling = enable_occlusion_handling
|
| 70 |
-
self.enable_instance_integration = enable_instance_integration
|
| 71 |
-
self.strict_loader = strict_loader
|
| 72 |
-
self.excluded_dates = excluded_dates or []
|
| 73 |
-
|
| 74 |
-
# Initialize components
|
| 75 |
-
self._initialize_components(include_ignored)
|
| 76 |
-
|
| 77 |
-
logger.info("Sorghum Pipeline initialized successfully")
|
| 78 |
|
| 79 |
def _setup_logging(self):
|
| 80 |
"""Setup logging configuration."""
|
| 81 |
logging.basicConfig(
|
| 82 |
level=logging.INFO,
|
| 83 |
-
format='%(asctime)s - %(
|
| 84 |
-
handlers=[
|
| 85 |
-
logging.StreamHandler(),
|
| 86 |
-
logging.FileHandler('sorghum_pipeline.log')
|
| 87 |
-
]
|
| 88 |
)
|
| 89 |
-
global logger
|
| 90 |
-
logger = logging.getLogger(__name__)
|
| 91 |
|
| 92 |
-
def _initialize_components(self
|
| 93 |
-
"""Initialize
|
| 94 |
-
|
| 95 |
-
self.
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
strict_loader=self.strict_loader,
|
| 100 |
-
excluded_dates=self.excluded_dates,
|
| 101 |
-
)
|
| 102 |
-
self.preprocessor = ImagePreprocessor(
|
| 103 |
-
target_size=self.config.processing.target_size
|
| 104 |
-
)
|
| 105 |
-
self.mask_handler = MaskHandler(
|
| 106 |
-
min_area=self.config.processing.min_component_area,
|
| 107 |
-
kernel_size=self.config.processing.morphology_kernel_size
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
# Feature extractors
|
| 111 |
-
self.texture_extractor = TextureExtractor(
|
| 112 |
-
lbp_points=self.config.processing.lbp_points,
|
| 113 |
-
lbp_radius=self.config.processing.lbp_radius,
|
| 114 |
-
hog_orientations=self.config.processing.hog_orientations,
|
| 115 |
-
hog_pixels_per_cell=self.config.processing.hog_pixels_per_cell,
|
| 116 |
-
hog_cells_per_block=self.config.processing.hog_cells_per_block,
|
| 117 |
-
lacunarity_window=self.config.processing.lacunarity_window,
|
| 118 |
-
ehd_threshold=self.config.processing.ehd_threshold,
|
| 119 |
-
angle_resolution=self.config.processing.angle_resolution
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
self.vegetation_extractor = VegetationIndexExtractor(
|
| 123 |
-
epsilon=self.config.processing.epsilon,
|
| 124 |
-
soil_factor=self.config.processing.soil_factor
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
self.morphology_extractor = MorphologyExtractor(
|
| 128 |
-
pixel_to_cm=self.config.processing.pixel_to_cm,
|
| 129 |
-
prune_sizes=self.config.processing.prune_sizes
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
# Segmentation
|
| 133 |
self.segmentation_manager = SegmentationManager(
|
| 134 |
-
model_name=
|
| 135 |
device=self.config.get_device(),
|
| 136 |
-
threshold=
|
| 137 |
-
trust_remote_code=
|
| 138 |
-
cache_dir=self.config.model.cache_dir if getattr(self.config.model, 'cache_dir', '') else None,
|
| 139 |
-
local_files_only=getattr(self.config.model, 'local_files_only', False),
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
# Occlusion handling (optional)
|
| 143 |
-
self.occlusion_handler = None
|
| 144 |
-
if self.enable_occlusion_handling and OcclusionHandler is not None:
|
| 145 |
-
try:
|
| 146 |
-
self.occlusion_handler = OcclusionHandler(
|
| 147 |
-
device=self.config.get_device(),
|
| 148 |
-
model="tiny", # Can be made configurable
|
| 149 |
-
confidence_threshold=0.5,
|
| 150 |
-
iou_threshold=0.1
|
| 151 |
-
)
|
| 152 |
-
logger.info("Occlusion handler initialized successfully")
|
| 153 |
-
except Exception as e:
|
| 154 |
-
logger.warning(f"Failed to initialize occlusion handler: {e}")
|
| 155 |
-
logger.warning("Continuing without occlusion handling")
|
| 156 |
-
self.occlusion_handler = None
|
| 157 |
-
elif self.enable_occlusion_handling and OcclusionHandler is None:
|
| 158 |
-
logger.warning("Occlusion handler module not found; continuing without occlusion handling")
|
| 159 |
-
|
| 160 |
-
# Output manager
|
| 161 |
self.output_manager = OutputManager(
|
| 162 |
output_folder=self.config.paths.output_folder,
|
| 163 |
settings=self.config.output
|
| 164 |
)
|
| 165 |
|
| 166 |
-
def
|
| 167 |
-
"""Attempt to free GPU memory prior to running SAM2Long in a subprocess.
|
| 168 |
-
|
| 169 |
-
- Moves BRIA segmentation model to CPU if present
|
| 170 |
-
- Deletes the model reference to release VRAM
|
| 171 |
-
- Calls torch.cuda.empty_cache()
|
| 172 |
"""
|
| 173 |
-
|
| 174 |
-
import torch as _torch # type: ignore
|
| 175 |
-
# Move BRIA model to CPU and drop reference
|
| 176 |
-
try:
|
| 177 |
-
if getattr(self, 'segmentation_manager', None) is not None:
|
| 178 |
-
mdl = getattr(self.segmentation_manager, 'model', None)
|
| 179 |
-
if mdl is not None:
|
| 180 |
-
try:
|
| 181 |
-
mdl.to('cpu')
|
| 182 |
-
except Exception:
|
| 183 |
-
pass
|
| 184 |
-
try:
|
| 185 |
-
delattr(self.segmentation_manager, 'model')
|
| 186 |
-
except Exception:
|
| 187 |
-
pass
|
| 188 |
-
# Ensure attribute exists but is None for future checks
|
| 189 |
-
try:
|
| 190 |
-
self.segmentation_manager.model = None # type: ignore
|
| 191 |
-
except Exception:
|
| 192 |
-
pass
|
| 193 |
-
except Exception:
|
| 194 |
-
pass
|
| 195 |
-
# Free CUDA cache
|
| 196 |
-
try:
|
| 197 |
-
if _torch.cuda.is_available():
|
| 198 |
-
_torch.cuda.empty_cache()
|
| 199 |
-
except Exception:
|
| 200 |
-
pass
|
| 201 |
-
logger.info("Freed GPU memory before SAM2Long invocation (moved BRIA to CPU and emptied cache)")
|
| 202 |
-
except Exception as e:
|
| 203 |
-
logger.warning(f"Failed to free GPU memory before instance segmentation: {e}")
|
| 204 |
-
|
| 205 |
-
def run(self, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None, filter_frames: Optional[List[str]] = None, run_instance_segmentation: bool = False, features_frame_only: Optional[int] = None, reuse_instance_results: bool = False, instance_mapping_path: Optional[str] = None, force_reprocess: bool = False, respect_instance_frame_rules_for_features: bool = False, substitute_feature_image_from_instance_src: bool = False) -> Dict[str, Any]:
|
| 206 |
-
"""
|
| 207 |
-
Run the complete pipeline.
|
| 208 |
|
| 209 |
Args:
|
| 210 |
-
|
| 211 |
-
segmentation_only: If True, run segmentation only and skip feature extraction
|
| 212 |
|
| 213 |
Returns:
|
| 214 |
-
Dictionary containing
|
| 215 |
"""
|
| 216 |
-
logger.info("Starting
|
| 217 |
|
| 218 |
try:
|
| 219 |
import time
|
|
|
|
|
|
|
| 220 |
total_start = time.perf_counter()
|
| 221 |
-
# Step 1: Load data
|
| 222 |
-
logger.info("Step 1/6: Loading data...")
|
| 223 |
-
# In reuse mode we need all frames to select the mapped frame per plant
|
| 224 |
-
if reuse_instance_results:
|
| 225 |
-
plants = self.data_loader.load_all_frames()
|
| 226 |
-
else:
|
| 227 |
-
# If specific frames are requested, we must load all frames to filter correctly
|
| 228 |
-
if load_all_frames or (filter_frames is not None and len(filter_frames) > 0):
|
| 229 |
-
plants = self.data_loader.load_all_frames()
|
| 230 |
-
else:
|
| 231 |
-
plants = self.data_loader.load_selected_frames()
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
}
|
| 240 |
-
|
| 241 |
-
# Optional filter by specific frame numbers (e.g., ["9"] or ["frame9"])
|
| 242 |
-
if filter_frames:
|
| 243 |
-
# Normalize to 'frameX' tokens
|
| 244 |
-
wanted = set(
|
| 245 |
-
[f if str(f).startswith('frame') else f"frame{str(f)}" for f in filter_frames]
|
| 246 |
-
)
|
| 247 |
-
plants = {
|
| 248 |
-
key: pdata for key, pdata in plants.items()
|
| 249 |
-
if key.split('_')[-1] in wanted
|
| 250 |
-
}
|
| 251 |
-
|
| 252 |
-
if not plants:
|
| 253 |
-
raise ValueError("No plant data loaded")
|
| 254 |
-
|
| 255 |
-
logger.info(f"Loaded {len(plants)} plants")
|
| 256 |
-
|
| 257 |
-
# If reusing instance results with mapping, restrict to exactly the mapped frame per plant (default frame8)
|
| 258 |
-
if reuse_instance_results:
|
| 259 |
-
try:
|
| 260 |
-
import json as _json
|
| 261 |
-
if instance_mapping_path is None:
|
| 262 |
-
raise ValueError("instance_mapping_path is required in reuse mode")
|
| 263 |
-
_map = _json.load(open(instance_mapping_path, 'r'))
|
| 264 |
-
# Normalize mapping plant keys and compute target frame (default 8)
|
| 265 |
-
target_frame_by_plant = {}
|
| 266 |
-
for pk, pv in _map.items():
|
| 267 |
-
k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
|
| 268 |
-
try:
|
| 269 |
-
target_frame_by_plant[k_norm] = int(pv.get('frame', 8))
|
| 270 |
-
except Exception:
|
| 271 |
-
target_frame_by_plant[k_norm] = 8
|
| 272 |
-
before = len(plants)
|
| 273 |
-
plants = {
|
| 274 |
-
key: pdata for key, pdata in plants.items()
|
| 275 |
-
if (len(key.split('_')) > 3 and key.split('_')[3] in target_frame_by_plant
|
| 276 |
-
and key.split('_')[-1] == f"frame{target_frame_by_plant[key.split('_')[3]]}")
|
| 277 |
-
}
|
| 278 |
-
logger.info(f"Restricted loaded data by mapping frames: {before} -> {len(plants)} items")
|
| 279 |
-
except Exception as e:
|
| 280 |
-
logger.warning(f"Failed to restrict loaded data by mapping frames: {e}")
|
| 281 |
-
|
| 282 |
-
# Skip plants that already have saved results (unless force_reprocess)
|
| 283 |
-
if not force_reprocess:
|
| 284 |
-
try:
|
| 285 |
-
before = len(plants)
|
| 286 |
-
filtered = {}
|
| 287 |
-
for key, pdata in plants.items():
|
| 288 |
-
parts = key.split('_')
|
| 289 |
-
if len(parts) < 5:
|
| 290 |
-
filtered[key] = pdata
|
| 291 |
-
continue
|
| 292 |
-
date_key = "_".join(parts[:3])
|
| 293 |
-
plant_name = parts[3]
|
| 294 |
-
plant_dir = Path(self.config.paths.output_folder) / date_key / plant_name
|
| 295 |
-
meta_ok = (plant_dir / 'metadata.json').exists()
|
| 296 |
-
seg_mask_ok = (plant_dir / self.config.output.segmentation_dir / 'mask.png').exists()
|
| 297 |
-
if meta_ok or seg_mask_ok:
|
| 298 |
-
continue
|
| 299 |
-
filtered[key] = pdata
|
| 300 |
-
plants = filtered
|
| 301 |
-
logger.info(f"Skip-existing filter: {before} -> {len(plants)} items to process")
|
| 302 |
-
except Exception as e:
|
| 303 |
-
logger.warning(f"Skip-existing filter failed: {e}")
|
| 304 |
|
| 305 |
-
#
|
| 306 |
-
try:
|
| 307 |
-
rewired = 0
|
| 308 |
-
borrow_map: Dict[str, str] = {
|
| 309 |
-
'plant13': 'plant12',
|
| 310 |
-
'plant14': 'plant13',
|
| 311 |
-
'plant15': 'plant14',
|
| 312 |
-
'plant16': 'plant15',
|
| 313 |
-
}
|
| 314 |
-
for _k in list(plants.keys()):
|
| 315 |
-
_parts = _k.split('_')
|
| 316 |
-
# Expect keys like YYYY_MM_DD_plantX_frameY
|
| 317 |
-
if len(_parts) < 5:
|
| 318 |
-
continue
|
| 319 |
-
_date_key = "_".join(_parts[:3])
|
| 320 |
-
_plant_name = _parts[3]
|
| 321 |
-
_frame_token = _parts[4]
|
| 322 |
-
# Do NOT borrow on 2025_05_08
|
| 323 |
-
if _date_key == '2025_05_08':
|
| 324 |
-
continue
|
| 325 |
-
if _plant_name not in borrow_map:
|
| 326 |
-
continue
|
| 327 |
-
_src_plant = borrow_map[_plant_name]
|
| 328 |
-
_src_key = f"{_date_key}_{_src_plant}_{_frame_token}"
|
| 329 |
-
_src = plants.get(_src_key)
|
| 330 |
-
if not _src:
|
| 331 |
-
# Fallback: load raw image for source plant directly from disk
|
| 332 |
-
try:
|
| 333 |
-
from PIL import Image as _Image
|
| 334 |
-
_date_folder = _date_key.replace('_', '-')
|
| 335 |
-
_frame_num = int(_frame_token.replace('frame', ''))
|
| 336 |
-
_date_dir = Path(self.config.paths.input_folder)
|
| 337 |
-
# If input folder is a parent of dates, append date folder
|
| 338 |
-
if _date_dir.name != _date_folder:
|
| 339 |
-
_date_dir = _date_dir / _date_folder
|
| 340 |
-
_frame_path = _date_dir / _src_plant / f"{_src_plant}_frame{_frame_num}.tif"
|
| 341 |
-
if _frame_path.exists():
|
| 342 |
-
_img = _Image.open(str(_frame_path))
|
| 343 |
-
_src = {"raw_image": (_img, _frame_path.name), "plant_name": _plant_name, "file_path": str(_frame_path)}
|
| 344 |
-
else:
|
| 345 |
-
_src = None
|
| 346 |
-
except Exception:
|
| 347 |
-
_src = None
|
| 348 |
-
if not _src:
|
| 349 |
-
continue
|
| 350 |
-
_tgt = plants[_k]
|
| 351 |
-
# Preserve original raw image once
|
| 352 |
-
if 'raw_image' in _tgt and 'raw_image_original' not in _tgt:
|
| 353 |
-
_tgt['raw_image_original'] = _tgt['raw_image']
|
| 354 |
-
if 'raw_image' in _src:
|
| 355 |
-
_tgt['raw_image'] = _src['raw_image']
|
| 356 |
-
_tgt['borrowed_from'] = _src_plant
|
| 357 |
-
rewired += 1
|
| 358 |
-
if rewired > 0:
|
| 359 |
-
logger.info(f"Pre-seg borrowing applied: rewired {rewired} frames for plants 13/14/15/16")
|
| 360 |
-
except Exception as e:
|
| 361 |
-
logger.warning(f"Pre-seg borrowing failed: {e}")
|
| 362 |
-
|
| 363 |
-
# Step 2: Create composites
|
| 364 |
-
logger.info("Step 2/6: Creating composites...")
|
| 365 |
-
step_start = time.perf_counter()
|
| 366 |
plants = self.preprocessor.create_composites(plants)
|
| 367 |
-
logger.info(f"Composites done in {(time.perf_counter()-step_start):.2f}s")
|
| 368 |
|
| 369 |
-
#
|
| 370 |
-
|
| 371 |
-
step_start = time.perf_counter()
|
| 372 |
-
bbox_lookup = None
|
| 373 |
-
try:
|
| 374 |
-
bbox_dir = getattr(self.config.paths, 'boundingbox_dir', None)
|
| 375 |
-
# Default to project BoundingBox dir if unset or falsy
|
| 376 |
-
if not bbox_dir:
|
| 377 |
-
try:
|
| 378 |
-
self.config.paths.boundingbox_dir = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/BoundingBox"
|
| 379 |
-
bbox_dir = self.config.paths.boundingbox_dir
|
| 380 |
-
except Exception:
|
| 381 |
-
bbox_dir = None
|
| 382 |
-
if bbox_dir:
|
| 383 |
-
bbox_lookup = self.data_loader.load_bounding_boxes(bbox_dir)
|
| 384 |
-
logger.info(f"Loaded bounding boxes from {bbox_dir}")
|
| 385 |
-
except Exception as e:
|
| 386 |
-
logger.warning(f"Failed to load bounding boxes: {e}")
|
| 387 |
-
bbox_lookup = None
|
| 388 |
-
plants = self._segment_plants(plants, bbox_lookup)
|
| 389 |
-
logger.info(f"Segmentation done in {(time.perf_counter()-step_start):.2f}s")
|
| 390 |
|
| 391 |
-
#
|
| 392 |
-
|
| 393 |
-
logger.info("Step 3.5/6: Handling occlusion with SAM2Long...")
|
| 394 |
-
step_start = time.perf_counter()
|
| 395 |
-
plants = self._handle_occlusion(plants)
|
| 396 |
-
logger.info(f"Occlusion handling done in {(time.perf_counter()-step_start):.2f}s")
|
| 397 |
-
|
| 398 |
-
# Optional: Export RMBG maskouts with white background and run instance segmentation
|
| 399 |
-
if (run_instance_segmentation or self.enable_instance_integration) and not reuse_instance_results:
|
| 400 |
-
if not load_all_frames:
|
| 401 |
-
logger.warning("Instance segmentation expects all 13 frames; consider running with load_all_frames=True.")
|
| 402 |
-
logger.info("Step 3.6: Exporting white-background RMBG images for instance segmentation...")
|
| 403 |
-
# Derive date-specific export/result directories when a single date is present
|
| 404 |
-
date_keys = set()
|
| 405 |
-
try:
|
| 406 |
-
for _k in plants.keys():
|
| 407 |
-
_p = _k.split('_')
|
| 408 |
-
if len(_p) >= 3:
|
| 409 |
-
date_keys.add("_".join(_p[:3]))
|
| 410 |
-
except Exception:
|
| 411 |
-
pass
|
| 412 |
-
if len(date_keys) == 1:
|
| 413 |
-
date_key = next(iter(date_keys))
|
| 414 |
-
base_dir = Path(self.config.paths.output_folder) / date_key
|
| 415 |
-
export_dir = base_dir / "instance_input_maskouts"
|
| 416 |
-
instance_results_dir = base_dir / "instance_results"
|
| 417 |
-
else:
|
| 418 |
-
export_dir = Path(self.config.paths.output_folder) / "instance_input_maskouts"
|
| 419 |
-
instance_results_dir = Path(self.config.paths.output_folder) / "instance_results"
|
| 420 |
-
export_dir.mkdir(parents=True, exist_ok=True)
|
| 421 |
-
instance_results_dir.mkdir(parents=True, exist_ok=True)
|
| 422 |
-
self._export_white_background_maskouts(plants, export_dir)
|
| 423 |
-
|
| 424 |
-
logger.info("Invoking final SAM2Long instance segmentation on exported images...")
|
| 425 |
-
# Free GPU memory before launching SAM2Long to avoid CUDA OOM
|
| 426 |
-
self._free_gpu_memory_before_instance()
|
| 427 |
-
env = os.environ.copy()
|
| 428 |
-
env["SAM2LONG_IMAGES_DIR"] = str(export_dir)
|
| 429 |
-
env["SAM2LONG_RESULTS_DIR"] = str(instance_results_dir)
|
| 430 |
-
# Ensure instance outputs include all frames for all dates
|
| 431 |
-
try:
|
| 432 |
-
env.pop("INSTANCE_OUTPUT_FRAMES", None)
|
| 433 |
-
except Exception:
|
| 434 |
-
pass
|
| 435 |
-
script_path = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/Experiments3_code/sam2long_instance_integration.py"
|
| 436 |
-
try:
|
| 437 |
-
subprocess.run(["python", script_path], check=True, env=env)
|
| 438 |
-
except subprocess.CalledProcessError as e:
|
| 439 |
-
logger.error(f"Instance segmentation failed: {e}")
|
| 440 |
-
else:
|
| 441 |
-
# Integrate instance masks (track_0 as target) into pdata before feature extraction
|
| 442 |
-
try:
|
| 443 |
-
self._apply_instance_masks(plants, instance_results_dir)
|
| 444 |
-
logger.info("Applied instance segmentation masks to pipeline data")
|
| 445 |
-
except Exception as e:
|
| 446 |
-
logger.warning(f"Failed to apply instance masks: {e}")
|
| 447 |
-
elif reuse_instance_results:
|
| 448 |
-
# Reuse existing instance masks from mapping file
|
| 449 |
-
if instance_mapping_path is None:
|
| 450 |
-
raise ValueError("reuse_instance_results=True requires instance_mapping_path to be provided")
|
| 451 |
-
try:
|
| 452 |
-
self._apply_instance_masks_from_mapping(plants, Path(instance_mapping_path))
|
| 453 |
-
logger.info("Applied instance masks from mapping file")
|
| 454 |
-
except Exception as e:
|
| 455 |
-
logger.error(f"Failed to apply instance masks from mapping: {e}")
|
| 456 |
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
_map = _json.load(open(instance_mapping_path, 'r'))
|
| 463 |
-
# Normalize map
|
| 464 |
-
_norm = {}
|
| 465 |
-
for pk, pv in _map.items():
|
| 466 |
-
k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
|
| 467 |
-
_norm[k_norm] = int(pv.get('frame', 8))
|
| 468 |
-
before = len(plants)
|
| 469 |
-
plants = {
|
| 470 |
-
k: v for k, v in plants.items()
|
| 471 |
-
if len(k.split('_')) > 3 and k.split('_')[3] in _norm and k.split('_')[-1] == f"frame{_norm[k.split('_')[3]]}"
|
| 472 |
-
}
|
| 473 |
-
logger.info(f"Restricted feature extraction by mapping: {before} -> {len(plants)} items")
|
| 474 |
-
except Exception as e:
|
| 475 |
-
logger.warning(f"Failed to restrict by mapping frames: {e}")
|
| 476 |
-
# Optional: restrict features to per-plant preferred frame using internal frame rules
|
| 477 |
-
if respect_instance_frame_rules_for_features:
|
| 478 |
-
try:
|
| 479 |
-
# Keep this in sync with _apply_instance_masks frame_rules
|
| 480 |
-
frame_rules: Dict[str, int] = {
|
| 481 |
-
"plant33": 2,
|
| 482 |
-
"plant16": 4,
|
| 483 |
-
"plant19": 5,
|
| 484 |
-
"plant26": 8,
|
| 485 |
-
"plant27": 8,
|
| 486 |
-
"plant29": 8,
|
| 487 |
-
"plant35": 7,
|
| 488 |
-
"plant36": 6,
|
| 489 |
-
"plant37": 2,
|
| 490 |
-
"plant45": 5,
|
| 491 |
-
}
|
| 492 |
-
before = len(plants)
|
| 493 |
-
def _keep(k: str) -> bool:
|
| 494 |
-
parts = k.split('_')
|
| 495 |
-
if len(parts) < 2:
|
| 496 |
-
return False
|
| 497 |
-
plant_name = parts[-2]
|
| 498 |
-
frame_token = parts[-1]
|
| 499 |
-
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 500 |
-
return False
|
| 501 |
-
desired = frame_rules.get(plant_name, 8)
|
| 502 |
-
return frame_token == f"frame{desired}"
|
| 503 |
-
plants = {k: v for k, v in plants.items() if _keep(k)}
|
| 504 |
-
logger.info(f"Restricted feature extraction by per-plant frame rules: {before} -> {len(plants)} items")
|
| 505 |
-
except Exception as e:
|
| 506 |
-
logger.warning(f"Failed to apply per-plant frame restriction for features: {e}")
|
| 507 |
-
|
| 508 |
-
# Optional: if features_frame_only set, keep only that frame's entries (global single frame)
|
| 509 |
-
if features_frame_only is not None:
|
| 510 |
-
frame_token = f"frame{features_frame_only}"
|
| 511 |
-
plants = {k: v for k, v in plants.items() if k.split('_')[-1] == frame_token}
|
| 512 |
-
logger.info(f"Restricted feature extraction to {len(plants)} items for {frame_token}")
|
| 513 |
-
|
| 514 |
-
# Optional: substitute feature input image from instance src_rules mapping (e.g., plant14 <- plant13)
|
| 515 |
-
if substitute_feature_image_from_instance_src:
|
| 516 |
-
try:
|
| 517 |
-
src_rules: Dict[str, str] = {
|
| 518 |
-
"plant13": "plant12",
|
| 519 |
-
"plant14": "plant13",
|
| 520 |
-
"plant15": "plant14",
|
| 521 |
-
"plant16": "plant15",
|
| 522 |
-
}
|
| 523 |
-
switched = 0
|
| 524 |
-
for key in list(plants.keys()):
|
| 525 |
-
parts = key.split('_')
|
| 526 |
-
if len(parts) < 5:
|
| 527 |
-
continue
|
| 528 |
-
date_key = "_".join(parts[:3])
|
| 529 |
-
plant_name = parts[3]
|
| 530 |
-
frame_token = parts[-1]
|
| 531 |
-
if plant_name not in src_rules:
|
| 532 |
-
continue
|
| 533 |
-
src_plant = src_rules[plant_name]
|
| 534 |
-
src_key = f"{date_key}_{src_plant}_{frame_token}"
|
| 535 |
-
if src_key not in plants:
|
| 536 |
-
continue
|
| 537 |
-
src_pdata = plants[src_key]
|
| 538 |
-
tgt_pdata = plants[key]
|
| 539 |
-
# Preserve the original composite used for segmentation for correct overlays later
|
| 540 |
-
try:
|
| 541 |
-
if 'composite' in tgt_pdata and 'segmentation_composite' not in tgt_pdata:
|
| 542 |
-
tgt_pdata['segmentation_composite'] = tgt_pdata['composite']
|
| 543 |
-
except Exception:
|
| 544 |
-
pass
|
| 545 |
-
# Swap feature inputs: composite and spectral bands
|
| 546 |
-
if 'composite' in src_pdata:
|
| 547 |
-
tgt_pdata['composite'] = src_pdata['composite']
|
| 548 |
-
if 'spectral_stack' in src_pdata:
|
| 549 |
-
tgt_pdata['spectral_stack'] = src_pdata['spectral_stack']
|
| 550 |
-
# Ensure mask aligns with substituted composite; resize if needed
|
| 551 |
-
try:
|
| 552 |
-
import cv2 as _cv2
|
| 553 |
-
import numpy as _np
|
| 554 |
-
comp = tgt_pdata.get('composite')
|
| 555 |
-
msk = tgt_pdata.get('mask')
|
| 556 |
-
if comp is not None and msk is not None:
|
| 557 |
-
ch, cw = comp.shape[:2]
|
| 558 |
-
mh, mw = msk.shape[:2]
|
| 559 |
-
if (mh, mw) != (ch, cw):
|
| 560 |
-
resized = _cv2.resize(msk.astype('uint8'), (cw, ch), interpolation=_cv2.INTER_NEAREST)
|
| 561 |
-
tgt_pdata['mask'] = resized
|
| 562 |
-
if 'soft_mask' in tgt_pdata and isinstance(tgt_pdata['soft_mask'], _np.ndarray):
|
| 563 |
-
tgt_pdata['soft_mask'] = (resized > 0).astype(_np.float32)
|
| 564 |
-
# Precompute masked composite with white background for saving
|
| 565 |
-
white = _np.full_like(comp, 255, dtype=_np.uint8)
|
| 566 |
-
result = white.copy()
|
| 567 |
-
result[tgt_pdata['mask'] > 0] = comp[tgt_pdata['mask'] > 0]
|
| 568 |
-
tgt_pdata['masked_composite'] = result
|
| 569 |
-
except Exception:
|
| 570 |
-
pass
|
| 571 |
-
switched += 1
|
| 572 |
-
if switched > 0:
|
| 573 |
-
logger.info(f"Substituted feature images from src_rules for {switched} items")
|
| 574 |
-
except Exception as e:
|
| 575 |
-
logger.warning(f"Failed feature-image substitution via src_rules: {e}")
|
| 576 |
-
# Step 4: Extract features
|
| 577 |
-
logger.info("Step 4/6: Extracting features...")
|
| 578 |
-
step_start = time.perf_counter()
|
| 579 |
-
# Stream-save mode: save outputs immediately after each plant's features when fast output is enabled
|
| 580 |
-
stream_save = False
|
| 581 |
-
try:
|
| 582 |
-
import os as _os
|
| 583 |
-
stream_save = bool(int(_os.environ.get('STREAM_SAVE', '0'))) or bool(getattr(self.output_manager, 'fast_mode', False))
|
| 584 |
-
except Exception:
|
| 585 |
-
stream_save = False
|
| 586 |
-
|
| 587 |
-
plants = self._extract_features(plants, stream_save=stream_save)
|
| 588 |
-
logger.info(f"Features done in {(time.perf_counter()-step_start):.2f}s")
|
| 589 |
-
|
| 590 |
-
# Step 5: Generate outputs (skip if already stream-saved)
|
| 591 |
-
if not stream_save:
|
| 592 |
-
logger.info("Step 5/6: Generating outputs...")
|
| 593 |
-
step_start = time.perf_counter()
|
| 594 |
-
self._generate_outputs(plants)
|
| 595 |
-
logger.info(f"Outputs done in {(time.perf_counter()-step_start):.2f}s")
|
| 596 |
-
|
| 597 |
-
# Step 6: Create summary
|
| 598 |
-
logger.info("Step 6/6: Creating summary...")
|
| 599 |
-
summary = self._create_summary(plants)
|
| 600 |
-
else:
|
| 601 |
-
logger.info("Segmentation-only mode: skipping texture/vegetation/morphology features and plots")
|
| 602 |
-
# Segmentation-only: generate only segmentation outputs and a minimal summary
|
| 603 |
-
logger.info("Step 4/4: Generating segmentation outputs (segmentation-only mode)...")
|
| 604 |
-
self._generate_outputs(plants)
|
| 605 |
-
summary = {
|
| 606 |
-
"total_plants": len(plants),
|
| 607 |
-
"successful_plants": len(plants),
|
| 608 |
-
"failed_plants": 0,
|
| 609 |
-
"features_extracted": {
|
| 610 |
-
"texture": 0,
|
| 611 |
-
"vegetation": 0,
|
| 612 |
-
"morphology": 0
|
| 613 |
-
}
|
| 614 |
-
}
|
| 615 |
|
| 616 |
total_time = time.perf_counter() - total_start
|
| 617 |
-
logger.info(f"Pipeline completed
|
|
|
|
| 618 |
return {
|
| 619 |
"plants": plants,
|
| 620 |
"summary": summary,
|
|
@@ -626,752 +116,129 @@ class SorghumPipeline:
|
|
| 626 |
logger.error(f"Pipeline failed: {e}")
|
| 627 |
raise
|
| 628 |
|
| 629 |
-
def
|
| 630 |
-
"""
|
| 631 |
-
|
| 632 |
-
Filenames follow: plantX_plantX_frameY_maskout.png so the final instance script can detect plants.
|
| 633 |
-
"""
|
| 634 |
-
# Clear any previous maskouts to avoid processing stale plants
|
| 635 |
-
try:
|
| 636 |
-
if out_dir.exists():
|
| 637 |
-
for p in out_dir.glob("*_maskout.png"):
|
| 638 |
-
try:
|
| 639 |
-
p.unlink()
|
| 640 |
-
except Exception:
|
| 641 |
-
pass
|
| 642 |
-
except Exception:
|
| 643 |
-
pass
|
| 644 |
-
count = 0
|
| 645 |
-
# Per-plant rule: use bbox-only (skip SAM2Long) for these plants on all dates except 2025_05_08
|
| 646 |
-
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
|
| 647 |
-
date_exception = "2025_05_08"
|
| 648 |
for key, pdata in plants.items():
|
| 649 |
try:
|
| 650 |
-
# key format: "YYYY_MM_DD_plantX_frameY"
|
| 651 |
-
parts = key.split('_')
|
| 652 |
-
if len(parts) < 3:
|
| 653 |
-
continue
|
| 654 |
-
plant_name = parts[-2]
|
| 655 |
-
frame_token = parts[-1] # e.g., frame8
|
| 656 |
-
if not plant_name.startswith('plant') or not frame_token.startswith('frame'):
|
| 657 |
-
continue
|
| 658 |
-
date_key = "_".join(parts[:3])
|
| 659 |
-
if (plant_name in bbox_only_plants) and (date_key != date_exception):
|
| 660 |
-
# Skip exporting maskouts for bbox-only plants so SAM2Long does not run on them
|
| 661 |
-
continue
|
| 662 |
-
# Extract frame number
|
| 663 |
-
frame_num = int(frame_token.replace('frame', ''))
|
| 664 |
-
composite = pdata.get('composite')
|
| 665 |
-
mask = pdata.get('mask')
|
| 666 |
-
if composite is None or mask is None:
|
| 667 |
-
continue
|
| 668 |
-
# Ensure 3-channel BGR
|
| 669 |
-
if len(composite.shape) == 2:
|
| 670 |
-
composite_bgr = cv2.cvtColor(composite, cv2.COLOR_GRAY2BGR)
|
| 671 |
-
else:
|
| 672 |
-
composite_bgr = composite
|
| 673 |
-
out_img = composite_bgr.copy()
|
| 674 |
-
# Set background to white where mask == 0
|
| 675 |
-
out_img[mask == 0] = (255, 255, 255)
|
| 676 |
-
out_path = out_dir / f"{plant_name}_{plant_name}_{frame_token}_maskout.png"
|
| 677 |
-
cv2.imwrite(str(out_path), out_img)
|
| 678 |
-
count += 1
|
| 679 |
-
except Exception as e:
|
| 680 |
-
logger.warning(f"Failed to export maskout for {key}: {e}")
|
| 681 |
-
logger.info(f"Exported {count} white-background maskouts to {out_dir}")
|
| 682 |
-
|
| 683 |
-
def _segment_plants(self, plants: Dict[str, Any],
|
| 684 |
-
bbox_lookup: Optional[Dict[str, tuple]]) -> Dict[str, Any]:
|
| 685 |
-
"""Segment plants using BRIA model.
|
| 686 |
-
|
| 687 |
-
If bbox_lookup is provided and contains an entry for the plant (e.g., 'plant1'),
|
| 688 |
-
the image is cropped/masked to the bounding box region before segmentation and the
|
| 689 |
-
predicted mask is mapped back to the full image size. In bbox mode a largest
|
| 690 |
-
connected component post-processing is applied to obtain a clean target mask.
|
| 691 |
-
"""
|
| 692 |
-
total = len(plants)
|
| 693 |
-
iterator = plants.items()
|
| 694 |
-
if tqdm is not None:
|
| 695 |
-
iterator = tqdm(list(plants.items()), desc="Segmenting", total=total, unit="img", leave=False)
|
| 696 |
-
for idx, (key, pdata) in enumerate(iterator):
|
| 697 |
-
try:
|
| 698 |
-
# Get composite image
|
| 699 |
composite = pdata['composite']
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
plant_name = parts[-2] if len(parts) >= 2 else None
|
| 705 |
-
date_key = "_".join(parts[:3]) if len(parts) >= 3 else None # e.g., 2025_04_16
|
| 706 |
-
bbox = None
|
| 707 |
-
if bbox_lookup is not None and plant_name is not None:
|
| 708 |
-
# keys in bbox_lookup are typically like 'plant1'
|
| 709 |
-
bbox = bbox_lookup.get(plant_name)
|
| 710 |
-
# For plant33, ignore any bbox and run full-image segmentation on all dates except the exception
|
| 711 |
-
if plant_name == 'plant33' and date_key != '2025_05_08':
|
| 712 |
-
bbox = None
|
| 713 |
-
|
| 714 |
-
# Plants that should use the bounding box itself as the mask (skip model)
|
| 715 |
-
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant39", "plant42", "plant44", "plant46"}
|
| 716 |
-
use_bbox_only = (plant_name in bbox_only_plants)
|
| 717 |
-
|
| 718 |
-
# Do not use bounding boxes for date 2025_05_08
|
| 719 |
-
if date_key == '2025_05_08':
|
| 720 |
-
bbox = None
|
| 721 |
-
|
| 722 |
-
if bbox is not None:
|
| 723 |
-
# Clamp bbox to image
|
| 724 |
-
x1, y1, x2, y2 = bbox
|
| 725 |
-
x1 = max(0, min(w, int(x1)))
|
| 726 |
-
x2 = max(0, min(w, int(x2)))
|
| 727 |
-
y1 = max(0, min(h, int(y1)))
|
| 728 |
-
y2 = max(0, min(h, int(y2)))
|
| 729 |
-
if x2 <= x1 or y2 <= y1:
|
| 730 |
-
x1, y1, x2, y2 = 0, 0, w, h
|
| 731 |
-
|
| 732 |
-
if use_bbox_only:
|
| 733 |
-
# Use the bbox as the mask directly (255 inside, 0 outside)
|
| 734 |
-
soft_full = np.zeros((h, w), dtype=np.float32)
|
| 735 |
-
soft_full[y1:y2, x1:x2] = 1.0
|
| 736 |
-
bin_full = np.zeros((h, w), dtype=np.uint8)
|
| 737 |
-
bin_full[y1:y2, x1:x2] = 255
|
| 738 |
-
pdata['soft_mask'] = soft_full
|
| 739 |
-
pdata['mask'] = bin_full
|
| 740 |
-
else:
|
| 741 |
-
# Segment inside the bbox region and map back
|
| 742 |
-
crop = composite[y1:y2, x1:x2]
|
| 743 |
-
soft_mask_crop = self.segmentation_manager.segment_image_soft(crop)
|
| 744 |
-
soft_full = np.zeros((h, w), dtype=np.float32)
|
| 745 |
-
soft_resized = cv2.resize(soft_mask_crop, (x2 - x1, y2 - y1), interpolation=cv2.INTER_LINEAR)
|
| 746 |
-
soft_full[y1:y2, x1:x2] = soft_resized
|
| 747 |
-
bin_full = (soft_full > 0.5).astype(np.uint8) * 255
|
| 748 |
-
try:
|
| 749 |
-
n_lbl, labels, stats, _ = cv2.connectedComponentsWithStats(bin_full, 8)
|
| 750 |
-
if n_lbl > 1:
|
| 751 |
-
largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
|
| 752 |
-
bin_full = (labels == largest).astype(np.uint8) * 255
|
| 753 |
-
except Exception:
|
| 754 |
-
pass
|
| 755 |
-
pdata['soft_mask'] = soft_full.astype(np.float32)
|
| 756 |
-
pdata['mask'] = bin_full.astype(np.uint8)
|
| 757 |
-
else:
|
| 758 |
-
# Full-image segmentation (no bbox)
|
| 759 |
-
soft_mask = self.segmentation_manager.segment_image_soft(composite)
|
| 760 |
-
pdata['soft_mask'] = soft_mask
|
| 761 |
-
pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
|
| 762 |
-
|
| 763 |
-
# Progress log every 25 items and for first/last
|
| 764 |
-
if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
|
| 765 |
-
logger.info(f"Segmented {idx + 1}/{total}: {key}")
|
| 766 |
-
|
| 767 |
except Exception as e:
|
| 768 |
logger.error(f"Segmentation failed for {key}: {e}")
|
| 769 |
pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
|
| 770 |
pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
|
| 771 |
-
|
| 772 |
return plants
|
| 773 |
|
| 774 |
-
def
|
| 775 |
-
"""
|
| 776 |
-
Handle occlusion problems using SAM2Long.
|
| 777 |
-
|
| 778 |
-
This method groups plants by their base plant ID and processes
|
| 779 |
-
each plant's 13-frame sequence to differentiate target plant
|
| 780 |
-
from neighboring plants.
|
| 781 |
-
|
| 782 |
-
Args:
|
| 783 |
-
plants: Dictionary of plant data
|
| 784 |
-
|
| 785 |
-
Returns:
|
| 786 |
-
Updated plant data with occlusion handling results
|
| 787 |
-
"""
|
| 788 |
-
if self.occlusion_handler is None:
|
| 789 |
-
logger.warning("Occlusion handler not available, skipping occlusion handling")
|
| 790 |
-
return plants
|
| 791 |
-
|
| 792 |
-
# Group plants by base plant ID (e.g., "plant1" from "plant1_plant1_frame1")
|
| 793 |
-
plant_groups = {}
|
| 794 |
for key, pdata in plants.items():
|
| 795 |
-
# Extract plant ID from key like "plant1_plant1_frame1"
|
| 796 |
-
parts = key.split('_')
|
| 797 |
-
if len(parts) >= 3:
|
| 798 |
-
plant_id = parts[0] # e.g., "plant1"
|
| 799 |
-
if plant_id not in plant_groups:
|
| 800 |
-
plant_groups[plant_id] = []
|
| 801 |
-
plant_groups[plant_id].append((key, pdata))
|
| 802 |
-
|
| 803 |
-
logger.info(f"Processing {len(plant_groups)} plant groups for occlusion handling")
|
| 804 |
-
|
| 805 |
-
# Process each plant group
|
| 806 |
-
for plant_id, plant_frames in plant_groups.items():
|
| 807 |
-
try:
|
| 808 |
-
# Sort frames by frame number
|
| 809 |
-
plant_frames.sort(key=lambda x: int(x[0].split('_')[-1].replace('frame', '')))
|
| 810 |
-
|
| 811 |
-
if len(plant_frames) < 2:
|
| 812 |
-
logger.warning(f"Plant {plant_id} has only {len(plant_frames)} frames, skipping")
|
| 813 |
-
continue
|
| 814 |
-
|
| 815 |
-
# Extract frames and keys
|
| 816 |
-
frame_keys = [x[0] for x in plant_frames]
|
| 817 |
-
frames = [x[1]['composite'] for x in plant_frames]
|
| 818 |
-
|
| 819 |
-
logger.info(f"Processing plant {plant_id} with {len(frames)} frames")
|
| 820 |
-
|
| 821 |
-
# Process with SAM2Long
|
| 822 |
-
occlusion_results = self.occlusion_handler.segment_plant_sequence(
|
| 823 |
-
frames=frames,
|
| 824 |
-
target_plant_id=plant_id
|
| 825 |
-
)
|
| 826 |
-
|
| 827 |
-
# Update plant data with occlusion results
|
| 828 |
-
target_masks = occlusion_results['target_masks']
|
| 829 |
-
neighbor_masks = occlusion_results['neighbor_masks']
|
| 830 |
-
|
| 831 |
-
for i, (key, pdata) in enumerate(plant_frames):
|
| 832 |
-
if i < len(target_masks):
|
| 833 |
-
# Update mask with target plant only
|
| 834 |
-
pdata['original_mask'] = pdata.get('mask', np.zeros_like(target_masks[i]))
|
| 835 |
-
pdata['mask'] = target_masks[i]
|
| 836 |
-
pdata['neighbor_mask'] = neighbor_masks[i]
|
| 837 |
-
pdata['occlusion_handled'] = True
|
| 838 |
-
|
| 839 |
-
# Update soft mask as well
|
| 840 |
-
pdata['original_soft_mask'] = pdata.get('soft_mask', np.zeros_like(target_masks[i], dtype=np.float32))
|
| 841 |
-
pdata['soft_mask'] = (target_masks[i] / 255.0).astype(np.float32)
|
| 842 |
-
|
| 843 |
-
# Calculate and store occlusion metrics
|
| 844 |
-
metrics = self.occlusion_handler.get_occlusion_metrics(occlusion_results)
|
| 845 |
-
for key, pdata in plant_frames:
|
| 846 |
-
pdata['occlusion_metrics'] = metrics
|
| 847 |
-
|
| 848 |
-
logger.info(f"Plant {plant_id} occlusion handling completed")
|
| 849 |
-
logger.info(f" - Average occlusion ratio: {metrics['average_occlusion_ratio']:.3f}")
|
| 850 |
-
logger.info(f" - Frames with occlusion: {metrics['frames_with_occlusion']}")
|
| 851 |
-
|
| 852 |
-
except Exception as e:
|
| 853 |
-
logger.error(f"Occlusion handling failed for plant {plant_id}: {e}")
|
| 854 |
-
# Mark as failed but continue
|
| 855 |
-
for key, pdata in plant_frames:
|
| 856 |
-
pdata['occlusion_handled'] = False
|
| 857 |
-
pdata['occlusion_error'] = str(e)
|
| 858 |
-
|
| 859 |
-
return plants
|
| 860 |
-
|
| 861 |
-
def _extract_features(self, plants: Dict[str, Any], stream_save: bool = False) -> Dict[str, Any]:
|
| 862 |
-
"""Extract all features from plants.
|
| 863 |
-
|
| 864 |
-
If stream_save is True, save outputs for each plant immediately after
|
| 865 |
-
its features are computed to improve throughput and reduce peak memory.
|
| 866 |
-
"""
|
| 867 |
-
total = len(plants)
|
| 868 |
-
logger.info(f"Extracting features for {total} plants...")
|
| 869 |
-
iterator = plants.items()
|
| 870 |
-
if tqdm is not None:
|
| 871 |
-
iterator = tqdm(list(plants.items()), desc="Extracting features", total=total, unit="img", leave=False)
|
| 872 |
-
|
| 873 |
-
# Prepare output directories once if we're streaming saves
|
| 874 |
-
if stream_save:
|
| 875 |
try:
|
| 876 |
-
self.output_manager.create_output_directories()
|
| 877 |
-
except Exception:
|
| 878 |
-
pass
|
| 879 |
-
|
| 880 |
-
for idx, (key, pdata) in enumerate(iterator):
|
| 881 |
-
try:
|
| 882 |
-
logger.debug(f"Extracting features for {key}")
|
| 883 |
-
|
| 884 |
-
# Extract texture features
|
| 885 |
pdata['texture_features'] = self._extract_texture_features(pdata)
|
| 886 |
-
|
| 887 |
-
# Extract vegetation indices
|
| 888 |
pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
|
| 889 |
-
|
| 890 |
-
# Extract morphological features
|
| 891 |
pdata['morphology_features'] = self._extract_morphology_features(pdata)
|
| 892 |
-
|
| 893 |
-
# Immediately save outputs for this plant if streaming is enabled
|
| 894 |
-
if stream_save:
|
| 895 |
-
try:
|
| 896 |
-
self.output_manager.save_plant_results(key, pdata)
|
| 897 |
-
except Exception as _e:
|
| 898 |
-
logger.error(f"Stream-save failed for {key}: {_e}")
|
| 899 |
-
|
| 900 |
-
logger.debug(f"Features extracted for {key}")
|
| 901 |
-
if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
|
| 902 |
-
logger.info(f"Extracted features for {idx + 1}/{total}: {key}")
|
| 903 |
-
|
| 904 |
except Exception as e:
|
| 905 |
logger.error(f"Feature extraction failed for {key}: {e}")
|
| 906 |
-
# Add empty features
|
| 907 |
pdata['texture_features'] = {}
|
| 908 |
pdata['vegetation_indices'] = {}
|
| 909 |
pdata['morphology_features'] = {}
|
| 910 |
-
|
| 911 |
return plants
|
| 912 |
|
| 913 |
def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 914 |
-
"""Extract texture features
|
| 915 |
features = {}
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
except Exception as e:
|
| 938 |
-
logger.error(f"Texture extraction failed for band {band}: {e}")
|
| 939 |
-
features[band] = {'features': {}, 'statistics': {}}
|
| 940 |
|
| 941 |
return features
|
| 942 |
|
| 943 |
def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 944 |
-
"""Extract vegetation indices
|
| 945 |
try:
|
| 946 |
spectral_stack = pdata.get('spectral_stack', {})
|
| 947 |
-
|
| 948 |
-
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 949 |
-
|
| 950 |
if not spectral_stack or mask is None:
|
| 951 |
return {}
|
| 952 |
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
except Exception as e:
|
| 958 |
logger.error(f"Vegetation index extraction failed: {e}")
|
| 959 |
return {}
|
| 960 |
|
| 961 |
def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 962 |
-
"""Extract morphological features
|
| 963 |
try:
|
| 964 |
composite = pdata.get('composite')
|
| 965 |
-
|
| 966 |
-
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 967 |
-
|
| 968 |
if composite is None or mask is None:
|
| 969 |
return {}
|
| 970 |
-
|
| 971 |
-
return self.morphology_extractor.extract_morphology_features(
|
| 972 |
-
composite, mask
|
| 973 |
-
)
|
| 974 |
-
|
| 975 |
except Exception as e:
|
| 976 |
-
logger.error(f"Morphology
|
| 977 |
return {}
|
| 978 |
|
| 979 |
-
def _prepare_band_image(self, pdata: Dict[str, Any], band: str) -> np.ndarray:
|
| 980 |
-
"""Prepare grayscale image for a specific band."""
|
| 981 |
-
if band == 'color':
|
| 982 |
-
composite = pdata['composite']
|
| 983 |
-
# Prefer mask3 → features_mask → mask
|
| 984 |
-
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 985 |
-
if mask is not None:
|
| 986 |
-
masked = self.mask_handler.apply_mask_to_image(composite, mask)
|
| 987 |
-
return cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
|
| 988 |
-
else:
|
| 989 |
-
return cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
|
| 990 |
-
|
| 991 |
-
elif band == 'pca':
|
| 992 |
-
# Create PCA from spectral bands
|
| 993 |
-
spectral_stack = pdata.get('spectral_stack', {})
|
| 994 |
-
# Prefer mask3 → features_mask → mask
|
| 995 |
-
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 996 |
-
|
| 997 |
-
if not spectral_stack:
|
| 998 |
-
return np.zeros((512, 512), dtype=np.uint8)
|
| 999 |
-
|
| 1000 |
-
# Stack bands
|
| 1001 |
-
bands_data = []
|
| 1002 |
-
for b in ['nir', 'red_edge', 'red', 'green']:
|
| 1003 |
-
if b in spectral_stack:
|
| 1004 |
-
arr = spectral_stack[b].squeeze(-1).astype(float)
|
| 1005 |
-
if mask is not None:
|
| 1006 |
-
arr = np.where(mask > 0, arr, np.nan)
|
| 1007 |
-
bands_data.append(arr)
|
| 1008 |
-
|
| 1009 |
-
if not bands_data:
|
| 1010 |
-
return np.zeros((512, 512), dtype=np.uint8)
|
| 1011 |
-
|
| 1012 |
-
# Create PCA
|
| 1013 |
-
full_stack = np.stack(bands_data, axis=-1)
|
| 1014 |
-
h, w, c = full_stack.shape
|
| 1015 |
-
flat = full_stack.reshape(-1, c)
|
| 1016 |
-
valid = ~np.isnan(flat).any(axis=1)
|
| 1017 |
-
|
| 1018 |
-
if valid.sum() == 0:
|
| 1019 |
-
return np.zeros((h, w), dtype=np.uint8)
|
| 1020 |
-
|
| 1021 |
-
vec = np.zeros(h * w)
|
| 1022 |
-
vec[valid] = PCA(n_components=1, whiten=True).fit_transform(
|
| 1023 |
-
flat[valid]
|
| 1024 |
-
).squeeze()
|
| 1025 |
-
|
| 1026 |
-
gray_f = vec.reshape(h, w)
|
| 1027 |
-
if mask is not None:
|
| 1028 |
-
m, M = gray_f[mask > 0].min(), gray_f[mask > 0].max()
|
| 1029 |
-
else:
|
| 1030 |
-
m, M = gray_f.min(), gray_f.max()
|
| 1031 |
-
|
| 1032 |
-
if M > m:
|
| 1033 |
-
gray = ((gray_f - m) / (M - m) * 255).astype(np.uint8)
|
| 1034 |
-
else:
|
| 1035 |
-
gray = np.zeros_like(gray_f, dtype=np.uint8)
|
| 1036 |
-
|
| 1037 |
-
return gray
|
| 1038 |
-
|
| 1039 |
-
else:
|
| 1040 |
-
# Individual spectral band
|
| 1041 |
-
spectral_stack = pdata.get('spectral_stack', {})
|
| 1042 |
-
# Prefer mask3 → features_mask → mask
|
| 1043 |
-
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 1044 |
-
|
| 1045 |
-
if band not in spectral_stack:
|
| 1046 |
-
return np.zeros((512, 512), dtype=np.uint8)
|
| 1047 |
-
|
| 1048 |
-
arr = spectral_stack[band].squeeze(-1).astype(float)
|
| 1049 |
-
if mask is not None:
|
| 1050 |
-
arr = np.where(mask > 0, arr, np.nan)
|
| 1051 |
-
|
| 1052 |
-
if mask is not None:
|
| 1053 |
-
m, M = np.nanmin(arr), np.nanmax(arr)
|
| 1054 |
-
else:
|
| 1055 |
-
m, M = arr.min(), arr.max()
|
| 1056 |
-
|
| 1057 |
-
if M > m:
|
| 1058 |
-
gray = ((np.nan_to_num(arr, nan=m) - m) / (M - m) * 255).astype(np.uint8)
|
| 1059 |
-
else:
|
| 1060 |
-
gray = np.zeros_like(arr, dtype=np.uint8)
|
| 1061 |
-
|
| 1062 |
-
return gray
|
| 1063 |
-
|
| 1064 |
def _generate_outputs(self, plants: Dict[str, Any]) -> None:
|
| 1065 |
-
"""Generate
|
| 1066 |
self.output_manager.create_output_directories()
|
| 1067 |
-
|
| 1068 |
for key, pdata in plants.items():
|
| 1069 |
try:
|
| 1070 |
-
logger.debug(f"Generating outputs for {key}")
|
| 1071 |
self.output_manager.save_plant_results(key, pdata)
|
| 1072 |
except Exception as e:
|
| 1073 |
logger.error(f"Output generation failed for {key}: {e}")
|
| 1074 |
|
| 1075 |
def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 1076 |
-
"""Create summary of
|
| 1077 |
-
|
| 1078 |
"total_plants": len(plants),
|
| 1079 |
-
"successful_plants":
|
| 1080 |
-
"failed_plants": 0,
|
| 1081 |
"features_extracted": {
|
| 1082 |
-
"texture":
|
| 1083 |
-
"vegetation":
|
| 1084 |
-
"morphology":
|
| 1085 |
}
|
| 1086 |
-
}
|
| 1087 |
-
|
| 1088 |
-
for key, pdata in plants.items():
|
| 1089 |
-
try:
|
| 1090 |
-
# Check if features were extracted
|
| 1091 |
-
if pdata.get('texture_features'):
|
| 1092 |
-
summary["features_extracted"]["texture"] += 1
|
| 1093 |
-
if pdata.get('vegetation_indices'):
|
| 1094 |
-
summary["features_extracted"]["vegetation"] += 1
|
| 1095 |
-
if pdata.get('morphology_features'):
|
| 1096 |
-
summary["features_extracted"]["morphology"] += 1
|
| 1097 |
-
|
| 1098 |
-
summary["successful_plants"] += 1
|
| 1099 |
-
|
| 1100 |
-
except Exception:
|
| 1101 |
-
summary["failed_plants"] += 1
|
| 1102 |
-
|
| 1103 |
-
return summary
|
| 1104 |
-
|
| 1105 |
-
def _apply_instance_masks(self, plants: Dict[str, Any], instance_results_dir: Path) -> None:
|
| 1106 |
-
"""Replace segmentation masks with SAM2Long instance masks using track_1.
|
| 1107 |
-
|
| 1108 |
-
Expects files under instance_results_dir/plantX/track_1/frame_YY_mask.png.
|
| 1109 |
-
"""
|
| 1110 |
-
# Default and per-plant overrides for source plant, track and preferred frame
|
| 1111 |
-
default_track = "track_0"
|
| 1112 |
-
src_rules: Dict[str, str] = {
|
| 1113 |
-
"plant13": "plant12",
|
| 1114 |
-
"plant14": "plant13",
|
| 1115 |
-
"plant15": "plant14",
|
| 1116 |
-
"plant16": "plant15",
|
| 1117 |
-
}
|
| 1118 |
-
track_rules: Dict[str, str] = {
|
| 1119 |
-
# explicit track rules
|
| 1120 |
-
"plant1": "track_0",
|
| 1121 |
-
"plant4": "track_0",
|
| 1122 |
-
"plant9": "track_3",
|
| 1123 |
-
"plant13": "track_1",
|
| 1124 |
-
"plant14": "track_0",
|
| 1125 |
-
"plant15": "track_0",
|
| 1126 |
-
"plant16": "track_0",
|
| 1127 |
-
"plant18": "track_0",
|
| 1128 |
-
"plant19": "track_0",
|
| 1129 |
-
"plant23": "track_1",
|
| 1130 |
-
"plant26": "track_0",
|
| 1131 |
-
"plant27": "track_0",
|
| 1132 |
-
"plant29": "track_0",
|
| 1133 |
-
"plant31": "track_1",
|
| 1134 |
-
"plant34": "track_1",
|
| 1135 |
-
"plant35": "track_1",
|
| 1136 |
-
"plant36": "track_0",
|
| 1137 |
-
"plant37": "track_1",
|
| 1138 |
-
"plant38": "track_0",
|
| 1139 |
-
"plant39": "track_1",
|
| 1140 |
-
"plant40": "track_0",
|
| 1141 |
-
"plant41": "track_1",
|
| 1142 |
-
"plant42": "track_0",
|
| 1143 |
-
"plant43": "track_0",
|
| 1144 |
-
"plant45": "track_0",
|
| 1145 |
-
}
|
| 1146 |
-
frame_rules: Dict[str, int] = {
|
| 1147 |
-
# preferred frame overrides (1-based)
|
| 1148 |
-
"plant13": 8,
|
| 1149 |
-
"plant14": 8,
|
| 1150 |
-
"plant15": 8,
|
| 1151 |
-
"plant33": 2,
|
| 1152 |
-
"plant16": 4,
|
| 1153 |
-
"plant19": 5,
|
| 1154 |
-
"plant26": 8,
|
| 1155 |
-
"plant27": 8,
|
| 1156 |
-
"plant29": 8,
|
| 1157 |
-
"plant35": 7,
|
| 1158 |
-
"plant36": 6,
|
| 1159 |
-
"plant37": 2,
|
| 1160 |
-
"plant45": 5,
|
| 1161 |
-
}
|
| 1162 |
-
# Per-plant rule: skip applying instance masks (keep bbox/BRIA mask) on all dates except 2025_05_08
|
| 1163 |
-
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
|
| 1164 |
-
date_exception = "2025_05_08"
|
| 1165 |
-
|
| 1166 |
-
for key, pdata in plants.items():
|
| 1167 |
-
try:
|
| 1168 |
-
parts = key.split('_')
|
| 1169 |
-
if len(parts) < 3:
|
| 1170 |
-
continue
|
| 1171 |
-
plant_name = parts[-2]
|
| 1172 |
-
frame_token = parts[-1] # frame8
|
| 1173 |
-
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 1174 |
-
continue
|
| 1175 |
-
date_key = "_".join(parts[:3])
|
| 1176 |
-
if (plant_name in bbox_only_plants) and (date_key != date_exception):
|
| 1177 |
-
# Do not override masks for bbox-only plants
|
| 1178 |
-
continue
|
| 1179 |
-
frame_num = int(frame_token.replace('frame', ''))
|
| 1180 |
-
# Resolve source plant, track and desired frame
|
| 1181 |
-
src_plant = src_rules.get(plant_name, plant_name)
|
| 1182 |
-
track_name = track_rules.get(plant_name, default_track)
|
| 1183 |
-
desired_frame = frame_rules.get(plant_name, frame_num)
|
| 1184 |
-
plant_dir = Path(instance_results_dir) / src_plant / track_name
|
| 1185 |
-
mask_path = plant_dir / f"frame_{desired_frame:02d}_mask.png"
|
| 1186 |
-
if not mask_path.exists():
|
| 1187 |
-
# Fallback to current frame if override not found
|
| 1188 |
-
fallback = plant_dir / f"frame_{frame_num:02d}_mask.png"
|
| 1189 |
-
if fallback.exists():
|
| 1190 |
-
mask_path = fallback
|
| 1191 |
-
else:
|
| 1192 |
-
# Last-resort: pick any available frame mask in the track directory
|
| 1193 |
-
try:
|
| 1194 |
-
candidates = sorted(plant_dir.glob("frame_*_mask.png"))
|
| 1195 |
-
if len(candidates) > 0:
|
| 1196 |
-
mask_path = candidates[0]
|
| 1197 |
-
else:
|
| 1198 |
-
continue
|
| 1199 |
-
except Exception:
|
| 1200 |
-
continue
|
| 1201 |
-
inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 1202 |
-
if inst_mask is None:
|
| 1203 |
-
continue
|
| 1204 |
-
# Ensure binary uint8 0/255
|
| 1205 |
-
inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
|
| 1206 |
-
pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
|
| 1207 |
-
pdata['mask'] = inst_mask_bin
|
| 1208 |
-
pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
|
| 1209 |
-
pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
|
| 1210 |
-
pdata['instance_applied'] = True
|
| 1211 |
-
|
| 1212 |
-
# Build mask3 = external(mask) AND BRIA(original_mask)
|
| 1213 |
-
try:
|
| 1214 |
-
_m1 = pdata.get('mask')
|
| 1215 |
-
_m2 = pdata.get('original_mask')
|
| 1216 |
-
if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
|
| 1217 |
-
_m1b = (_m1.astype(np.uint8) > 0)
|
| 1218 |
-
_m2b = (_m2.astype(np.uint8) > 0)
|
| 1219 |
-
mask3 = (_m1b & _m2b).astype(np.uint8) * 255
|
| 1220 |
-
pdata['mask3'] = mask3
|
| 1221 |
-
pdata['features_mask'] = mask3
|
| 1222 |
-
except Exception:
|
| 1223 |
-
pass
|
| 1224 |
-
|
| 1225 |
-
# After applying instance masks, also overwrite the composite and spectral stack
|
| 1226 |
-
# with the source plant's raw image (desired frame preferred) so that
|
| 1227 |
-
# feature extraction and saved originals/overlays are consistent with the mask source.
|
| 1228 |
-
try:
|
| 1229 |
-
if plant_name in src_rules:
|
| 1230 |
-
date_key = "_".join(parts[:3])
|
| 1231 |
-
src_key_desired = f"{date_key}_{src_plant}_frame{desired_frame}"
|
| 1232 |
-
src_key_same = f"{date_key}_{src_plant}_{frame_token}"
|
| 1233 |
-
copy_from = plants.get(src_key_desired) or plants.get(src_key_same)
|
| 1234 |
-
if copy_from is None:
|
| 1235 |
-
# Fallback: load source composite from filesystem if not present in plants dict
|
| 1236 |
-
try:
|
| 1237 |
-
from PIL import Image as _Image
|
| 1238 |
-
_date_folder = date_key.replace('_', '-')
|
| 1239 |
-
_date_dir = Path(self.config.paths.input_folder)
|
| 1240 |
-
if _date_dir.name != _date_folder:
|
| 1241 |
-
_date_dir = _date_dir / _date_folder
|
| 1242 |
-
_frame_path = _date_dir / src_plant / f"{src_plant}_frame{desired_frame}.tif"
|
| 1243 |
-
if not _frame_path.exists():
|
| 1244 |
-
_frame_path = _date_dir / src_plant / f"{src_plant}_frame{frame_num}.tif"
|
| 1245 |
-
if _frame_path.exists():
|
| 1246 |
-
_img = _Image.open(str(_frame_path))
|
| 1247 |
-
# Process to composite using preprocessor
|
| 1248 |
-
comp, spec = self.preprocessor.process_raw_image(_img)
|
| 1249 |
-
copy_from = {"composite": comp, "spectral_stack": spec}
|
| 1250 |
-
except Exception:
|
| 1251 |
-
copy_from = None
|
| 1252 |
-
if copy_from is not None:
|
| 1253 |
-
# Preserve the segmentation-time composite once
|
| 1254 |
-
if 'composite' in pdata and 'segmentation_composite' not in pdata:
|
| 1255 |
-
pdata['segmentation_composite'] = pdata['composite']
|
| 1256 |
-
if 'composite' in copy_from:
|
| 1257 |
-
pdata['composite'] = copy_from['composite']
|
| 1258 |
-
if 'spectral_stack' in copy_from:
|
| 1259 |
-
pdata['spectral_stack'] = copy_from['spectral_stack']
|
| 1260 |
-
# Ensure mask size matches the copied composite
|
| 1261 |
-
ch, cw = pdata['composite'].shape[:2]
|
| 1262 |
-
mh, mw = pdata['mask'].shape[:2]
|
| 1263 |
-
if (mh, mw) != (ch, cw):
|
| 1264 |
-
pdata['mask'] = cv2.resize(pdata['mask'].astype('uint8'), (cw, ch), interpolation=cv2.INTER_NEAREST)
|
| 1265 |
-
pdata['soft_mask'] = (pdata['mask'] > 0).astype(np.float32)
|
| 1266 |
-
except Exception:
|
| 1267 |
-
pass
|
| 1268 |
-
except Exception as e:
|
| 1269 |
-
logger.debug(f"Instance mask apply failed for {key}: {e}")
|
| 1270 |
-
|
| 1271 |
-
def _apply_instance_masks_from_mapping(self, plants: Dict[str, Any], mapping_file: Path) -> None:
|
| 1272 |
-
"""Apply instance masks using an explicit mapping file with absolute paths.
|
| 1273 |
-
|
| 1274 |
-
mapping JSON structure:
|
| 1275 |
-
{
|
| 1276 |
-
"plant1": {"frame": 8, "mask_path": "/abs/path/to/plant1/track_X/frame_08_mask.png"},
|
| 1277 |
-
"plant2": {"frame": 8, "mask_path": "/abs/path/.../frame_08_mask.png"},
|
| 1278 |
-
...
|
| 1279 |
-
}
|
| 1280 |
-
If a plant's mapping specifies a different frame, only entries matching that frame are updated.
|
| 1281 |
-
"""
|
| 1282 |
-
import json
|
| 1283 |
-
if not mapping_file.exists():
|
| 1284 |
-
raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
|
| 1285 |
-
with open(mapping_file, "r") as f:
|
| 1286 |
-
mapping = json.load(f)
|
| 1287 |
-
# Normalize mapping plant keys to names like 'plantX'
|
| 1288 |
-
norm_map = {}
|
| 1289 |
-
for k, v in mapping.items():
|
| 1290 |
-
k_norm = k if str(k).startswith("plant") else f"plant{int(k)}" if str(k).isdigit() else str(k)
|
| 1291 |
-
norm_map[k_norm] = v
|
| 1292 |
-
|
| 1293 |
-
for key, pdata in plants.items():
|
| 1294 |
-
try:
|
| 1295 |
-
parts = key.split('_')
|
| 1296 |
-
if len(parts) < 3:
|
| 1297 |
-
continue
|
| 1298 |
-
plant_name = parts[-2]
|
| 1299 |
-
frame_token = parts[-1]
|
| 1300 |
-
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 1301 |
-
continue
|
| 1302 |
-
frame_num = int(frame_token.replace('frame', ''))
|
| 1303 |
-
if plant_name not in norm_map:
|
| 1304 |
-
continue
|
| 1305 |
-
entry = norm_map[plant_name]
|
| 1306 |
-
target_frame = int(entry.get("frame", frame_num))
|
| 1307 |
-
if frame_num != target_frame:
|
| 1308 |
-
# Only update the designated frame for this plant
|
| 1309 |
-
continue
|
| 1310 |
-
mask_path_str = entry.get("mask_path")
|
| 1311 |
-
if not mask_path_str:
|
| 1312 |
-
continue
|
| 1313 |
-
mask_path = Path(mask_path_str)
|
| 1314 |
-
if not mask_path.exists():
|
| 1315 |
-
logger.warning(f"Mask path not found for {plant_name} {frame_token}: {mask_path}")
|
| 1316 |
-
continue
|
| 1317 |
-
inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 1318 |
-
if inst_mask is None:
|
| 1319 |
-
continue
|
| 1320 |
-
inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
|
| 1321 |
-
pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
|
| 1322 |
-
pdata['mask'] = inst_mask_bin
|
| 1323 |
-
pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
|
| 1324 |
-
pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
|
| 1325 |
-
pdata['instance_applied'] = True
|
| 1326 |
-
|
| 1327 |
-
# Build mask3 = external(mask) AND BRIA(original_mask)
|
| 1328 |
-
try:
|
| 1329 |
-
_m1 = pdata.get('mask')
|
| 1330 |
-
_m2 = pdata.get('original_mask')
|
| 1331 |
-
if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
|
| 1332 |
-
_m1b = (_m1.astype(np.uint8) > 0)
|
| 1333 |
-
_m2b = (_m2.astype(np.uint8) > 0)
|
| 1334 |
-
mask3 = (_m1b & _m2b).astype(np.uint8) * 255
|
| 1335 |
-
pdata['mask3'] = mask3
|
| 1336 |
-
pdata['features_mask'] = mask3
|
| 1337 |
-
except Exception:
|
| 1338 |
-
pass
|
| 1339 |
-
except Exception as e:
|
| 1340 |
-
logger.debug(f"Instance mapping apply failed for {key}: {e}")
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
def run_pipeline(config_path: str, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 1344 |
-
"""
|
| 1345 |
-
Convenience function to run the pipeline.
|
| 1346 |
-
|
| 1347 |
-
Args:
|
| 1348 |
-
config_path: Path to configuration file
|
| 1349 |
-
load_all_frames: Whether to load all frames or selected frames
|
| 1350 |
-
segmentation_only: If True, run segmentation only and skip feature extraction
|
| 1351 |
-
|
| 1352 |
-
Returns:
|
| 1353 |
-
Pipeline results
|
| 1354 |
-
"""
|
| 1355 |
-
pipeline = SorghumPipeline(config_path)
|
| 1356 |
-
return pipeline.run(load_all_frames, segmentation_only, filter_plants)
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
if __name__ == "__main__":
|
| 1360 |
-
import sys
|
| 1361 |
-
|
| 1362 |
-
config_path = sys.argv[1] if len(sys.argv) > 1 else "config.yml"
|
| 1363 |
-
load_all = "--all" in sys.argv
|
| 1364 |
-
seg_only = "--seg-only" in sys.argv
|
| 1365 |
-
# Basic arg parse for --plant=<name>
|
| 1366 |
-
plant_filter = None
|
| 1367 |
-
for arg in sys.argv[1:]:
|
| 1368 |
-
if arg.startswith("--plant="):
|
| 1369 |
-
plant_filter = [arg.split("=", 1)[1]]
|
| 1370 |
-
|
| 1371 |
-
try:
|
| 1372 |
-
results = run_pipeline(config_path, load_all, seg_only, plant_filter)
|
| 1373 |
-
print("Pipeline completed successfully!")
|
| 1374 |
-
print(f"Processed {results['summary']['total_plants']} plants")
|
| 1375 |
-
except Exception as e:
|
| 1376 |
-
print(f"Pipeline failed: {e}")
|
| 1377 |
-
sys.exit(1)
|
|
|
|
| 1 |
"""
|
| 2 |
Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
|
| 3 |
|
| 4 |
+
Minimal single-image version for Hugging Face demo.
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
|
|
| 8 |
import logging
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
import numpy as np
|
| 12 |
import cv2
|
|
|
|
|
|
|
|
|
|
| 13 |
from sklearn.decomposition import PCA
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from .config import Config
|
| 16 |
+
from .data import ImagePreprocessor, MaskHandler
|
| 17 |
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
|
| 18 |
from .output import OutputManager
|
| 19 |
from .segmentation import SegmentationManager
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class SorghumPipeline:
|
| 25 |
+
"""Minimal pipeline for single-image plant phenotyping."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
def __init__(self, config: Config):
|
| 28 |
+
"""Initialize the minimal pipeline."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
self._setup_logging()
|
| 30 |
+
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.config.validate()
|
| 32 |
+
self._initialize_components()
|
| 33 |
+
logger.info("Sorghum Pipeline initialized")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def _setup_logging(self):
|
| 36 |
"""Setup logging configuration."""
|
| 37 |
logging.basicConfig(
|
| 38 |
level=logging.INFO,
|
| 39 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 40 |
+
handlers=[logging.StreamHandler()]
|
|
|
|
|
|
|
|
|
|
| 41 |
)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
def _initialize_components(self):
|
| 44 |
+
"""Initialize pipeline components."""
|
| 45 |
+
self.preprocessor = ImagePreprocessor(target_size=None)
|
| 46 |
+
self.mask_handler = MaskHandler(min_area=1000, kernel_size=7)
|
| 47 |
+
self.texture_extractor = TextureExtractor()
|
| 48 |
+
self.vegetation_extractor = VegetationIndexExtractor()
|
| 49 |
+
self.morphology_extractor = MorphologyExtractor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self.segmentation_manager = SegmentationManager(
|
| 51 |
+
model_name="briaai/RMBG-2.0",
|
| 52 |
device=self.config.get_device(),
|
| 53 |
+
threshold=0.5,
|
| 54 |
+
trust_remote_code=True
|
|
|
|
|
|
|
| 55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
self.output_manager = OutputManager(
|
| 57 |
output_folder=self.config.paths.output_folder,
|
| 58 |
settings=self.config.output
|
| 59 |
)
|
| 60 |
|
| 61 |
+
def run(self, single_image_path: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
+
Run minimal pipeline on single image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
Args:
|
| 66 |
+
single_image_path: Path to input image
|
|
|
|
| 67 |
|
| 68 |
Returns:
|
| 69 |
+
Dictionary containing results
|
| 70 |
"""
|
| 71 |
+
logger.info("Starting minimal single-image pipeline...")
|
| 72 |
|
| 73 |
try:
|
| 74 |
import time
|
| 75 |
+
from PIL import Image as _Image
|
| 76 |
+
|
| 77 |
total_start = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# Load single image
|
| 80 |
+
_p = Path(single_image_path)
|
| 81 |
+
_img = _Image.open(str(_p))
|
| 82 |
+
plants = {
|
| 83 |
+
"demo_demo_frame1": {
|
| 84 |
+
"raw_image": (_img, _p.name),
|
| 85 |
+
"plant_name": "demo",
|
| 86 |
+
"file_path": str(_p)
|
| 87 |
}
|
| 88 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# Create composite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
plants = self.preprocessor.create_composites(plants)
|
|
|
|
| 92 |
|
| 93 |
+
# Segment
|
| 94 |
+
plants = self._segment_plants(plants)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
# Extract features
|
| 97 |
+
plants = self._extract_features(plants)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
# Generate outputs
|
| 100 |
+
self._generate_outputs(plants)
|
| 101 |
+
|
| 102 |
+
# Summary
|
| 103 |
+
summary = self._create_summary(plants)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
total_time = time.perf_counter() - total_start
|
| 106 |
+
logger.info(f"Pipeline completed in {total_time:.2f}s")
|
| 107 |
+
|
| 108 |
return {
|
| 109 |
"plants": plants,
|
| 110 |
"summary": summary,
|
|
|
|
| 116 |
logger.error(f"Pipeline failed: {e}")
|
| 117 |
raise
|
| 118 |
|
| 119 |
+
def _segment_plants(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 120 |
+
"""Segment plants using BRIA model (full image)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
for key, pdata in plants.items():
|
| 122 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
composite = pdata['composite']
|
| 124 |
+
soft_mask = self.segmentation_manager.segment_image_soft(composite)
|
| 125 |
+
pdata['soft_mask'] = soft_mask
|
| 126 |
+
pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
|
| 127 |
+
logger.info(f"Segmented {key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
except Exception as e:
|
| 129 |
logger.error(f"Segmentation failed for {key}: {e}")
|
| 130 |
pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
|
| 131 |
pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
|
|
|
|
| 132 |
return plants
|
| 133 |
|
| 134 |
+
def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 135 |
+
"""Extract features from plants."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
for key, pdata in plants.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
pdata['texture_features'] = self._extract_texture_features(pdata)
|
|
|
|
|
|
|
| 139 |
pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
|
|
|
|
|
|
|
| 140 |
pdata['morphology_features'] = self._extract_morphology_features(pdata)
|
| 141 |
+
logger.info(f"Features extracted for {key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
except Exception as e:
|
| 143 |
logger.error(f"Feature extraction failed for {key}: {e}")
|
|
|
|
| 144 |
pdata['texture_features'] = {}
|
| 145 |
pdata['vegetation_indices'] = {}
|
| 146 |
pdata['morphology_features'] = {}
|
|
|
|
| 147 |
return plants
|
| 148 |
|
| 149 |
def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 150 |
+
"""Extract texture features from pseudo-color image only."""
|
| 151 |
features = {}
|
| 152 |
+
try:
|
| 153 |
+
# Only process pseudo-color composite
|
| 154 |
+
composite = pdata['composite']
|
| 155 |
+
mask = pdata.get('mask')
|
| 156 |
+
if mask is not None:
|
| 157 |
+
masked = self.mask_handler.apply_mask_to_image(composite, mask)
|
| 158 |
+
gray_image = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
|
| 159 |
+
else:
|
| 160 |
+
gray_image = cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
|
| 161 |
+
|
| 162 |
+
band_features = self.texture_extractor.extract_all_texture_features(gray_image)
|
| 163 |
+
stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
|
| 164 |
+
|
| 165 |
+
features['color'] = {
|
| 166 |
+
'features': band_features,
|
| 167 |
+
'statistics': stats
|
| 168 |
+
}
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Texture extraction failed: {e}")
|
| 171 |
+
features['color'] = {'features': {}, 'statistics': {}}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
return features
|
| 174 |
|
| 175 |
def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 176 |
+
"""Extract vegetation indices (NDVI, ARI, GNDVI only)."""
|
| 177 |
try:
|
| 178 |
spectral_stack = pdata.get('spectral_stack', {})
|
| 179 |
+
mask = pdata.get('mask')
|
|
|
|
|
|
|
| 180 |
if not spectral_stack or mask is None:
|
| 181 |
return {}
|
| 182 |
|
| 183 |
+
out: Dict[str, Any] = {}
|
| 184 |
+
for name in ("NDVI", "ARI", "GNDVI"):
|
| 185 |
+
bands = self.vegetation_extractor.index_bands.get(name, [])
|
| 186 |
+
if not all(b in spectral_stack for b in bands):
|
| 187 |
+
continue
|
| 188 |
+
arrays = []
|
| 189 |
+
for b in bands:
|
| 190 |
+
arr = spectral_stack[b]
|
| 191 |
+
if isinstance(arr, np.ndarray):
|
| 192 |
+
arr = arr.squeeze(-1)
|
| 193 |
+
arrays.append(np.asarray(arr, dtype=np.float64))
|
| 194 |
+
|
| 195 |
+
values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
|
| 196 |
+
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
|
| 197 |
+
masked_values = np.where(binary_mask, values, np.nan)
|
| 198 |
+
valid = masked_values[~np.isnan(masked_values)]
|
| 199 |
+
|
| 200 |
+
stats = {
|
| 201 |
+
'mean': float(np.mean(valid)) if valid.size else 0.0,
|
| 202 |
+
'std': float(np.std(valid)) if valid.size else 0.0,
|
| 203 |
+
'min': float(np.min(valid)) if valid.size else 0.0,
|
| 204 |
+
'max': float(np.max(valid)) if valid.size else 0.0,
|
| 205 |
+
'median': float(np.median(valid)) if valid.size else 0.0,
|
| 206 |
+
}
|
| 207 |
+
out[name] = {'values': masked_values, 'statistics': stats}
|
| 208 |
+
return out
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"Vegetation index extraction failed: {e}")
|
| 211 |
return {}
|
| 212 |
|
| 213 |
def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 214 |
+
"""Extract morphological features."""
|
| 215 |
try:
|
| 216 |
composite = pdata.get('composite')
|
| 217 |
+
mask = pdata.get('mask')
|
|
|
|
|
|
|
| 218 |
if composite is None or mask is None:
|
| 219 |
return {}
|
| 220 |
+
return self.morphology_extractor.extract_morphology_features(composite, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
except Exception as e:
|
| 222 |
+
logger.error(f"Morphology extraction failed: {e}")
|
| 223 |
return {}
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def _generate_outputs(self, plants: Dict[str, Any]) -> None:
|
| 226 |
+
"""Generate output files."""
|
| 227 |
self.output_manager.create_output_directories()
|
|
|
|
| 228 |
for key, pdata in plants.items():
|
| 229 |
try:
|
|
|
|
| 230 |
self.output_manager.save_plant_results(key, pdata)
|
| 231 |
except Exception as e:
|
| 232 |
logger.error(f"Output generation failed for {key}: {e}")
|
| 233 |
|
| 234 |
def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 235 |
+
"""Create summary of results."""
|
| 236 |
+
return {
|
| 237 |
"total_plants": len(plants),
|
| 238 |
+
"successful_plants": sum(1 for p in plants.values() if p.get('texture_features')),
|
|
|
|
| 239 |
"features_extracted": {
|
| 240 |
+
"texture": sum(1 for p in plants.values() if p.get('texture_features')),
|
| 241 |
+
"vegetation": sum(1 for p in plants.values() if p.get('vegetation_indices')),
|
| 242 |
+
"morphology": sum(1 for p in plants.values() if p.get('morphology_features'))
|
| 243 |
}
|
| 244 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sorghum_pipeline/segmentation/manager.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
This module handles image segmentation using the BRIA model
|
| 5 |
-
and provides post-processing capabilities.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import numpy as np
|
|
@@ -11,299 +8,51 @@ import torch
|
|
| 11 |
from PIL import Image
|
| 12 |
from torchvision import transforms
|
| 13 |
from transformers import AutoModelForImageSegmentation
|
| 14 |
-
from typing import Optional
|
| 15 |
import logging
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
| 20 |
class SegmentationManager:
|
| 21 |
-
"""
|
| 22 |
|
| 23 |
-
def __init__(self,
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
trust_remote_code: bool = True,
|
| 28 |
-
cache_dir: Optional[str] = None,
|
| 29 |
-
local_files_only: bool = False):
|
| 30 |
-
"""
|
| 31 |
-
Initialize segmentation manager.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
model_name: Name of the BRIA model
|
| 35 |
-
device: Device to run model on ("auto", "cpu", "cuda")
|
| 36 |
-
threshold: Segmentation threshold
|
| 37 |
-
trust_remote_code: Whether to trust remote code
|
| 38 |
-
cache_dir: Hugging Face cache directory for model weights
|
| 39 |
-
local_files_only: If True, only load from local cache
|
| 40 |
-
"""
|
| 41 |
self.model_name = model_name
|
| 42 |
self.threshold = threshold
|
| 43 |
-
self.
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
def _load_model(self):
|
| 59 |
-
"""Load the BRIA segmentation model."""
|
| 60 |
-
try:
|
| 61 |
-
logger.info(f"Loading BRIA model: {self.model_name}")
|
| 62 |
-
|
| 63 |
-
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 64 |
-
self.model_name,
|
| 65 |
-
trust_remote_code=self.trust_remote_code,
|
| 66 |
-
cache_dir=self.cache_dir if self.cache_dir else None,
|
| 67 |
-
local_files_only=self.local_files_only,
|
| 68 |
-
).eval().to(self.device)
|
| 69 |
-
|
| 70 |
-
# Define image transform
|
| 71 |
-
self.transform = transforms.Compose([
|
| 72 |
-
transforms.Resize((1024, 1024)),
|
| 73 |
-
transforms.ToTensor(),
|
| 74 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 75 |
-
])
|
| 76 |
-
|
| 77 |
-
logger.info("BRIA model loaded successfully")
|
| 78 |
-
|
| 79 |
-
except Exception as e:
|
| 80 |
-
logger.error(f"Failed to load BRIA model: {e}")
|
| 81 |
-
raise
|
| 82 |
-
|
| 83 |
-
def segment_image(self, image: np.ndarray) -> np.ndarray:
|
| 84 |
-
"""
|
| 85 |
-
Segment an image using the BRIA model.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
image: Input image (BGR format)
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
Binary mask (0/255)
|
| 92 |
-
"""
|
| 93 |
-
if self.model is None:
|
| 94 |
-
raise RuntimeError("Model not loaded")
|
| 95 |
-
|
| 96 |
-
try:
|
| 97 |
-
# Convert BGR to RGB
|
| 98 |
-
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 99 |
-
pil_image = Image.fromarray(rgb_image)
|
| 100 |
-
|
| 101 |
-
# Apply transform
|
| 102 |
-
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 103 |
-
|
| 104 |
-
# Run inference
|
| 105 |
-
with torch.no_grad():
|
| 106 |
-
predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
|
| 107 |
-
|
| 108 |
-
# Apply threshold
|
| 109 |
-
mask = (predictions > self.threshold).astype(np.uint8) * 255
|
| 110 |
-
|
| 111 |
-
# Resize back to original size
|
| 112 |
-
original_size = (image.shape[1], image.shape[0]) # (width, height)
|
| 113 |
-
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
|
| 114 |
-
|
| 115 |
-
return mask_resized
|
| 116 |
-
|
| 117 |
-
except Exception as e:
|
| 118 |
-
logger.error(f"Segmentation failed: {e}")
|
| 119 |
-
# Return empty mask
|
| 120 |
-
return np.zeros(image.shape[:2], dtype=np.uint8)
|
| 121 |
-
|
| 122 |
def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
|
| 123 |
-
"""
|
| 124 |
-
Segment an image and return a soft mask in [0, 1] resized to original size.
|
| 125 |
-
No thresholding or post-processing is applied.
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
image: Input image (BGR format)
|
| 129 |
-
|
| 130 |
-
Returns:
|
| 131 |
-
Float mask in [0,1] with shape (H, W)
|
| 132 |
-
"""
|
| 133 |
-
if self.model is None:
|
| 134 |
-
raise RuntimeError("Model not loaded")
|
| 135 |
try:
|
| 136 |
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 137 |
pil_image = Image.fromarray(rgb_image)
|
| 138 |
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
|
|
|
| 139 |
with torch.no_grad():
|
| 140 |
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
|
|
|
|
| 141 |
original_size = (image.shape[1], image.shape[0])
|
| 142 |
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
|
| 143 |
return np.clip(soft_mask, 0.0, 1.0)
|
| 144 |
except Exception as e:
|
| 145 |
-
logger.error(f"
|
| 146 |
-
return np.zeros(image.shape[:2], dtype=np.float32)
|
| 147 |
-
|
| 148 |
-
def post_process_mask(self, mask: np.ndarray,
|
| 149 |
-
min_area: int = 1000,
|
| 150 |
-
kernel_size: int = 5) -> np.ndarray:
|
| 151 |
-
"""
|
| 152 |
-
Post-process segmentation mask.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
mask: Input mask
|
| 156 |
-
min_area: Minimum area for connected components
|
| 157 |
-
kernel_size: Kernel size for morphological operations
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
Post-processed mask
|
| 161 |
-
"""
|
| 162 |
-
try:
|
| 163 |
-
# Morphological opening to remove noise
|
| 164 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 165 |
-
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 166 |
-
|
| 167 |
-
# Remove small connected components
|
| 168 |
-
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
| 169 |
-
opened, connectivity=8
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
processed_mask = np.zeros_like(opened)
|
| 173 |
-
for label in range(1, num_labels): # Skip background
|
| 174 |
-
if stats[label, cv2.CC_STAT_AREA] >= min_area:
|
| 175 |
-
processed_mask[labels == label] = 255
|
| 176 |
-
|
| 177 |
-
return processed_mask
|
| 178 |
-
|
| 179 |
-
except Exception as e:
|
| 180 |
-
logger.error(f"Mask post-processing failed: {e}")
|
| 181 |
-
return mask
|
| 182 |
-
|
| 183 |
-
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
|
| 184 |
-
"""
|
| 185 |
-
Keep only the largest connected component.
|
| 186 |
-
|
| 187 |
-
Args:
|
| 188 |
-
mask: Input mask
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
Mask with only the largest component
|
| 192 |
-
"""
|
| 193 |
-
try:
|
| 194 |
-
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
|
| 195 |
-
|
| 196 |
-
if num_labels <= 1:
|
| 197 |
-
return mask
|
| 198 |
-
|
| 199 |
-
# Find the largest component (excluding background)
|
| 200 |
-
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 201 |
-
largest_label = 1 + np.argmax(areas)
|
| 202 |
-
|
| 203 |
-
# Create mask with only the largest component
|
| 204 |
-
largest_mask = (labels == largest_label).astype(np.uint8) * 255
|
| 205 |
-
|
| 206 |
-
return largest_mask
|
| 207 |
-
|
| 208 |
-
except Exception as e:
|
| 209 |
-
logger.error(f"Largest component extraction failed: {e}")
|
| 210 |
-
return mask
|
| 211 |
-
|
| 212 |
-
def validate_mask(self, mask: np.ndarray) -> bool:
|
| 213 |
-
"""
|
| 214 |
-
Validate segmentation mask.
|
| 215 |
-
|
| 216 |
-
Args:
|
| 217 |
-
mask: Mask to validate
|
| 218 |
-
|
| 219 |
-
Returns:
|
| 220 |
-
True if valid, False otherwise
|
| 221 |
-
"""
|
| 222 |
-
if mask is None:
|
| 223 |
-
return False
|
| 224 |
-
|
| 225 |
-
if not isinstance(mask, np.ndarray):
|
| 226 |
-
return False
|
| 227 |
-
|
| 228 |
-
if mask.ndim != 2:
|
| 229 |
-
return False
|
| 230 |
-
|
| 231 |
-
if mask.dtype not in [np.uint8, np.bool_]:
|
| 232 |
-
return False
|
| 233 |
-
|
| 234 |
-
# Check if mask has any foreground pixels
|
| 235 |
-
if np.sum(mask > 0) == 0:
|
| 236 |
-
logger.warning("Mask has no foreground pixels")
|
| 237 |
-
return False
|
| 238 |
-
|
| 239 |
-
return True
|
| 240 |
-
|
| 241 |
-
def get_mask_properties(self, mask: np.ndarray) -> dict:
|
| 242 |
-
"""
|
| 243 |
-
Get properties of the segmentation mask.
|
| 244 |
-
|
| 245 |
-
Args:
|
| 246 |
-
mask: Binary mask
|
| 247 |
-
|
| 248 |
-
Returns:
|
| 249 |
-
Dictionary of mask properties
|
| 250 |
-
"""
|
| 251 |
-
if not self.validate_mask(mask):
|
| 252 |
-
return {}
|
| 253 |
-
|
| 254 |
-
try:
|
| 255 |
-
# Convert to binary
|
| 256 |
-
binary_mask = (mask > 127).astype(np.uint8)
|
| 257 |
-
|
| 258 |
-
# Calculate properties
|
| 259 |
-
area = np.sum(binary_mask)
|
| 260 |
-
perimeter = 0
|
| 261 |
-
|
| 262 |
-
# Find contours
|
| 263 |
-
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 264 |
-
if contours:
|
| 265 |
-
perimeter = cv2.arcLength(contours[0], True)
|
| 266 |
-
|
| 267 |
-
# Bounding box
|
| 268 |
-
x, y, w, h = cv2.boundingRect(contours[0])
|
| 269 |
-
bbox_area = w * h
|
| 270 |
-
aspect_ratio = w / h if h > 0 else 0
|
| 271 |
-
else:
|
| 272 |
-
bbox_area = 0
|
| 273 |
-
aspect_ratio = 0
|
| 274 |
-
|
| 275 |
-
return {
|
| 276 |
-
"area": int(area),
|
| 277 |
-
"perimeter": float(perimeter),
|
| 278 |
-
"bbox_area": int(bbox_area),
|
| 279 |
-
"aspect_ratio": float(aspect_ratio),
|
| 280 |
-
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0,
|
| 281 |
-
"num_components": len(contours)
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
except Exception as e:
|
| 285 |
-
logger.error(f"Mask property calculation failed: {e}")
|
| 286 |
-
return {}
|
| 287 |
-
|
| 288 |
-
def create_overlay(self, image: np.ndarray, mask: np.ndarray,
|
| 289 |
-
color: Tuple[int, int, int] = (0, 255, 0),
|
| 290 |
-
alpha: float = 0.5) -> np.ndarray:
|
| 291 |
-
"""
|
| 292 |
-
Create overlay of mask on image.
|
| 293 |
-
|
| 294 |
-
Args:
|
| 295 |
-
image: Base image
|
| 296 |
-
mask: Binary mask
|
| 297 |
-
color: Overlay color (B, G, R)
|
| 298 |
-
alpha: Overlay transparency
|
| 299 |
-
|
| 300 |
-
Returns:
|
| 301 |
-
Image with mask overlay
|
| 302 |
-
"""
|
| 303 |
-
try:
|
| 304 |
-
overlay = image.copy()
|
| 305 |
-
overlay[mask == 255] = color
|
| 306 |
-
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
|
| 307 |
-
except Exception as e:
|
| 308 |
-
logger.error(f"Overlay creation failed: {e}")
|
| 309 |
-
return image
|
|
|
|
| 1 |
"""
|
| 2 |
+
Minimal segmentation manager.
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision import transforms
|
| 10 |
from transformers import AutoModelForImageSegmentation
|
| 11 |
+
from typing import Optional
|
| 12 |
import logging
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class SegmentationManager:
|
| 18 |
+
"""Minimal BRIA segmentation."""
|
| 19 |
|
| 20 |
+
def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto",
|
| 21 |
+
threshold: float = 0.5, trust_remote_code: bool = True,
|
| 22 |
+
cache_dir: Optional[str] = None, local_files_only: bool = False):
|
| 23 |
+
"""Initialize segmentation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
self.model_name = model_name
|
| 25 |
self.threshold = threshold
|
| 26 |
+
self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device
|
| 27 |
+
|
| 28 |
+
logger.info(f"Loading BRIA model: {model_name}")
|
| 29 |
+
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 30 |
+
model_name,
|
| 31 |
+
trust_remote_code=trust_remote_code,
|
| 32 |
+
cache_dir=cache_dir if cache_dir else None,
|
| 33 |
+
local_files_only=local_files_only,
|
| 34 |
+
).eval().to(self.device)
|
| 35 |
+
|
| 36 |
+
self.transform = transforms.Compose([
|
| 37 |
+
transforms.Resize((1024, 1024)),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 40 |
+
])
|
| 41 |
+
logger.info("BRIA model loaded")
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
|
| 44 |
+
"""Segment image and return soft mask [0,1]."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
try:
|
| 46 |
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 47 |
pil_image = Image.fromarray(rgb_image)
|
| 48 |
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 49 |
+
|
| 50 |
with torch.no_grad():
|
| 51 |
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
|
| 52 |
+
|
| 53 |
original_size = (image.shape[1], image.shape[0])
|
| 54 |
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
|
| 55 |
return np.clip(soft_mask, 0.0, 1.0)
|
| 56 |
except Exception as e:
|
| 57 |
+
logger.error(f"Segmentation failed: {e}")
|
| 58 |
+
return np.zeros(image.shape[:2], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wrapper.py
CHANGED
|
@@ -3,9 +3,10 @@ from typing import Dict
|
|
| 3 |
import shutil
|
| 4 |
from PIL import Image
|
| 5 |
import glob
|
| 6 |
-
import
|
| 7 |
|
| 8 |
from sorghum_pipeline.pipeline import SorghumPipeline
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
|
|
@@ -21,34 +22,38 @@ def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts:
|
|
| 21 |
input_copy = work / Path(input_image_path).name
|
| 22 |
shutil.copy(input_image_path, input_copy)
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
)
|
|
|
|
| 31 |
|
| 32 |
-
# Run the pipeline (single image
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
run_instance_segmentation=False,
|
| 37 |
-
features_frame_only=None
|
| 38 |
-
)
|
| 39 |
|
| 40 |
# Collect outputs
|
| 41 |
outputs: Dict[str, str] = {}
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
return outputs
|
|
|
|
| 3 |
import shutil
|
| 4 |
from PIL import Image
|
| 5 |
import glob
|
| 6 |
+
import os
|
| 7 |
|
| 8 |
from sorghum_pipeline.pipeline import SorghumPipeline
|
| 9 |
+
from sorghum_pipeline.config import Config, Paths
|
| 10 |
|
| 11 |
|
| 12 |
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
|
|
|
|
| 22 |
input_copy = work / Path(input_image_path).name
|
| 23 |
shutil.copy(input_image_path, input_copy)
|
| 24 |
|
| 25 |
+
# Build in-memory config pointing input/output to the working directory
|
| 26 |
+
cfg = Config()
|
| 27 |
+
cfg.paths = Paths(
|
| 28 |
+
input_folder=str(work),
|
| 29 |
+
output_folder=str(work),
|
| 30 |
+
boundingbox_dir=str(work)
|
| 31 |
)
|
| 32 |
+
pipeline = SorghumPipeline(config=cfg)
|
| 33 |
|
| 34 |
+
# Run the pipeline (single image minimal demo)
|
| 35 |
+
os.environ['MINIMAL_DEMO'] = '1'
|
| 36 |
+
os.environ['FAST_OUTPUT'] = '1'
|
| 37 |
+
results = pipeline.run(single_image_path=str(input_copy))
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Collect outputs
|
| 40 |
outputs: Dict[str, str] = {}
|
| 41 |
|
| 42 |
+
# Return only the requested 7 images with fixed keys
|
| 43 |
+
wanted = [
|
| 44 |
+
work / 'Vegetation_indices_images/ndvi.png',
|
| 45 |
+
work / 'Vegetation_indices_images/ari.png',
|
| 46 |
+
work / 'Vegetation_indices_images/gndvi.png',
|
| 47 |
+
work / 'texture_output/lbp.png',
|
| 48 |
+
work / 'texture_output/hog.png',
|
| 49 |
+
work / 'texture_output/lacunarity.png',
|
| 50 |
+
work / 'results/size.size_analysis.png',
|
| 51 |
+
]
|
| 52 |
+
labels = [
|
| 53 |
+
'NDVI', 'ARI', 'GNDVI', 'LBP', 'HOG', 'Lacunarity', 'SizeAnalysis'
|
| 54 |
+
]
|
| 55 |
+
for label, path in zip(labels, wanted):
|
| 56 |
+
if path.exists():
|
| 57 |
+
outputs[label] = str(path)
|
| 58 |
|
| 59 |
return outputs
|