MogensR commited on
Commit
390b954
·
verified ·
1 Parent(s): d5d4d61

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -12,10 +12,12 @@
12
  - Fix: Add logging for background preparation issue
13
  - COMPOSITING FIX: Normalize frame and background scales to prevent dark backgrounds
14
  - MAJOR FIX: Use enhanced SAM2Handler instead of basic segment_person_hq
 
15
  """
16
  from __future__ import annotations
17
  import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
18
  from pathlib import Path
 
19
  from .quality_manager import quality_manager # New quality manager import
20
 
21
  # Project logger if available
@@ -179,7 +181,7 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
179
  }
180
 
181
  # ---------------------------------------------------------------------------
182
- # Two-Stage Processor - FIXED VERSION
183
  # ---------------------------------------------------------------------------
184
  class TwoStageProcessor:
185
  def __init__(self, sam2_predictor=None, matanyone_model=None):
@@ -213,31 +215,30 @@ def _unwrap_sam2(self, predictor):
213
  return predictor
214
 
215
  def _get_mask(self, frame: np.ndarray) -> np.ndarray:
216
- """FIXED: Get segmentation mask using ENHANCED SAM2Handler with full-body detection."""
217
- logger.info("=== TwoStageProcessor _get_mask called ===")
218
 
219
  if self.sam2_handler is None:
220
  logger.warning("No SAM2Handler available - using fallback threshold")
221
- # Fallback: simple luminance threshold (kept to avoid breaking callers)
222
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
223
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
224
  return mask
225
 
226
  try:
227
- # CRITICAL FIX: Use the ENHANCED SAM2Handler instead of basic segment_person_hq
228
  if hasattr(self.sam2_handler, 'create_mask'):
229
- logger.info("Using ENHANCED SAM2Handler.create_mask() with full-body detection")
230
- # Convert BGR to RGB for SAM2Handler
231
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
232
  mask = self.sam2_handler.create_mask(frame_rgb)
233
 
234
  if mask is not None:
235
- logger.info(f"Enhanced SAM2 mask created - shape: {mask.shape}, coverage: {np.mean(mask/255.0):.3f}")
 
 
236
  return mask
237
  else:
238
  logger.warning("Enhanced SAM2Handler returned None mask")
239
  else:
240
- logger.warning("SAM2Handler doesn't have create_mask method - falling back to basic segmentation")
241
 
242
  # Fallback to basic SAM2 if enhanced handler fails
243
  if self.sam2 is not None:
@@ -245,10 +246,10 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
245
  try:
246
  from utils.cv_processing import segment_person_hq
247
  mask = segment_person_hq(frame, self.sam2)
248
- logger.info(f"Basic SAM2 mask created - coverage: {np.mean(mask/255.0 if mask.max() > 1 else mask):.3f}")
249
  return mask
250
  except ImportError:
251
- logger.warning("Could not import segment_person_hq - using threshold fallback")
252
  except Exception as e:
253
  logger.warning(f"Basic SAM2 segmentation failed: {e}")
254
 
@@ -261,6 +262,43 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
261
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
262
  return mask
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  @staticmethod
265
  def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
266
  """Convert mask to uint8(0..255)."""
 
12
  - Fix: Add logging for background preparation issue
13
  - COMPOSITING FIX: Normalize frame and background scales to prevent dark backgrounds
14
  - MAJOR FIX: Use enhanced SAM2Handler instead of basic segment_person_hq
15
+ - MAXIMUM QUALITY: Added aggressive mask cleaning for gap elimination
16
  """
17
  from __future__ import annotations
18
  import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
19
  from pathlib import Path
20
+ from typing import Optional, Callable, Dict, Any, Tuple, List
21
  from .quality_manager import quality_manager # New quality manager import
22
 
23
  # Project logger if available
 
181
  }
182
 
183
  # ---------------------------------------------------------------------------
184
+ # Two-Stage Processor - MAXIMUM QUALITY VERSION
185
  # ---------------------------------------------------------------------------
186
  class TwoStageProcessor:
187
  def __init__(self, sam2_predictor=None, matanyone_model=None):
 
215
  return predictor
216
 
217
  def _get_mask(self, frame: np.ndarray) -> np.ndarray:
218
+ """MAXIMUM QUALITY mask with enhanced cleaning."""
219
+ logger.info("=== TwoStageProcessor _get_mask called (MAX QUALITY) ===")
220
 
221
  if self.sam2_handler is None:
222
  logger.warning("No SAM2Handler available - using fallback threshold")
 
223
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
224
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
225
  return mask
226
 
227
  try:
 
228
  if hasattr(self.sam2_handler, 'create_mask'):
229
+ logger.info("Using ENHANCED SAM2Handler.create_mask() with MAXIMUM QUALITY")
 
230
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
231
  mask = self.sam2_handler.create_mask(frame_rgb)
232
 
233
  if mask is not None:
234
+ # MAXIMUM QUALITY POST-PROCESSING
235
+ mask = self._maximum_quality_mask_cleaning(mask)
236
+ logger.info(f"Enhanced SAM2 mask with max quality cleaning - coverage: {np.mean(mask/255.0):.3f}")
237
  return mask
238
  else:
239
  logger.warning("Enhanced SAM2Handler returned None mask")
240
  else:
241
+ logger.warning("SAM2Handler doesn't have create_mask method")
242
 
243
  # Fallback to basic SAM2 if enhanced handler fails
244
  if self.sam2 is not None:
 
246
  try:
247
  from utils.cv_processing import segment_person_hq
248
  mask = segment_person_hq(frame, self.sam2)
249
+ mask = self._maximum_quality_mask_cleaning(mask)
250
  return mask
251
  except ImportError:
252
+ logger.warning("Could not import segment_person_hq")
253
  except Exception as e:
254
  logger.warning(f"Basic SAM2 segmentation failed: {e}")
255
 
 
262
  _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
263
  return mask
264
 
265
+ def _maximum_quality_mask_cleaning(self, mask: np.ndarray) -> np.ndarray:
266
+ """Maximum quality mask cleaning and refinement."""
267
+ try:
268
+ # Ensure uint8 format
269
+ if mask.max() <= 1.0:
270
+ mask_uint8 = (mask * 255).astype(np.uint8)
271
+ else:
272
+ mask_uint8 = mask.astype(np.uint8)
273
+
274
+ # Step 1: Fill small holes aggressively
275
+ kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
276
+ mask_filled = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_fill)
277
+
278
+ # Step 2: Connect nearby regions
279
+ kernel_connect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
280
+ mask_connected = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, kernel_connect)
281
+
282
+ # Step 3: Smooth boundaries heavily
283
+ mask_smooth1 = cv2.GaussianBlur(mask_connected, (7, 7), 2.0)
284
+
285
+ # Step 4: Re-threshold to crisp edges
286
+ _, mask_thresh = cv2.threshold(mask_smooth1, 127, 255, cv2.THRESH_BINARY)
287
+
288
+ # Step 5: Final edge smoothing
289
+ mask_final = cv2.GaussianBlur(mask_thresh, (5, 5), 1.0)
290
+
291
+ # Step 6: Dilate slightly to ensure full coverage
292
+ kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
293
+ mask_dilated = cv2.dilate(mask_final, kernel_dilate, iterations=1)
294
+
295
+ logger.info("Maximum quality mask cleaning applied successfully")
296
+ return mask_dilated
297
+
298
+ except Exception as e:
299
+ logger.warning(f"Maximum quality mask cleaning failed: {e}")
300
+ return mask
301
+
302
  @staticmethod
303
  def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
304
  """Convert mask to uint8(0..255)."""