Fahimeh Orvati Nia commited on
Commit
2c0bae7
·
1 Parent(s): 1c4433a
sorghum_pipeline/data/mask_handler.py CHANGED
@@ -2,6 +2,9 @@
2
 
3
  import numpy as np
4
  import cv2
 
 
 
5
 
6
 
7
  class MaskHandler:
@@ -16,20 +19,32 @@ class MaskHandler:
16
  if mask is None:
17
  return image
18
 
19
- # Get image dimensions
20
- img_h, img_w = image.shape[:2]
21
-
22
- # Ensure mask is 2D uint8
23
- if mask.ndim > 2:
24
- mask = mask[:, :, 0] # Take first channel if multi-channel
25
- mask = mask.astype(np.uint8)
26
-
27
- # Resize mask to exactly match image dimensions
28
- if mask.shape != (img_h, img_w):
29
- mask = cv2.resize(mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST)
30
-
31
- # Create binary mask (must be exactly same H x W as image)
32
- binary = (mask > 0).astype(np.uint8) * 255
33
-
34
- # Apply mask
35
- return cv2.bitwise_and(image, image, mask=binary)
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import numpy as np
4
  import cv2
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
 
9
 
10
  class MaskHandler:
 
19
  if mask is None:
20
  return image
21
 
22
+ try:
23
+ # Get image dimensions
24
+ img_h, img_w = image.shape[:2]
25
+
26
+ # Ensure mask is 2D uint8
27
+ if mask.ndim > 2:
28
+ mask = mask[:, :, 0] # Take first channel if multi-channel
29
+ mask = mask.astype(np.uint8)
30
+
31
+ # Resize mask to exactly match image dimensions
32
+ mask_h, mask_w = mask.shape[:2]
33
+ if (mask_h, mask_w) != (img_h, img_w):
34
+ logger.debug(f"Resizing mask from {mask.shape} to ({img_h}, {img_w})")
35
+ mask = cv2.resize(mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST)
36
+
37
+ # Create binary mask (must be exactly same H x W as image)
38
+ binary = (mask > 0).astype(np.uint8) * 255
39
+
40
+ # Verify dimensions match
41
+ if binary.shape != (img_h, img_w):
42
+ logger.error(f"Mask shape mismatch: binary {binary.shape} != image ({img_h}, {img_w})")
43
+ return image
44
+
45
+ # Apply mask
46
+ result = cv2.bitwise_and(image, image, mask=binary)
47
+ return result
48
+ except Exception as e:
49
+ logger.error(f"Mask application failed: {e}")
50
+ return image
sorghum_pipeline/pipeline.py CHANGED
@@ -79,8 +79,12 @@ class SorghumPipeline:
79
  """Segment using BRIA."""
80
  for key, pdata in plants.items():
81
  composite = pdata['composite']
 
82
  soft_mask = self.segmentation_manager.segment_image_soft(composite)
83
- pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
 
 
 
84
  return plants
85
 
86
  def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
 
79
  """Segment using BRIA."""
80
  for key, pdata in plants.items():
81
  composite = pdata['composite']
82
+ logger.info(f"Composite shape: {composite.shape}")
83
  soft_mask = self.segmentation_manager.segment_image_soft(composite)
84
+ logger.info(f"Soft mask shape: {soft_mask.shape}")
85
+ mask_uint8 = (soft_mask * 255.0).astype(np.uint8)
86
+ logger.info(f"Mask uint8 shape: {mask_uint8.shape}")
87
+ pdata['mask'] = mask_uint8
88
  return plants
89
 
90
  def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]: