VideoBackgroundReplacer / utils /cv_processing.py
MogensR's picture
Update utils/cv_processing.py
d2502a6
raw
history blame
17.9 kB
#!/usr/bin/env python3
"""
cv_processing.py · FIXED VERSION with proper SAM2 handling + MatAnyone stateful integration
All public functions in this module expect RGB images (H,W,3) unless stated otherwise.
CoreVideoProcessor already converts BGR→RGB before calling into this module.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Callable
import cv2
import numpy as np
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------
# Background presets
# ----------------------------------------------------------------------------
PROFESSIONAL_BACKGROUNDS_LOCAL: Dict[str, Dict[str, Any]] = {
"office": {"color": (240, 248, 255), "gradient": True},
"studio": {"color": (32, 32, 32), "gradient": False},
"nature": {"color": (34, 139, 34), "gradient": True},
"abstract": {"color": (75, 0, 130), "gradient": True},
"white": {"color": (255, 255, 255), "gradient": False},
"black": {"color": (0, 0, 0), "gradient": False},
}
PROFESSIONAL_BACKGROUNDS = PROFESSIONAL_BACKGROUNDS_LOCAL
# ----------------------------------------------------------------------------
# Helpers (RGB-safe)
# ----------------------------------------------------------------------------
def _ensure_rgb(img: np.ndarray) -> np.ndarray:
"""
Identity for RGB HWC images. If channels-first, convert to HWC.
DOES NOT perform BGR↔RGB swaps (the caller is responsible for color space).
"""
if img is None:
return img
x = np.asarray(img)
if x.ndim == 3 and x.shape[-1] in (3, 4):
return x[..., :3]
if x.ndim == 3 and x.shape[0] in (1, 3, 4) and x.shape[-1] not in (1, 3, 4):
return np.transpose(x, (1, 2, 0))[..., :3]
return x
def _ensure_rgb01(frame_rgb: np.ndarray) -> np.ndarray:
"""
Convert RGB uint8/float to RGB float32 in [0,1], HWC.
No channel swaps are performed.
"""
if frame_rgb is None:
raise ValueError("frame_rgb is None")
x = _ensure_rgb(frame_rgb)
if x.dtype == np.uint8:
return (x.astype(np.float32) / 255.0).copy()
if np.issubdtype(x.dtype, np.floating):
return np.clip(x.astype(np.float32), 0.0, 1.0).copy()
# other integer types
x = np.clip(x, 0, 255).astype(np.uint8)
return (x.astype(np.float32) / 255.0).copy()
def _to_mask01(m: np.ndarray) -> np.ndarray:
if m is None:
return None
if m.ndim == 3 and m.shape[2] in (1, 3, 4):
m = m[..., 0]
m = np.asarray(m)
if m.dtype == np.uint8:
m = m.astype(np.float32) / 255.0
elif m.dtype != np.float32:
m = m.astype(np.float32)
return np.clip(m, 0.0, 1.0)
def _mask_to_2d(mask: np.ndarray) -> np.ndarray:
"""
Reduce any mask to 2-D float32 [H,W], contiguous, in [0,1].
Handles HWC/CHW/B1HW/1HW/HW, etc.
"""
m = np.asarray(mask)
# CHW with single channel
if m.ndim == 3 and m.shape[0] == 1 and (m.shape[1] > 1 and m.shape[2] > 1):
m = m[0]
# HWC with single channel
if m.ndim == 3 and m.shape[-1] == 1:
m = m[..., 0]
# generic 3D -> take first channel
if m.ndim == 3:
m = m[..., 0] if m.shape[-1] in (1, 3, 4) else m[0]
m = np.squeeze(m)
if m.ndim != 2:
# fall back to neutral 0.5 mask
h = int(m.shape[-2]) if m.ndim >= 2 else 512
w = int(m.shape[-1]) if m.ndim >= 2 else 512
logger.warning(f"_mask_to_2d: unexpected shape {mask.shape}, creating neutral mask.")
m = np.full((h, w), 0.5, dtype=np.float32)
if m.dtype == np.uint8:
m = m.astype(np.float32) / 255.0
elif m.dtype != np.float32:
m = m.astype(np.float32)
return np.ascontiguousarray(np.clip(m, 0.0, 1.0))
def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
if mask01.ndim == 3:
mask01 = mask01[..., 0]
k = max(1, int(k) * 2 + 1)
m = cv2.GaussianBlur((mask01 * 255.0).astype(np.uint8), (k, k), 0)
return (m.astype(np.float32) / 255.0)
def _vertical_gradient(top: Tuple[int,int,int], bottom: Tuple[int,int,int], width: int, height: int) -> np.ndarray:
bg = np.zeros((height, width, 3), dtype=np.uint8)
for y in range(height):
t = y / max(1, height - 1)
r = int(top[0] * (1 - t) + bottom[0] * t)
g = int(top[1] * (1 - t) + bottom[1] * t)
b = int(top[2] * (1 - t) + bottom[2] * t)
bg[y, :] = (r, g, b)
return bg
# ----------------------------------------------------------------------------
# Background creation
# ----------------------------------------------------------------------------
def create_professional_background(key_or_cfg: Any, width: int, height: int) -> np.ndarray:
if isinstance(key_or_cfg, str):
cfg = PROFESSIONAL_BACKGROUNDS_LOCAL.get(key_or_cfg, PROFESSIONAL_BACKGROUNDS_LOCAL["office"])
elif isinstance(key_or_cfg, dict):
cfg = key_or_cfg
else:
cfg = PROFESSIONAL_BACKGROUNDS_LOCAL["office"]
color = tuple(int(x) for x in cfg.get("color", (255, 255, 255)))
use_grad = bool(cfg.get("gradient", False))
if not use_grad:
return np.full((height, width, 3), color, dtype=np.uint8)
dark = (int(color[0]*0.7), int(color[1]*0.7), int(color[2]*0.7))
return _vertical_gradient(dark, color, width, height)
# ----------------------------------------------------------------------------
# Improved Segmentation (expects RGB input)
# ----------------------------------------------------------------------------
def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
"""Basic fallback segmentation using color detection on RGB frames."""
h, w = frame_rgb.shape[:2]
hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)
lower_skin = np.array([0, 20, 70], dtype=np.uint8)
upper_skin = np.array([20, 255, 255], dtype=np.uint8)
skin_mask = cv2.inRange(hsv, lower_skin, upper_skin)
# detect greenscreen-ish
lower_green = np.array([40, 40, 40], dtype=np.uint8)
upper_green = np.array([80, 255, 255], dtype=np.uint8)
green_mask = cv2.inRange(hsv, lower_green, upper_green)
person_mask = cv2.bitwise_not(green_mask)
person_mask = cv2.bitwise_or(person_mask, skin_mask)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_OPEN, kernel, iterations=1)
contours, _ = cv2.findContours(person_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
person_mask = np.zeros_like(person_mask)
cv2.drawContours(person_mask, [largest_contour], -1, 255, -1)
return (person_mask.astype(np.float32) / 255.0)
def segment_person_hq(
frame: np.ndarray,
predictor: Optional[Any] = None,
fallback_enabled: bool = True,
use_sam2: Optional[bool] = None,
**_compat_kwargs,
) -> np.ndarray:
"""
High-quality person segmentation with proper SAM2 handling.
Expects RGB frame (H,W,3), uint8 or float in [0,1].
"""
frame_rgb = _ensure_rgb(frame)
h, w = frame_rgb.shape[:2]
if use_sam2 is False:
return _simple_person_segmentation(frame_rgb)
if predictor is not None:
try:
if hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
# Predictor adapter expects RGB uint8; convert if needed
if frame_rgb.dtype != np.uint8:
rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
if np.issubdtype(frame_rgb.dtype, np.floating) else frame_rgb.astype(np.uint8)
else:
rgb_u8 = frame_rgb
predictor.set_image(rgb_u8)
# Center + a couple of body-biased prompts
points = np.array([
[w // 2, h // 2],
[w // 2, h // 4],
[w // 2, h // 2 + h // 8],
], dtype=np.float32)
labels = np.array([1, 1, 1], dtype=np.int32)
result = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True
)
# normalize outputs
if isinstance(result, dict):
masks = result.get("masks", None)
scores = result.get("scores", None)
elif isinstance(result, (tuple, list)) and len(result) >= 2:
masks, scores = result[0], result[1]
else:
masks, scores = result, None
if masks is not None:
masks = np.asarray(masks)
if masks.ndim == 2:
mask = masks
elif masks.ndim == 3 and masks.shape[0] > 0:
if scores is not None:
best_idx = int(np.argmax(np.asarray(scores)))
mask = masks[best_idx]
else:
mask = masks[0]
elif masks.ndim == 4 and masks.shape[1] == 1:
# (N,1,H,W)
if scores is not None:
best_idx = int(np.argmax(np.asarray(scores)))
mask = masks[best_idx, 0]
else:
mask = masks[0, 0]
else:
logger.warning(f"Unexpected mask shape from SAM2: {masks.shape}")
mask = None
if mask is not None:
mask = _to_mask01(mask)
if float(mask.max()) > 0.1:
return np.ascontiguousarray(mask)
else:
logger.warning("SAM2 mask too weak, using fallback")
else:
logger.warning("SAM2 returned no masks")
except Exception as e:
logger.warning(f"SAM2 segmentation error: {e}")
if fallback_enabled:
logger.debug("Using fallback segmentation")
return _simple_person_segmentation(frame_rgb)
else:
return np.ones((h, w), dtype=np.float32)
segment_person_hq_original = segment_person_hq
# ----------------------------------------------------------------------------
# MatAnyone Refinement (Stateful-capable)
# ----------------------------------------------------------------------------
def refine_mask_hq(
frame: np.ndarray,
mask: np.ndarray,
matanyone: Optional[Callable] = None,
*,
frame_idx: Optional[int] = None,
fallback_enabled: bool = True,
use_matanyone: Optional[bool] = None,
**_compat_kwargs,
) -> np.ndarray:
"""
Refine mask with MatAnyone.
Modes:
• Stateful (preferred): provide `frame_idx`. On frame_idx==0, the session encodes with the mask.
On subsequent frames, the session propagates without a mask.
• Backward-compat (stateless): if `frame_idx` is None, we try callable/step/process with (frame, mask)
like before.
Returns:
2-D float32 alpha [H,W], contiguous, in [0,1] (OpenCV-safe).
"""
mask01 = _to_mask01(mask)
if use_matanyone is False:
return mask01
if matanyone is not None and callable(matanyone):
try:
rgb01 = _ensure_rgb01(frame) # RGB float32 in [0,1]
# Stateful path (preferred)
if frame_idx is not None:
if frame_idx == 0:
refined = matanyone(rgb01, mask01) # encode + first-frame predict inside
else:
refined = matanyone(rgb01) # propagate without mask
refined = _mask_to_2d(refined)
if float(refined.max()) > 0.1:
return _postprocess_mask(refined)
logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")
# Backward-compat (stateless) path
refined = None
# Method 1: Direct callable with (frame, mask)
try:
refined = matanyone(rgb01, mask01)
refined = _mask_to_2d(refined)
except Exception as e:
logger.debug(f"MatAnyone callable failed: {e}")
# Method 2: step(image, mask)
if refined is None and hasattr(matanyone, 'step'):
try:
refined = matanyone.step(rgb01, mask01)
refined = _mask_to_2d(refined)
except Exception as e:
logger.debug(f"MatAnyone step failed: {e}")
# Method 3: process(image, mask)
if refined is None and hasattr(matanyone, 'process'):
try:
refined = matanyone.process(rgb01, mask01)
refined = _mask_to_2d(refined)
except Exception as e:
logger.debug(f"MatAnyone process failed: {e}")
if refined is not None and float(refined.max()) > 0.1:
return _postprocess_mask(refined)
else:
logger.warning("MatAnyone refinement failed or produced empty mask")
except Exception as e:
logger.warning(f"MatAnyone error: {e}")
# Fallback refinement
if fallback_enabled:
return _fallback_refine(mask01)
else:
return mask01
def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
"""Post-process mask to clean edges and remove artifacts"""
mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_close)
mask_uint8 = cv2.GaussianBlur(mask_uint8, (3, 3), 0)
_, mask_uint8 = cv2.threshold(mask_uint8, 127, 255, cv2.THRESH_BINARY)
mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
out = mask_uint8.astype(np.float32) / 255.0
return np.ascontiguousarray(out)
def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
"""Simple fallback refinement"""
mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
mask_uint8 = cv2.bilateralFilter(mask_uint8, 9, 75, 75)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)
mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel)
mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
out = mask_uint8.astype(np.float32) / 255.0
return np.ascontiguousarray(out)
# ----------------------------------------------------------------------------
# Compositing (expects RGB inputs)
# ----------------------------------------------------------------------------
def replace_background_hq(
frame: np.ndarray,
mask01: np.ndarray,
background: np.ndarray,
fallback_enabled: bool = True,
**_compat,
) -> np.ndarray:
"""High-quality background replacement with alpha blending (RGB in/out)."""
try:
H, W = frame.shape[:2]
if background.shape[:2] != (H, W):
background = cv2.resize(background, (W, H), interpolation=cv2.INTER_LANCZOS4)
m = _mask_to_2d(_to_mask01(mask01))
m = _feather(m, k=1)
m3 = np.repeat(m[:, :, None], 3, axis=2)
comp = frame.astype(np.float32) * m3 + background.astype(np.float32) * (1.0 - m3)
return np.clip(comp, 0, 255).astype(np.uint8)
except Exception as e:
if fallback_enabled:
logger.warning(f"Compositing failed ({e}) – returning original frame")
return frame
raise
# ----------------------------------------------------------------------------
# Video validation
# ----------------------------------------------------------------------------
def validate_video_file(video_path: str) -> Tuple[bool, str]:
if not video_path or not Path(video_path).exists():
return False, "Video file not found"
try:
size = Path(video_path).stat().st_size
if size == 0:
return False, "File is empty"
if size > 2 * 1024 * 1024 * 1024:
return False, "File > 2 GB"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return False, "Cannot read file"
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if n_frames == 0:
return False, "No frames detected"
if fps <= 0 or fps > 120:
return False, f"Invalid FPS: {fps}"
if w <= 0 or h <= 0:
return False, "Invalid resolution"
if w > 4096 or h > 4096:
return False, f"Resolution {w}×{h} too high"
if (n_frames / fps) > 300:
return False, "Video longer than 5 minutes"
return True, f"OK → {w}×{h}, {fps:.1f} fps, {n_frames/fps:.1f}s"
except Exception as e:
logger.error(f"validate_video_file: {e}")
return False, f"Validation error: {e}"
# ----------------------------------------------------------------------------
# Public symbols
# ----------------------------------------------------------------------------
__all__ = [
"segment_person_hq",
"segment_person_hq_original",
"refine_mask_hq",
"replace_background_hq",
"create_professional_background",
"validate_video_file",
"PROFESSIONAL_BACKGROUNDS",
]