MogensR commited on
Commit
6a16de4
Β·
1 Parent(s): f7c6a9c

Update utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +82 -72
utils/cv_processing.py CHANGED
@@ -1,20 +1,20 @@
1
  #!/usr/bin/env python3
2
  """
3
- cv_processing.py Β· slim orchestrator layer (self-contained)
4
  ──────────────────────────────────────────────────────────────────────────────
5
  Public API (unchanged):
6
- - segment_person_hq(frame, predictor=None, fallback_enabled=True) -> mask (H,W) float32 [0..1]
7
- - segment_person_hq_original(...) -> alias of segment_person_hq (back-compat)
8
- - refine_mask_hq(frame, mask, matanyone=None, fallback_enabled=True) -> mask (H,W) float32 [0..1]
9
- - replace_background_hq(frame, mask, background, fallback_enabled=True) -> frame uint8 (H,W,3)
10
- - create_professional_background(key_or_cfg, width, height) -> RGB uint8 (H,W,3)
11
  - validate_video_file(video_path) -> (bool, reason)
12
 
13
  Design:
14
  * NO imports from other utils.* modules β†’ avoids circular imports.
15
- * Torch & diffusers imported lazily inside functions.
16
- * All masks are single-channel float32 in [0..1] at boundaries between stages.
17
- * MatAnyOne step() is fed (N,C,H,W); no 5D tensors.
18
  """
19
 
20
  from __future__ import annotations
@@ -29,7 +29,7 @@
29
  logger = logging.getLogger(__name__)
30
 
31
  # ----------------------------------------------------------------------------
32
- # Background presets (minimal set; callers can keep their own catalog if needed)
33
  # ----------------------------------------------------------------------------
34
  PROFESSIONAL_BACKGROUNDS_LOCAL: Dict[str, Dict[str, Any]] = {
35
  "office": {"color": (240, 248, 255), "gradient": True},
@@ -39,30 +39,22 @@
39
  "white": {"color": (255, 255, 255), "gradient": False},
40
  "black": {"color": (0, 0, 0), "gradient": False},
41
  }
 
 
42
 
43
  # ----------------------------------------------------------------------------
44
  # Helpers
45
  # ----------------------------------------------------------------------------
46
  def _ensure_rgb(img: np.ndarray) -> np.ndarray:
47
- """Convert BGR→RGB if looks like BGR; otherwise pass-through."""
48
  if img is None:
49
  return img
50
  if img.ndim == 3 and img.shape[2] == 3:
51
- # Heuristic: assume OpenCV BGR
52
  return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
53
  return img
54
 
55
- def _ensure_bgr(img: np.ndarray) -> np.ndarray:
56
- """Convert RGB→BGR if looks like RGB; otherwise pass-through."""
57
- if img is None:
58
- return img
59
- if img.ndim == 3 and img.shape[2] == 3:
60
- # Heuristic: assume non-OpenCV images are RGB
61
- return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
62
- return img
63
-
64
  def _to_mask01(m: np.ndarray) -> np.ndarray:
65
- """Ensure single-channel float32 [0..1]."""
66
  if m is None:
67
  return None
68
  if m.ndim == 3:
@@ -73,7 +65,7 @@ def _to_mask01(m: np.ndarray) -> np.ndarray:
73
  return np.clip(m, 0.0, 1.0)
74
 
75
  def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
76
- """Small Gaussian feather for cleaner edges."""
77
  if mask01.ndim == 3:
78
  mask01 = mask01[..., 0]
79
  k = max(1, int(k) * 2 + 1)
