Update processing/two_stage/two_stage_processor.py
Browse files
processing/two_stage/two_stage_processor.py
CHANGED
|
@@ -12,10 +12,12 @@
|
|
| 12 |
- Fix: Add logging for background preparation issue
|
| 13 |
- COMPOSITING FIX: Normalize frame and background scales to prevent dark backgrounds
|
| 14 |
- MAJOR FIX: Use enhanced SAM2Handler instead of basic segment_person_hq
|
|
|
|
| 15 |
"""
|
| 16 |
from __future__ import annotations
|
| 17 |
import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
|
| 18 |
from pathlib import Path
|
|
|
|
| 19 |
from .quality_manager import quality_manager # New quality manager import
|
| 20 |
|
| 21 |
# Project logger if available
|
|
@@ -179,7 +181,7 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
|
|
| 179 |
}
|
| 180 |
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
-
# Two-Stage Processor -
|
| 183 |
# ---------------------------------------------------------------------------
|
| 184 |
class TwoStageProcessor:
|
| 185 |
def __init__(self, sam2_predictor=None, matanyone_model=None):
|
|
@@ -213,31 +215,30 @@ def _unwrap_sam2(self, predictor):
|
|
| 213 |
return predictor
|
| 214 |
|
| 215 |
def _get_mask(self, frame: np.ndarray) -> np.ndarray:
|
| 216 |
-
"""
|
| 217 |
-
logger.info("=== TwoStageProcessor _get_mask called ===")
|
| 218 |
|
| 219 |
if self.sam2_handler is None:
|
| 220 |
logger.warning("No SAM2Handler available - using fallback threshold")
|
| 221 |
-
# Fallback: simple luminance threshold (kept to avoid breaking callers)
|
| 222 |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 223 |
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 224 |
return mask
|
| 225 |
|
| 226 |
try:
|
| 227 |
-
# CRITICAL FIX: Use the ENHANCED SAM2Handler instead of basic segment_person_hq
|
| 228 |
if hasattr(self.sam2_handler, 'create_mask'):
|
| 229 |
-
logger.info("Using ENHANCED SAM2Handler.create_mask() with
|
| 230 |
-
# Convert BGR to RGB for SAM2Handler
|
| 231 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 232 |
mask = self.sam2_handler.create_mask(frame_rgb)
|
| 233 |
|
| 234 |
if mask is not None:
|
| 235 |
-
|
|
|
|
|
|
|
| 236 |
return mask
|
| 237 |
else:
|
| 238 |
logger.warning("Enhanced SAM2Handler returned None mask")
|
| 239 |
else:
|
| 240 |
-
logger.warning("SAM2Handler doesn't have create_mask method
|
| 241 |
|
| 242 |
# Fallback to basic SAM2 if enhanced handler fails
|
| 243 |
if self.sam2 is not None:
|
|
@@ -245,10 +246,10 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
|
|
| 245 |
try:
|
| 246 |
from utils.cv_processing import segment_person_hq
|
| 247 |
mask = segment_person_hq(frame, self.sam2)
|
| 248 |
-
|
| 249 |
return mask
|
| 250 |
except ImportError:
|
| 251 |
-
logger.warning("Could not import segment_person_hq
|
| 252 |
except Exception as e:
|
| 253 |
logger.warning(f"Basic SAM2 segmentation failed: {e}")
|
| 254 |
|
|
@@ -261,6 +262,43 @@ def _get_mask(self, frame: np.ndarray) -> np.ndarray:
|
|
| 261 |
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 262 |
return mask
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
@staticmethod
|
| 265 |
def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
|
| 266 |
"""Convert mask to uint8(0..255)."""
|
|
|
|
| 12 |
- Fix: Add logging for background preparation issue
|
| 13 |
- COMPOSITING FIX: Normalize frame and background scales to prevent dark backgrounds
|
| 14 |
- MAJOR FIX: Use enhanced SAM2Handler instead of basic segment_person_hq
|
| 15 |
+
- MAXIMUM QUALITY: Added aggressive mask cleaning for gap elimination
|
| 16 |
"""
|
| 17 |
from __future__ import annotations
|
| 18 |
import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
|
| 19 |
from pathlib import Path
|
| 20 |
+
from typing import Optional, Callable, Dict, Any, Tuple, List
|
| 21 |
from .quality_manager import quality_manager # New quality manager import
|
| 22 |
|
| 23 |
# Project logger if available
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
# ---------------------------------------------------------------------------
|
| 184 |
+
# Two-Stage Processor - MAXIMUM QUALITY VERSION
|
| 185 |
# ---------------------------------------------------------------------------
|
| 186 |
class TwoStageProcessor:
|
| 187 |
def __init__(self, sam2_predictor=None, matanyone_model=None):
|
|
|
|
| 215 |
return predictor
|
| 216 |
|
| 217 |
def _get_mask(self, frame: np.ndarray) -> np.ndarray:
|
| 218 |
+
"""MAXIMUM QUALITY mask with enhanced cleaning."""
|
| 219 |
+
logger.info("=== TwoStageProcessor _get_mask called (MAX QUALITY) ===")
|
| 220 |
|
| 221 |
if self.sam2_handler is None:
|
| 222 |
logger.warning("No SAM2Handler available - using fallback threshold")
|
|
|
|
| 223 |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 224 |
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 225 |
return mask
|
| 226 |
|
| 227 |
try:
|
|
|
|
| 228 |
if hasattr(self.sam2_handler, 'create_mask'):
|
| 229 |
+
logger.info("Using ENHANCED SAM2Handler.create_mask() with MAXIMUM QUALITY")
|
|
|
|
| 230 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 231 |
mask = self.sam2_handler.create_mask(frame_rgb)
|
| 232 |
|
| 233 |
if mask is not None:
|
| 234 |
+
# MAXIMUM QUALITY POST-PROCESSING
|
| 235 |
+
mask = self._maximum_quality_mask_cleaning(mask)
|
| 236 |
+
logger.info(f"Enhanced SAM2 mask with max quality cleaning - coverage: {np.mean(mask/255.0):.3f}")
|
| 237 |
return mask
|
| 238 |
else:
|
| 239 |
logger.warning("Enhanced SAM2Handler returned None mask")
|
| 240 |
else:
|
| 241 |
+
logger.warning("SAM2Handler doesn't have create_mask method")
|
| 242 |
|
| 243 |
# Fallback to basic SAM2 if enhanced handler fails
|
| 244 |
if self.sam2 is not None:
|
|
|
|
| 246 |
try:
|
| 247 |
from utils.cv_processing import segment_person_hq
|
| 248 |
mask = segment_person_hq(frame, self.sam2)
|
| 249 |
+
mask = self._maximum_quality_mask_cleaning(mask)
|
| 250 |
return mask
|
| 251 |
except ImportError:
|
| 252 |
+
logger.warning("Could not import segment_person_hq")
|
| 253 |
except Exception as e:
|
| 254 |
logger.warning(f"Basic SAM2 segmentation failed: {e}")
|
| 255 |
|
|
|
|
| 262 |
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 263 |
return mask
|
| 264 |
|
| 265 |
+
def _maximum_quality_mask_cleaning(self, mask: np.ndarray) -> np.ndarray:
|
| 266 |
+
"""Maximum quality mask cleaning and refinement."""
|
| 267 |
+
try:
|
| 268 |
+
# Ensure uint8 format
|
| 269 |
+
if mask.max() <= 1.0:
|
| 270 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 271 |
+
else:
|
| 272 |
+
mask_uint8 = mask.astype(np.uint8)
|
| 273 |
+
|
| 274 |
+
# Step 1: Fill small holes aggressively
|
| 275 |
+
kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
|
| 276 |
+
mask_filled = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_fill)
|
| 277 |
+
|
| 278 |
+
# Step 2: Connect nearby regions
|
| 279 |
+
kernel_connect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 280 |
+
mask_connected = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, kernel_connect)
|
| 281 |
+
|
| 282 |
+
# Step 3: Smooth boundaries heavily
|
| 283 |
+
mask_smooth1 = cv2.GaussianBlur(mask_connected, (7, 7), 2.0)
|
| 284 |
+
|
| 285 |
+
# Step 4: Re-threshold to crisp edges
|
| 286 |
+
_, mask_thresh = cv2.threshold(mask_smooth1, 127, 255, cv2.THRESH_BINARY)
|
| 287 |
+
|
| 288 |
+
# Step 5: Final edge smoothing
|
| 289 |
+
mask_final = cv2.GaussianBlur(mask_thresh, (5, 5), 1.0)
|
| 290 |
+
|
| 291 |
+
# Step 6: Dilate slightly to ensure full coverage
|
| 292 |
+
kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 293 |
+
mask_dilated = cv2.dilate(mask_final, kernel_dilate, iterations=1)
|
| 294 |
+
|
| 295 |
+
logger.info("Maximum quality mask cleaning applied successfully")
|
| 296 |
+
return mask_dilated
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.warning(f"Maximum quality mask cleaning failed: {e}")
|
| 300 |
+
return mask
|
| 301 |
+
|
| 302 |
@staticmethod
|
| 303 |
def _to_binary_mask(mask: np.ndarray) -> Optional[np.ndarray]:
|
| 304 |
"""Convert mask to uint8(0..255)."""
|