Fahimeh Orvati Nia
commited on
Commit
·
69cba14
1
Parent(s):
2a055cf
update
Browse files
sorghum_pipeline/features/morphology.py
CHANGED
|
@@ -22,11 +22,14 @@ class MorphologyExtractor:
|
|
| 22 |
"""Morphology extraction: size analysis image + simple traits + YOLO tips overlay."""
|
| 23 |
|
| 24 |
def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
|
| 25 |
-
yolo_weights_path: str = "/home/grads/f/fahimehorvatinia/plant-analysis-demo/SSL_greenhouse_tip_detection.pt"
|
|
|
|
| 26 |
"""Initialize."""
|
| 27 |
self.pixel_to_cm = pixel_to_cm
|
| 28 |
self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
|
| 29 |
self.yolo_weights_path = yolo_weights_path
|
|
|
|
|
|
|
| 30 |
|
| 31 |
if PLANT_CV_AVAILABLE:
|
| 32 |
pcv.params.debug = None
|
|
@@ -48,12 +51,15 @@ class MorphologyExtractor:
|
|
| 48 |
if rgb is None:
|
| 49 |
return features
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
# Size analysis image via PlantCV if available
|
| 52 |
if PLANT_CV_AVAILABLE:
|
| 53 |
with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
|
| 54 |
contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
|
| 55 |
try:
|
| 56 |
-
labeled_mask, n_labels = pcv.create_labels(
|
| 57 |
size_analysis = pcv.analyze.size(rgb, labeled_mask, n_labels, label="default")
|
| 58 |
features['images']['size_analysis'] = size_analysis
|
| 59 |
features['success'] = True
|
|
@@ -61,7 +67,7 @@ class MorphologyExtractor:
|
|
| 61 |
logger.warning(f"Size analysis failed: {e}")
|
| 62 |
else:
|
| 63 |
# Fallback: make a simple contour visualization
|
| 64 |
-
vis = self._simple_size_visual(rgb,
|
| 65 |
features['images']['size_analysis'] = vis
|
| 66 |
features['success'] = True
|
| 67 |
|
|
@@ -83,6 +89,32 @@ class MorphologyExtractor:
|
|
| 83 |
logger.error(f"Morphology extraction failed: {e}")
|
| 84 |
|
| 85 |
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 88 |
"""Preprocess mask."""
|
|
@@ -149,12 +181,36 @@ class MorphologyExtractor:
|
|
| 149 |
"""Detect tips using a YOLO model if available. Returns (overlay_img, tips_list)."""
|
| 150 |
try:
|
| 151 |
from ultralytics import YOLO # type: ignore
|
| 152 |
-
except Exception:
|
|
|
|
| 153 |
return None, []
|
| 154 |
|
| 155 |
try:
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
return None, []
|
| 159 |
|
| 160 |
try:
|
|
@@ -173,7 +229,8 @@ class MorphologyExtractor:
|
|
| 173 |
x, y = float(pt[0]), float(pt[1])
|
| 174 |
if not np.isnan(x) and not np.isnan(y):
|
| 175 |
conf = float(kps_conf[i][j]) if kps_conf is not None else 1.0
|
| 176 |
-
|
|
|
|
| 177 |
tips.append((int(x), int(y), conf))
|
| 178 |
|
| 179 |
# Draw tips
|
|
@@ -181,7 +238,8 @@ class MorphologyExtractor:
|
|
| 181 |
for (x, y, _c) in tips:
|
| 182 |
cv2.circle(vis, (int(x), int(y)), 8, (255, 0, 0), -1)
|
| 183 |
return vis, tips
|
| 184 |
-
except Exception:
|
|
|
|
| 185 |
return None, []
|
| 186 |
|
| 187 |
class _FilteredStream:
|
|
|
|
| 22 |
"""Morphology extraction: size analysis image + simple traits + YOLO tips overlay."""
|
| 23 |
|
| 24 |
def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None,
|
| 25 |
+
yolo_weights_path: str = "/home/grads/f/fahimehorvatinia/plant-analysis-demo/SSL_greenhouse_tip_detection.pt",
|
| 26 |
+
min_component_area_for_size: int = 3000):
|
| 27 |
"""Initialize."""
|
| 28 |
self.pixel_to_cm = pixel_to_cm
|
| 29 |
self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
|
| 30 |
self.yolo_weights_path = yolo_weights_path
|
| 31 |
+
# Used only for the Morphology Size visualization (not for height or YOLO)
|
| 32 |
+
self.min_component_area_for_size = int(min_component_area_for_size)
|
| 33 |
|
| 34 |
if PLANT_CV_AVAILABLE:
|
| 35 |
pcv.params.debug = None
|
|
|
|
| 51 |
if rgb is None:
|
| 52 |
return features
|
| 53 |
|
| 54 |
+
# For the size visualization only, remove small connected components
|
| 55 |
+
size_mask = self._filter_small_components(clean_mask, self.min_component_area_for_size)
|
| 56 |
+
|
| 57 |
# Size analysis image via PlantCV if available
|
| 58 |
if PLANT_CV_AVAILABLE:
|
| 59 |
with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
|
| 60 |
contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
|
| 61 |
try:
|
| 62 |
+
labeled_mask, n_labels = pcv.create_labels(size_mask)
|
| 63 |
size_analysis = pcv.analyze.size(rgb, labeled_mask, n_labels, label="default")
|
| 64 |
features['images']['size_analysis'] = size_analysis
|
| 65 |
features['success'] = True
|
|
|
|
| 67 |
logger.warning(f"Size analysis failed: {e}")
|
| 68 |
else:
|
| 69 |
# Fallback: make a simple contour visualization
|
| 70 |
+
vis = self._simple_size_visual(rgb, size_mask)
|
| 71 |
features['images']['size_analysis'] = vis
|
| 72 |
features['success'] = True
|
| 73 |
|
|
|
|
| 89 |
logger.error(f"Morphology extraction failed: {e}")
|
| 90 |
|
| 91 |
return features
|
| 92 |
+
|
| 93 |
+
def _filter_small_components(self, mask: np.ndarray, min_area: int) -> np.ndarray:
|
| 94 |
+
"""Remove connected components smaller than min_area from a binary mask (0/255)."""
|
| 95 |
+
if mask is None or mask.size == 0:
|
| 96 |
+
return mask
|
| 97 |
+
try:
|
| 98 |
+
m = mask
|
| 99 |
+
if m.ndim == 3:
|
| 100 |
+
m = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
|
| 101 |
+
m = (m.astype(np.uint8) > 0).astype(np.uint8) * 255
|
| 102 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
|
| 103 |
+
filtered = np.zeros_like(m)
|
| 104 |
+
kept = 0
|
| 105 |
+
for lbl in range(1, num_labels):
|
| 106 |
+
if stats[lbl, cv2.CC_STAT_AREA] >= max(0, int(min_area)):
|
| 107 |
+
filtered[labels == lbl] = 255
|
| 108 |
+
kept += 1
|
| 109 |
+
if kept == 0:
|
| 110 |
+
# If all removed, keep the largest to avoid empty visualization
|
| 111 |
+
if num_labels > 1:
|
| 112 |
+
largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
|
| 113 |
+
filtered[labels == largest] = 255
|
| 114 |
+
return filtered
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.warning(f"Component filtering failed, using original mask: {e}")
|
| 117 |
+
return mask
|
| 118 |
|
| 119 |
def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 120 |
"""Preprocess mask."""
|
|
|
|
| 181 |
"""Detect tips using a YOLO model if available. Returns (overlay_img, tips_list)."""
|
| 182 |
try:
|
| 183 |
from ultralytics import YOLO # type: ignore
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.warning(f"Ultralytics not available: {e}")
|
| 186 |
return None, []
|
| 187 |
|
| 188 |
try:
|
| 189 |
+
# Resolve weights path robustly
|
| 190 |
+
weights_path = self.yolo_weights_path
|
| 191 |
+
if not isinstance(weights_path, str) or not weights_path:
|
| 192 |
+
weights_path = "SSL_greenhouse_tip_detection.pt"
|
| 193 |
+
# Try absolute, then repo root, then cwd
|
| 194 |
+
candidates = [
|
| 195 |
+
weights_path,
|
| 196 |
+
"/home/grads/f/fahimehorvatinia/plant-analysis-demo/SSL_greenhouse_tip_detection.pt",
|
| 197 |
+
"./SSL_greenhouse_tip_detection.pt",
|
| 198 |
+
]
|
| 199 |
+
chosen = None
|
| 200 |
+
for p in candidates:
|
| 201 |
+
try:
|
| 202 |
+
import os
|
| 203 |
+
if os.path.exists(p):
|
| 204 |
+
chosen = p
|
| 205 |
+
break
|
| 206 |
+
except Exception:
|
| 207 |
+
pass
|
| 208 |
+
if chosen is None:
|
| 209 |
+
logger.warning("YOLO weights not found; skipping YOLO tips")
|
| 210 |
+
return None, []
|
| 211 |
+
model = YOLO(chosen)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.warning(f"Failed to load YOLO model: {e}")
|
| 214 |
return None, []
|
| 215 |
|
| 216 |
try:
|
|
|
|
| 229 |
x, y = float(pt[0]), float(pt[1])
|
| 230 |
if not np.isnan(x) and not np.isnan(y):
|
| 231 |
conf = float(kps_conf[i][j]) if kps_conf is not None else 1.0
|
| 232 |
+
# Slightly relax threshold to 0.4 to improve recall
|
| 233 |
+
if conf >= 0.4:
|
| 234 |
tips.append((int(x), int(y), conf))
|
| 235 |
|
| 236 |
# Draw tips
|
|
|
|
| 238 |
for (x, y, _c) in tips:
|
| 239 |
cv2.circle(vis, (int(x), int(y)), 8, (255, 0, 0), -1)
|
| 240 |
return vis, tips
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.warning(f"YOLO detection failed: {e}")
|
| 243 |
return None, []
|
| 244 |
|
| 245 |
class _FilteredStream:
|
sorghum_pipeline/output/manager.py
CHANGED
|
@@ -227,6 +227,21 @@ class OutputManager:
|
|
| 227 |
if isinstance(yolo_img, np.ndarray) and yolo_img.size > 0:
|
| 228 |
titled = self._add_title_banner(yolo_img, 'YOLO Tips')
|
| 229 |
cv2.imwrite(str(results_dir / 'yolo_tips.png'), titled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
except Exception as e:
|
| 231 |
logger.error(f"Failed to save size analysis: {e}")
|
| 232 |
|
|
@@ -343,4 +358,48 @@ class OutputManager:
|
|
| 343 |
# Put area text
|
| 344 |
cv2.putText(base_bgr, f"Area: {area_px} px", (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA)
|
| 345 |
|
| 346 |
-
return base_bgr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
if isinstance(yolo_img, np.ndarray) and yolo_img.size > 0:
|
| 228 |
titled = self._add_title_banner(yolo_img, 'YOLO Tips')
|
| 229 |
cv2.imwrite(str(results_dir / 'yolo_tips.png'), titled)
|
| 230 |
+
else:
|
| 231 |
+
# Fallback YOLO visualization: draw nothing but preserve panel, or overlay mask centroids
|
| 232 |
+
try:
|
| 233 |
+
mask_for_tips = plant_data.get('mask')
|
| 234 |
+
base_img_for_tips = plant_data.get('composite')
|
| 235 |
+
if isinstance(base_img_for_tips, np.ndarray) and base_img_for_tips.size > 0:
|
| 236 |
+
base = base_img_for_tips
|
| 237 |
+
else:
|
| 238 |
+
# If no composite, synthesize white background overlay from mask
|
| 239 |
+
base = None
|
| 240 |
+
fallback = self._create_fallback_yolo_panel(mask_for_tips, base)
|
| 241 |
+
titled = self._add_title_banner(fallback, 'YOLO Tips (No Detections)')
|
| 242 |
+
cv2.imwrite(str(results_dir / 'yolo_tips.png'), titled)
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.warning(f"Failed to create fallback YOLO tips image: {e}")
|
| 245 |
except Exception as e:
|
| 246 |
logger.error(f"Failed to save size analysis: {e}")
|
| 247 |
|
|
|
|
| 358 |
# Put area text
|
| 359 |
cv2.putText(base_bgr, f"Area: {area_px} px", (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA)
|
| 360 |
|
| 361 |
+
return base_bgr
|
| 362 |
+
|
| 363 |
+
def _create_fallback_yolo_panel(self, mask: Any, base_image: Any = None) -> np.ndarray:
|
| 364 |
+
"""Create a fallback YOLO tips panel when detections are unavailable.
|
| 365 |
+
Uses the composite image if available; otherwise, creates a white canvas sized to mask.
|
| 366 |
+
"""
|
| 367 |
+
try:
|
| 368 |
+
if isinstance(base_image, np.ndarray) and base_image.size > 0:
|
| 369 |
+
img = base_image
|
| 370 |
+
if img.dtype != np.uint8:
|
| 371 |
+
img = self._normalize_to_uint8(img.astype(np.float64))
|
| 372 |
+
if img.ndim == 2:
|
| 373 |
+
panel = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 374 |
+
elif img.ndim == 3 and img.shape[2] == 3:
|
| 375 |
+
panel = img.copy()
|
| 376 |
+
elif img.ndim == 3 and img.shape[2] == 4:
|
| 377 |
+
panel = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
| 378 |
+
else:
|
| 379 |
+
norm = self._normalize_to_uint8(img.astype(np.float64))
|
| 380 |
+
panel = cv2.cvtColor(norm, cv2.COLOR_GRAY2BGR)
|
| 381 |
+
else:
|
| 382 |
+
if isinstance(mask, np.ndarray) and mask.size > 0:
|
| 383 |
+
h, w = mask.shape[:2]
|
| 384 |
+
else:
|
| 385 |
+
h, w = 256, 256
|
| 386 |
+
panel = np.full((h, w, 3), 255, dtype=np.uint8)
|
| 387 |
+
|
| 388 |
+
# Optionally show mask centroid as hint
|
| 389 |
+
try:
|
| 390 |
+
if isinstance(mask, np.ndarray) and mask.size > 0:
|
| 391 |
+
m = mask
|
| 392 |
+
if m.ndim == 3:
|
| 393 |
+
m = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
|
| 394 |
+
_, bin_m = cv2.threshold(m.astype(np.uint8), 0, 255, cv2.THRESH_BINARY)
|
| 395 |
+
moments = cv2.moments(bin_m)
|
| 396 |
+
if moments['m00'] != 0:
|
| 397 |
+
cx = int(moments['m10'] / moments['m00'])
|
| 398 |
+
cy = int(moments['m01'] / moments['m00'])
|
| 399 |
+
cv2.drawMarker(panel, (cx, cy), (0, 0, 255), markerType=cv2.MARKER_TILTED_CROSS, markerSize=12, thickness=2)
|
| 400 |
+
except Exception:
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
return panel
|
| 404 |
+
except Exception:
|
| 405 |
+
return np.full((256, 256, 3), 255, dtype=np.uint8)
|