Fahimeh Orvati Nia
update sorghum for multiple plants
f8ac29e
from pathlib import Path
from typing import Dict, Callable, Optional, Generator, Any
import shutil
from PIL import Image
import glob
import os
from sorghum_pipeline.pipeline import SorghumPipeline
from sorghum_pipeline.config import Config, Paths
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
single_plant_mode: bool = False) -> Generator[Dict[str, str], None, None]:
"""
Run sorghum pipeline on a single image (no instance segmentation).
Yields dict[label -> image_path] progressively for gallery display.
Args:
input_image_path: Path to input image
work_dir: Working directory for outputs
save_artifacts: Whether to save artifacts
progress_callback: Optional callback(stage_name, data) called after each pipeline stage
Yields:
Dictionary of output paths progressively as they become available
"""
work = Path(work_dir)
work.mkdir(parents=True, exist_ok=True)
# Use input path directly (already in work_dir from app.py)
input_path = Path(input_image_path)
# Ensure demo env vars are set before pipeline construction
os.environ['MINIMAL_DEMO'] = '1'
os.environ['FAST_OUTPUT'] = '1'
# Build in-memory config pointing input/output to the working directory
cfg = Config()
cfg.paths = Paths(
input_folder=str(work),
output_folder=str(work),
boundingbox_dir=str(work)
)
pipeline = SorghumPipeline(config=cfg, single_plant_mode=single_plant_mode)
# Run the pipeline with progress callback (generator)
for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback, single_plant_mode=single_plant_mode):
# Yield intermediate outputs as they become available
outputs = _collect_outputs(work, stage_result.get('plants', {}))
yield outputs
# Final results
results = stage_result
def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]:
"""Collect all available outputs from work directory and plants data."""
outputs: Dict[str, str] = {}
try:
# Log immediate output directory contents for debugging
for sub in ['results', 'Vegetation_indices_images', 'texture_output']:
p = work / sub
if p.exists():
files = sorted([str(x.name) for x in p.iterdir() if x.is_file()])
print(f"Artifacts in {sub}: {files}")
except Exception:
pass
# Collect desired vegetation indices (replace ARI with SAVI)
wanted = [
work / 'Vegetation_indices_images/ndvi.png',
work / 'Vegetation_indices_images/gndvi.png',
work / 'Vegetation_indices_images/savi.png',
]
labels = [
'NDVI', 'GNDVI', 'SAVI',
]
for label, path in zip(labels, wanted):
if path.exists():
outputs[label] = str(path)
# Also include overlay and mask if present
overlay_path = work / 'results/overlay.png'
mask_path = work / 'results/mask.png'
composite_path = work / 'results/composite.png'
input_img_path = work / 'results/input_image.png'
if overlay_path.exists():
outputs['Overlay'] = str(overlay_path)
if mask_path.exists():
outputs['Mask'] = str(mask_path)
if composite_path.exists():
outputs['Composite'] = str(composite_path)
if input_img_path.exists():
outputs['InputImage'] = str(input_img_path)
# Extract simple stats for display if present in pipeline results
try:
if plants:
_, pdata = next(iter(plants.items()))
veg = pdata.get('vegetation_indices', {})
stats_lines = []
for name in ['NDVI', 'GNDVI', 'SAVI']:
entry = veg.get(name, {})
st = entry.get('statistics', {}) if isinstance(entry, dict) else {}
if st:
stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}")
# Morphology stats (height - always show as single plant)
morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {}
traits = morph.get('traits', {}) if isinstance(morph, dict) else {}
# Get plant height (system now filters to largest plant only)
plant_heights = traits.get('plant_heights', {})
num_plants = traits.get('num_plants', 0)
# Display plant info based on mode
if num_plants > 0 and isinstance(plant_heights, dict):
if num_plants == 1 or len(plant_heights) == 1:
# Single plant display
height_cm = list(plant_heights.values())[0]
stats_lines.append(f"Number of plants: 1")
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
else:
# Multiple plants display
stats_lines.append(f"Number of plants: {num_plants}")
sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1]))
for plant_name, height_cm in sorted_plants:
plant_num = plant_name.split('_')[1]
stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm")
else:
# Fallback to old single height field
height_cm = traits.get('plant_height_cm')
if isinstance(height_cm, (int, float)) and height_cm > 0:
stats_lines.append(f"Number of plants: 1")
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
if stats_lines:
outputs['StatsText'] = "\n".join(stats_lines)
except Exception:
pass
return outputs