Fahimeh Orvati Nia
update
c170961
raw
history blame
6.04 kB
"""
Minimal single-image pipeline for Hugging Face demo.
"""
import logging
from pathlib import Path
from typing import Dict, Any
import numpy as np
import cv2
from .config import Config
from .data import ImagePreprocessor, MaskHandler
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
from .output import OutputManager
from .segmentation import SegmentationManager
from .features.morphology import MorphologyExtractor
logger = logging.getLogger(__name__)
class SorghumPipeline:
"""Minimal pipeline for single-image processing."""
def __init__(self, config: Config):
"""Initialize pipeline."""
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
self.config = config
self.config.validate()
# Initialize components with defaults
self.preprocessor = ImagePreprocessor()
self.mask_handler = MaskHandler()
self.texture_extractor = TextureExtractor()
self.vegetation_extractor = VegetationIndexExtractor()
self.morphology_extractor = MorphologyExtractor()
self.segmentation_manager = SegmentationManager(
model_name="briaai/RMBG-2.0",
device=self.config.get_device(),
trust_remote_code=True
)
self.output_manager = OutputManager(
output_folder=self.config.paths.output_folder,
settings=self.config.output
)
logger.info("Pipeline initialized")
def run(self, single_image_path: str) -> Dict[str, Any]:
"""Run pipeline on single image."""
logger.info("Processing single image...")
from PIL import Image
import time
start = time.perf_counter()
# Load image
img = Image.open(single_image_path)
plants = {
"demo": {
"raw_image": (img, Path(single_image_path).name),
"plant_name": "demo",
}
}
# Process: composite β†’ segment β†’ features β†’ save
plants = self.preprocessor.create_composites(plants)
plants = self._segment(plants)
plants = self._extract_features(plants)
self.output_manager.create_output_directories()
for key, pdata in plants.items():
self.output_manager.save_plant_results(key, pdata)
elapsed = time.perf_counter() - start
logger.info(f"Completed in {elapsed:.2f}s")
return {"plants": plants, "timing": elapsed}
def _segment(self, plants: Dict[str, Any]) -> Dict[str, Any]:
"""Segment using BRIA."""
for key, pdata in plants.items():
composite = pdata['composite']
logger.info(f"Composite shape: {composite.shape}")
soft_mask = self.segmentation_manager.segment_image_soft(composite)
logger.info(f"Soft mask shape: {soft_mask.shape}")
mask_uint8 = (soft_mask * 255.0).astype(np.uint8)
logger.info(f"Mask uint8 shape: {mask_uint8.shape}")
pdata['mask'] = mask_uint8
return plants
def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
"""Extract features (NDVI only for now)."""
for key, pdata in plants.items():
composite = pdata['composite']
mask = pdata.get('mask')
# Texture: ONLY LBP on green band within mask
pdata['texture_features'] = {}
green_band = None
spectral = pdata.get('spectral_stack', {})
if 'green' in spectral:
green_band = spectral['green'].squeeze(-1).astype(np.float64)
if mask is not None:
valid = np.where(mask > 0, green_band, np.nan)
else:
valid = green_band
# normalize to uint8 for LBP
v = valid.copy()
v = np.nan_to_num(v, nan=np.nanmin(v))
m, M = np.min(v), np.max(v)
denom = (M - m) if (M - m) > 1e-6 else 1.0
gray8 = ((v - m) / denom * 255.0).astype(np.uint8)
lbp_map = self.texture_extractor.extract_lbp(gray8)
pdata['texture_features'] = {'green': {'features': {'lbp': lbp_map}}}
# Vegetation: NDVI, GNDVI, SAVI
spectral = pdata.get('spectral_stack', {})
if spectral and mask is not None:
pdata['vegetation_indices'] = self._compute_vegetation(spectral, mask)
else:
pdata['vegetation_indices'] = {}
# Morphology: compute size analysis image via internal extractor
try:
pdata['morphology_features'] = self.morphology_extractor.extract_morphology_features(
cv2.cvtColor(composite, cv2.COLOR_BGR2RGB), mask
)
except Exception:
pdata['morphology_features'] = {}
return plants
def _compute_vegetation(self, spectral: Dict[str, np.ndarray], mask: np.ndarray) -> Dict[str, Any]:
"""Compute NDVI, ARI, GNDVI only."""
out = {}
for name in ("NDVI", "GNDVI", "SAVI"):
bands = self.vegetation_extractor.index_bands.get(name, [])
if not all(b in spectral for b in bands):
continue
arrays = [np.asarray(spectral[b].squeeze(-1), dtype=np.float64) for b in bands]
values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
binary_mask = (mask > 0)
masked_values = np.where(binary_mask, values, np.nan)
valid = masked_values[~np.isnan(masked_values)]
stats = {
'mean': float(np.mean(valid)) if valid.size else 0.0,
'std': float(np.std(valid)) if valid.size else 0.0,
}
out[name] = {'values': masked_values, 'statistics': stats}
return out