MogensR commited on
Commit
6a8b0ae
Β·
1 Parent(s): 6c6785e

Create compositing.py

Browse files
Files changed (1) hide show
  1. compositing.py +102 -0
compositing.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ utils.compositing
4
+ ─────────────────────────────────────────────────────────────────────────────
5
+ Handles frame-level compositing (foreground frame + mask + background).
6
+ Public API
7
+ ----------
8
+ replace_background_hq(frame_bgr, mask, background_bgr, fallback_enabled=True) β†’ np.ndarray
9
+ """
10
+ from __future__ import annotations
11
+ import logging
12
+ from typing import Tuple
13
+ import cv2
14
+ import numpy as np
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+ # Exception class that cv_processing.py expects
19
+ class BackgroundReplacementError(Exception):
20
+ """Exception raised for background replacement errors"""
21
+ pass
22
+
23
+ __all__ = ["replace_background_hq", "BackgroundReplacementError"]
24
+
25
+ # ────────────────────────────────────────────────────────────────────────────
26
+ # Main entry
27
+ # ────────────────────────────────────────────────────────────────────────────
28
+ def replace_background_hq(
29
+ frame_bgr: np.ndarray,
30
+ mask: np.ndarray,
31
+ background_bgr: np.ndarray,
32
+ fallback_enabled: bool = True,
33
+ ) -> np.ndarray:
34
+ """
35
+ β€’ Ensures background is resized to frame
36
+ β€’ Accepts mask in {0,1} or {0,255} or float32
37
+ β€’ Tries edge-feathered advanced blend, else simple overlay
38
+ """
39
+ if frame_bgr is None or mask is None or background_bgr is None:
40
+ raise ValueError("Invalid input to replace_background_hq")
41
+
42
+ h, w = frame_bgr.shape[:2]
43
+ background = cv2.resize(background_bgr, (w, h), interpolation=cv2.INTER_LANCZOS4)
44
+ m = _process_mask(mask)
45
+
46
+ try:
47
+ return _advanced_composite(frame_bgr, background, m)
48
+ except Exception as e:
49
+ log.warning(f"Advanced compositing failed: {e}")
50
+ if not fallback_enabled:
51
+ raise
52
+ return _simple_composite(frame_bgr, background, m)
53
+
54
+ # ────────────────────────────────────────────────────────────────────────────
55
+ # Advanced compositor (feather + subtle colour-match)
56
+ # ────────────────────────────────────────────────────────────────────────────
57
+ def _advanced_composite(fg, bg, mask_u8):
58
+ # 1) Smooth / feather
59
+ mask = cv2.GaussianBlur(mask_u8.astype(np.float32), (5, 5), 1.0) / 255.0
60
+ mask = np.power(mask, 0.8) # shrink bleed
61
+ mask3 = mask[..., None]
62
+
63
+ # 2) Edge colour-match to reduce halo
64
+ fg_adj = _colour_match_edges(fg, bg, mask)
65
+
66
+ # 3) Blend
67
+ comp = fg_adj.astype(np.float32) * mask3 + bg.astype(np.float32) * (1 - mask3)
68
+ return np.clip(comp, 0, 255).astype(np.uint8)
69
+
70
+ def _colour_match_edges(fg, bg, alpha):
71
+ edge = cv2.Sobel(alpha, cv2.CV_32F, 1, 1, ksize=3)
72
+ edge = (np.abs(edge) > 0.05).astype(np.float32)
73
+
74
+ if not np.any(edge):
75
+ return fg
76
+
77
+ adj = fg.astype(np.float32).copy()
78
+ mix = 0.1
79
+ adj[edge > 0] = adj[edge > 0] * (1 - mix) + bg[edge > 0] * mix
80
+ return adj.astype(np.uint8)
81
+
82
+ # ────────────────────────────────────────────────────────────────────────────
83
+ # Simple fallback compositor
84
+ # ────────────────────────────────────────────────────────────────────────────
85
+ def _simple_composite(fg, bg, mask_u8):
86
+ m = mask_u8.astype(np.float32) / 255.0
87
+ m3 = m[..., None]
88
+ return (fg.astype(np.float32) * m3 + bg.astype(np.float32) * (1 - m3)).astype(np.uint8)
89
+
90
+ # ────────────────────────────────────────────────────────────────────────────
91
+ # Utilities
92
+ # ────────────────────────────────────────────────────────────────────────────
93
+ def _process_mask(mask):
94
+ """Ensure uint8 0/255 single-channel"""
95
+ if mask.ndim == 3:
96
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
97
+
98
+ if mask.dtype != np.uint8:
99
+ mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
100
+
101
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
102
+ return mask