@@ -90,13 +82,21 @@ def _vertical_gradient(top: Tuple[int,int,int], bottom: Tuple[int,int,int], widt
90
  bg[y, :] = (r, g, b)
91
  return bg
92
 
 
 
 
 
 
 
 
 
93
  # ----------------------------------------------------------------------------
94
- # Background creation (kept here to match public API)
95
  # ----------------------------------------------------------------------------
96
  def create_professional_background(key_or_cfg: Any, width: int, height: int) -> np.ndarray:
97
  """
98
  Accepts:
99
- - key: str in local preset dict
100
  - cfg: {"color": (r,g,b), "gradient": bool}
101
  Returns RGB uint8 image (H,W,3).
102
  """
@@ -113,19 +113,14 @@ def create_professional_background(key_or_cfg: Any, width: int, height: int) ->
113
  if not use_grad:
114
  return np.full((height, width, 3), color, dtype=np.uint8)
115
 
116
- # Simple vertical gradient dark->base color
117
  dark = (int(color[0]*0.7), int(color[1]*0.7), int(color[2]*0.7))
118
- bg = _vertical_gradient(dark, color, width, height)
119
- return bg # already RGB by convention
120
 
121
  # ----------------------------------------------------------------------------
122
  # Segmentation
123
  # ----------------------------------------------------------------------------
124
  def _simple_person_segmentation(frame_bgr: np.ndarray) -> np.ndarray:
125
- """
126
- Very simple fallback segmentation by suppressing green/white backgrounds.
127
- Returns mask01 (H,W) float32.
128
- """
129
  hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
130
 
131
  lower_green = np.array([40, 40, 40], dtype=np.uint8)
@@ -145,13 +140,22 @@ def _simple_person_segmentation(frame_bgr: np.ndarray) -> np.ndarray:
145
 
146
  return (person_mask.astype(np.float32) / 255.0)
147
 
148
- def segment_person_hq(frame: np.ndarray, predictor: Optional[Any] = None, fallback_enabled: bool = True) -> np.ndarray:
 
 
 
 
 
 
 
149
  """
150
  Try SAM2 predictor if available; return single-channel float32 mask in [0..1].
151
- - predictor.set_image expects RGB
152
- - predictor.predict returns masks with shapes (N,H,W) or (H,W)
153
  """
154
  try:
 
 
 
155
  if predictor is not None and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
156
  rgb = _ensure_rgb(frame)
157
  predictor.set_image(rgb)
@@ -164,18 +168,11 @@ def segment_person_hq(frame: np.ndarray, predictor: Optional[Any] = None, fallba
164
  multimask_output=True
165
  )
166
 
167
- # Normalize and pick best
168
- if isinstance(masks, np.ndarray):
169
- m = masks
170
- else:
171
- m = np.array(masks)
172
-
173
- if m.ndim == 3: # N,H,W
174
  idx = int(np.argmax(scores)) if scores is not None else 0
175
  m = m[idx]
176
- elif m.ndim == 2: # H,W
177
- pass
178
- else:
179
  raise RuntimeError(f"Unexpected SAM2 mask shape: {m.shape}")
180
 
181
  return _to_mask01(m)
@@ -185,7 +182,7 @@ def segment_person_hq(frame: np.ndarray, predictor: Optional[Any] = None, fallba
185
 
186
  return _simple_person_segmentation(frame) if fallback_enabled else np.ones(frame.shape[:2], dtype=np.float32)
187
 
188
- # Back-compat alias (some code may import this)
189
  segment_person_hq_original = segment_person_hq
190
 
191
  # ----------------------------------------------------------------------------
@@ -194,13 +191,11 @@ def segment_person_hq(frame: np.ndarray, predictor: Optional[Any] = None, fallba
194
  def _to_tensor_chw(img_uint8_bgr: np.ndarray) -> "torch.Tensor":
195
  import torch
196
  rgb = cv2.cvtColor(img_uint8_bgr, cv2.COLOR_BGR2RGB)
197
- t = torch.from_numpy(rgb).permute(2, 0, 1).contiguous().float() / 255.0 # (3,H,W)
198
- return t
199
 
200
  def _mask_to_tensor01(mask01: np.ndarray) -> "torch.Tensor":
201
  import torch
202
- m = torch.from_numpy(mask01.astype(np.float32)).unsqueeze(0).unsqueeze(0) # (1,1,H,W)
203
- return m
204
 
205
  def _tensor_to_mask01(t: "torch.Tensor") -> np.ndarray:
206
  import torch
@@ -216,23 +211,36 @@ def _simple_mask_refinement(mask01: np.ndarray) -> np.ndarray:
216
  m = cv2.bilateralFilter(m, 9, 75, 75)
217
  return (m.astype(np.float32) / 255.0)
218
 
219
- def refine_mask_hq(frame: np.ndarray, mask: np.ndarray, matanyone: Optional[Any] = None, fallback_enabled: bool = True) -> np.ndarray:
 
 
 
 
 
 
 
 
220
  """
221
- If MatAnyOne processor is available, refine the mask (single-channel).
222
- - Converts inputs to tensors with shapes:
223
- image: (1,3,H,W)
224
- mask: (1,1,H,W)
225
- - No 5D tensors; avoids conv2d errors like [1,1,3,720,1280].
226
  """
227
- H, W = frame.shape[:2]
 
 
 
228
  mask01 = _to_mask01(mask)
229
 
230
  try:
 
 
 
231
  if matanyone is not None:
232
  import torch
233
 
234
- img_t = _to_tensor_chw(frame).unsqueeze(0) # (1,3,H,W)
235
- mask_t = _mask_to_tensor01(mask01) # (1,1,H,W)
236
 
237
  device = "cuda" if torch.cuda.is_available() else "cpu"
238
  img_t = img_t.to(device)
@@ -246,19 +254,15 @@ def refine_mask_hq(frame: np.ndarray, mask: np.ndarray, matanyone: Optional[Any]
246
  objects=None,
247
  first_frame_pred=True
248
  )
249
- # out should be (1,1,H,W)
250
  if hasattr(matanyone, "output_prob_to_mask"):
251
  out = matanyone.output_prob_to_mask(out)
252
  return _tensor_to_mask01(out)
253
 
254
- elif hasattr(matanyone, "process"):
255
- # Generic .process(image, mask) path; accepts numpy/PIL
256
  refined = matanyone.process(frame, mask01)
257
- refined = np.asarray(refined).astype(np.float32)
258
- return _to_mask01(refined)
259
 
260
- else:
261
- logger.warning("MatAnyOne provided but no 'step' or 'process' method found.")
262
 
263
  except Exception as e:
264
  logger.warning("MatAnyOne refinement failed: %s", e)
@@ -268,23 +272,28 @@ def refine_mask_hq(frame: np.ndarray, mask: np.ndarray, matanyone: Optional[Any]
268
  # ----------------------------------------------------------------------------
269
  # Compositing
270
  # ----------------------------------------------------------------------------
271
- def replace_background_hq(frame: np.ndarray, mask01: np.ndarray, background: np.ndarray, fallback_enabled: bool = True) -> np.ndarray:
 
 
 
 
 
 
272
  """
273
  Composite frame over background using feathered mask.
274
  Inputs:
275
- - frame: (H,W,3) uint8 (BGR or RGB, doesn't matter for linear blend)
276
  - mask01: (H,W) or (H,W,1) float32 in [0..1]
277
  - background: (H,W,3) uint8
278
  Returns:
279
- - composited frame (H,W,3) uint8 (same channel order as inputs)
280
  """
281
  try:
282
  H, W = frame.shape[:2]
283
  if background.shape[:2] != (H, W):
284
  background = cv2.resize(background, (W, H), interpolation=cv2.INTER_LANCZOS4)
285
 
286
- m = _to_mask01(mask01)
287
- m = _feather(m, k=2)
288
  m3 = np.repeat(m[:, :, None], 3, axis=2)
289
 
290
  comp = frame.astype(np.float32) * m3 + background.astype(np.float32) * (1.0 - m3)
@@ -296,7 +305,7 @@ def replace_background_hq(frame: np.ndarray, mask01: np.ndarray, background: np.
296
  raise
297
 
298
  # ----------------------------------------------------------------------------
299
- # Video validation (detailed)
300
  # ----------------------------------------------------------------------------
301
  def validate_video_file(video_path: str) -> Tuple[bool, str]:
302
  """
@@ -350,4 +359,5 @@ def validate_video_file(video_path: str) -> Tuple[bool, str]:
350
  "replace_background_hq",
351
  "create_professional_background",
352
  "validate_video_file",
 
353
  ]
 
1
  #!/usr/bin/env python3
2
  """
3
+ cv_processing.py Β· slim orchestrator layer (self-contained, backward-compatible)
4
  ──────────────────────────────────────────────────────────────────────────────
5
  Public API (unchanged):
6
+ - segment_person_hq(frame, predictor=None, fallback_enabled=True, **compat)
7
+ - segment_person_hq_original(...)
8
+ - refine_mask_hq(frame, mask, matanyone=None, fallback_enabled=True, **compat)
9
+ - replace_background_hq(frame, mask, background, fallback_enabled=True)
10
+ - create_professional_background(key_or_cfg, width, height)
11
  - validate_video_file(video_path) -> (bool, reason)
12
 
13
  Design:
14
  * NO imports from other utils.* modules β†’ avoids circular imports.
15
+ * Torch is imported lazily inside functions.
16
+ * All masks are single-channel float32 in [0..1] at stage boundaries.
17
+ * MatAnyOne gets (N,C,H,W) β€” no 5D tensors.
18
  """
19
 
20
  from __future__ import annotations
 
29
  logger = logging.getLogger(__name__)
30
 
31
  # ----------------------------------------------------------------------------
32
+ # Background presets (local copy; safe defaults)
33
  # ----------------------------------------------------------------------------
34
  PROFESSIONAL_BACKGROUNDS_LOCAL: Dict[str, Dict[str, Any]] = {
35
  "office": {"color": (240, 248, 255), "gradient": True},
 
39
  "white": {"color": (255, 255, 255), "gradient": False},
40
  "black": {"color": (0, 0, 0), "gradient": False},
41
  }
42
+ # Optional alias if callers import by this name
43
+ PROFESSIONAL_BACKGROUNDS = PROFESSIONAL_BACKGROUNDS_LOCAL
44
 
45
  # ----------------------------------------------------------------------------
46
  # Helpers
47
  # ----------------------------------------------------------------------------
48
  def _ensure_rgb(img: np.ndarray) -> np.ndarray:
49
+ """Convert BGR→RGB if it looks like BGR (OpenCV convention)."""
50
  if img is None:
51
  return img
52
  if img.ndim == 3 and img.shape[2] == 3:
 
53
  return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54
  return img
55
 
 
 
 
 
 
 
 
 
 
56
  def _to_mask01(m: np.ndarray) -> np.ndarray:
57
+ """Return single-channel float32 in [0..1]."""
58
  if m is None:
59
  return None
60
  if m.ndim == 3:
 
65
  return np.clip(m, 0.0, 1.0)
66
 
67
  def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
68
+ """Tiny Gaussian feather for smoother edges."""
69
  if mask01.ndim == 3:
70
  mask01 = mask01[..., 0]
71
  k = max(1, int(k) * 2 + 1)
 
82
  bg[y, :] = (r, g, b)
83
  return bg
84
 
85
+ def _looks_like_mask(x: Any) -> bool:
86
+ return (
87
+ isinstance(x, np.ndarray)
88
+ and x.ndim in (2, 3)
89
+ and (x.ndim == 2 or (x.ndim == 3 and x.shape[2] in (1, 3)))
90
+ and x.dtype != object
91
+ )
92
+
93
  # ----------------------------------------------------------------------------
94
+ # Background creation (RGB)
95
  # ----------------------------------------------------------------------------
96
  def create_professional_background(key_or_cfg: Any, width: int, height: int) -> np.ndarray:
97
  """
98
  Accepts:
99
+ - key: str in preset dict
100
  - cfg: {"color": (r,g,b), "gradient": bool}
101
  Returns RGB uint8 image (H,W,3).
102
  """
 
113
  if not use_grad:
114
  return np.full((height, width, 3), color, dtype=np.uint8)
115
 
 
116
  dark = (int(color[0]*0.7), int(color[1]*0.7), int(color[2]*0.7))
117
+ return _vertical_gradient(dark, color, width, height)
 
118
 
119
  # ----------------------------------------------------------------------------
120
  # Segmentation
121
  # ----------------------------------------------------------------------------
122
  def _simple_person_segmentation(frame_bgr: np.ndarray) -> np.ndarray:
123
+ """Very simple fallback segmentation by suppressing green/white backgrounds."""
 
 
 
124
  hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
125
 
126
  lower_green = np.array([40, 40, 40], dtype=np.uint8)
 
140
 
141
  return (person_mask.astype(np.float32) / 255.0)
142
 
143
+ def segment_person_hq(
144
+ frame: np.ndarray,
145
+ predictor: Optional[Any] = None,
146
+ fallback_enabled: bool = True,
147
+ # backward-compat shim:
148
+ use_sam2: Optional[bool] = None,
149
+ **_compat_kwargs,
150
+ ) -> np.ndarray:
151
  """
152
  Try SAM2 predictor if available; return single-channel float32 mask in [0..1].
153
+ Backward-compat: accepts use_sam2 (if False β†’ force fallback).
 
154
  """
155
  try:
156
+ if use_sam2 is False:
157
+ return _simple_person_segmentation(frame)
158
+
159
  if predictor is not None and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
160
  rgb = _ensure_rgb(frame)
161
  predictor.set_image(rgb)
 
168
  multimask_output=True
169
  )
170
 
171
+ m = np.array(masks)
172
+ if m.ndim == 3: # (N,H,W)
 
 
 
 
 
173
  idx = int(np.argmax(scores)) if scores is not None else 0
174
  m = m[idx]
175
+ elif m.ndim != 2: # not (H,W)
 
 
176
  raise RuntimeError(f"Unexpected SAM2 mask shape: {m.shape}")
177
 
178
  return _to_mask01(m)
 
182
 
183
  return _simple_person_segmentation(frame) if fallback_enabled else np.ones(frame.shape[:2], dtype=np.float32)
184
 
185
+ # Back-compat alias
186
  segment_person_hq_original = segment_person_hq
187
 
188
  # ----------------------------------------------------------------------------
 
191
  def _to_tensor_chw(img_uint8_bgr: np.ndarray) -> "torch.Tensor":
192
  import torch
193
  rgb = cv2.cvtColor(img_uint8_bgr, cv2.COLOR_BGR2RGB)
194
+ return torch.from_numpy(rgb).permute(2, 0, 1).contiguous().float() / 255.0 # (3,H,W)
 
195
 
196
  def _mask_to_tensor01(mask01: np.ndarray) -> "torch.Tensor":
197
  import torch
198
+ return torch.from_numpy(mask01.astype(np.float32)).unsqueeze(0).unsqueeze(0) # (1,1,H,W)
 
199
 
200
  def _tensor_to_mask01(t: "torch.Tensor") -> np.ndarray:
201
  import torch
 
211
  m = cv2.bilateralFilter(m, 9, 75, 75)
212
  return (m.astype(np.float32) / 255.0)
213
 
214
+ def refine_mask_hq(
215
+ frame: np.ndarray,
216
+ mask: np.ndarray,
217
+ matanyone: Optional[Any] = None,
218
+ fallback_enabled: bool = True,
219
+ # backward-compat shims:
220
+ use_matanyone: Optional[bool] = None,
221
+ **_compat_kwargs,
222
+ ) -> np.ndarray:
223
  """
224
+ Refine single-channel mask with MatAnyOne if available.
225
+ Backward-compat:
226
+ - accepts use_matanyone (False β†’ skip model)
227
+ - tolerates legacy arg order refine_mask_hq(mask, frame, ...)
 
228
  """
229
+ # tolerate legacy order: refine_mask_hq(mask, frame, ...)
230
+ if _looks_like_mask(frame) and isinstance(mask, np.ndarray) and mask.ndim == 3 and mask.shape[2] == 3:
231
+ frame, mask = mask, frame
232
+
233
  mask01 = _to_mask01(mask)
234
 
235
  try:
236
+ if use_matanyone is False:
237
+ return _simple_mask_refinement(mask01)
238
+
239
  if matanyone is not None:
240
  import torch
241
 
242
+ img_t = _to_tensor_chw(frame).unsqueeze(0) # (1,3,H,W)
243
+ mask_t = _mask_to_tensor01(mask01) # (1,1,H,W)
244
 
245
  device = "cuda" if torch.cuda.is_available() else "cpu"
246
  img_t = img_t.to(device)
 
254
  objects=None,
255
  first_frame_pred=True
256
  )
 
257
  if hasattr(matanyone, "output_prob_to_mask"):
258
  out = matanyone.output_prob_to_mask(out)
259
  return _tensor_to_mask01(out)
260
 
261
+ if hasattr(matanyone, "process"):
 
262
  refined = matanyone.process(frame, mask01)
263
+ return _to_mask01(np.asarray(refined))
 
264
 
265
+ logger.warning("MatAnyOne provided but neither 'step' nor 'process' found.")
 
266
 
267
  except Exception as e:
268
  logger.warning("MatAnyOne refinement failed: %s", e)
 
272
  # ----------------------------------------------------------------------------
273
  # Compositing
274
  # ----------------------------------------------------------------------------
275
+ def replace_background_hq(
276
+ frame: np.ndarray,
277
+ mask01: np.ndarray,
278
+ background: np.ndarray,
279
+ fallback_enabled: bool = True,
280
+ **_compat,
281
+ ) -> np.ndarray:
282
  """
283
  Composite frame over background using feathered mask.
284
  Inputs:
285
+ - frame: (H,W,3) uint8 (BGR or RGB, linear blend anyway)
286
  - mask01: (H,W) or (H,W,1) float32 in [0..1]
287
  - background: (H,W,3) uint8
288
  Returns:
289
+ - composited frame (H,W,3) uint8
290
  """
291
  try:
292
  H, W = frame.shape[:2]
293
  if background.shape[:2] != (H, W):
294
  background = cv2.resize(background, (W, H), interpolation=cv2.INTER_LANCZOS4)
295
 
296
+ m = _feather(_to_mask01(mask01), k=2)
 
297
  m3 = np.repeat(m[:, :, None], 3, axis=2)
298
 
299
  comp = frame.astype(np.float32) * m3 + background.astype(np.float32) * (1.0 - m3)
 
305
  raise
306
 
307
  # ----------------------------------------------------------------------------
308
+ # Video validation
309
  # ----------------------------------------------------------------------------
310
  def validate_video_file(video_path: str) -> Tuple[bool, str]:
311
  """
 
359
  "replace_background_hq",
360
  "create_professional_background",
361
  "validate_video_file",
362
+ "PROFESSIONAL_BACKGROUNDS",
363
  ]