MogensR commited on
Commit
e5e6fe5
·
1 Parent(s): 94a9c1b

Create processing/matting.py

Browse files
Files changed (1) hide show
  1. processing/matting.py +450 -0
processing/matting.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced matting algorithms for BackgroundFX Pro.
3
+ Implements multiple matting techniques with automatic fallback.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import cv2
11
+ from typing import Dict, Tuple, Optional, List
12
+ from dataclasses import dataclass
13
+ import logging
14
+
15
+ from ..utils.logger import setup_logger
16
+ from ..utils.device import DeviceManager
17
+ from ..utils.config import ConfigManager
18
+ from ..core.models import ModelFactory, ModelType
19
+ from ..core.quality import QualityAnalyzer
20
+ from ..core.edge import EdgeRefinement
21
+
22
+ logger = setup_logger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class MattingConfig:
27
+ """Configuration for matting operations."""
28
+ alpha_threshold: float = 0.5
29
+ erode_iterations: int = 2
30
+ dilate_iterations: int = 2
31
+ blur_radius: int = 3
32
+ trimap_size: int = 30
33
+ confidence_threshold: float = 0.7
34
+ use_guided_filter: bool = True
35
+ guided_filter_radius: int = 8
36
+ guided_filter_eps: float = 1e-6
37
+ use_temporal_smoothing: bool = False
38
+ temporal_window: int = 5
39
+
40
+
41
+ class AlphaMatting:
42
+ """Advanced alpha matting using multiple techniques."""
43
+
44
+ def __init__(self, config: Optional[MattingConfig] = None):
45
+ self.config = config or MattingConfig()
46
+ self.device_manager = DeviceManager()
47
+ self.quality_analyzer = QualityAnalyzer()
48
+ self.edge_refinement = EdgeRefinement()
49
+
50
+ def create_trimap(self, mask: np.ndarray,
51
+ dilation_size: int = None) -> np.ndarray:
52
+ """
53
+ Create trimap from binary mask.
54
+
55
+ Args:
56
+ mask: Binary mask (H, W)
57
+ dilation_size: Size of uncertain region
58
+
59
+ Returns:
60
+ Trimap with 0 (background), 128 (unknown), 255 (foreground)
61
+ """
62
+ dilation_size = dilation_size or self.config.trimap_size
63
+
64
+ # Ensure binary mask
65
+ if mask.dtype != np.uint8:
66
+ mask = (mask * 255).astype(np.uint8)
67
+
68
+ # Create trimap
69
+ trimap = np.copy(mask)
70
+ kernel = cv2.getStructuringElement(
71
+ cv2.MORPH_ELLIPSE,
72
+ (dilation_size, dilation_size)
73
+ )
74
+
75
+ # Dilate and erode to create unknown region
76
+ dilated = cv2.dilate(mask, kernel, iterations=1)
77
+ eroded = cv2.erode(mask, kernel, iterations=1)
78
+
79
+ # Set unknown region
80
+ trimap[dilated == 255] = 128
81
+ trimap[eroded == 255] = 255
82
+
83
+ return trimap
84
+
85
+ def guided_filter(self, image: np.ndarray,
86
+ guide: np.ndarray,
87
+ radius: int = None,
88
+ eps: float = None) -> np.ndarray:
89
+ """
90
+ Apply guided filter for edge-preserving smoothing.
91
+
92
+ Args:
93
+ image: Input image to filter
94
+ guide: Guide image (usually RGB image)
95
+ radius: Filter radius
96
+ eps: Regularization parameter
97
+
98
+ Returns:
99
+ Filtered image
100
+ """
101
+ radius = radius or self.config.guided_filter_radius
102
+ eps = eps or self.config.guided_filter_eps
103
+
104
+ if len(guide.shape) == 3:
105
+ guide = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY)
106
+
107
+ # Convert to float32
108
+ guide = guide.astype(np.float32) / 255.0
109
+ image = image.astype(np.float32) / 255.0
110
+
111
+ # Box filter helper
112
+ def box_filter(img, r):
113
+ return cv2.boxFilter(img, -1, (r, r))
114
+
115
+ # Guided filter implementation
116
+ mean_I = box_filter(guide, radius)
117
+ mean_p = box_filter(image, radius)
118
+ mean_Ip = box_filter(guide * image, radius)
119
+ cov_Ip = mean_Ip - mean_I * mean_p
120
+
121
+ mean_II = box_filter(guide * guide, radius)
122
+ var_I = mean_II - mean_I * mean_I
123
+
124
+ a = cov_Ip / (var_I + eps)
125
+ b = mean_p - a * mean_I
126
+
127
+ mean_a = box_filter(a, radius)
128
+ mean_b = box_filter(b, radius)
129
+
130
+ output = mean_a * guide + mean_b
131
+ return np.clip(output * 255, 0, 255).astype(np.uint8)
132
+
133
+ def closed_form_matting(self, image: np.ndarray,
134
+ trimap: np.ndarray) -> np.ndarray:
135
+ """
136
+ Closed-form matting using Laplacian matrix.
137
+ Simplified version for real-time processing.
138
+
139
+ Args:
140
+ image: RGB image
141
+ trimap: Trimap with known regions
142
+
143
+ Returns:
144
+ Alpha matte
145
+ """
146
+ h, w = trimap.shape
147
+
148
+ # Initialize alpha with trimap
149
+ alpha = np.copy(trimap).astype(np.float32) / 255.0
150
+
151
+ # Known regions
152
+ is_fg = trimap == 255
153
+ is_bg = trimap == 0
154
+ is_unknown = trimap == 128
155
+
156
+ if not np.any(is_unknown):
157
+ return alpha
158
+
159
+ # Simple propagation from known to unknown regions
160
+ # Using distance transform for efficiency
161
+ dist_fg = cv2.distanceTransform(
162
+ is_fg.astype(np.uint8),
163
+ cv2.DIST_L2, 5
164
+ )
165
+ dist_bg = cv2.distanceTransform(
166
+ is_bg.astype(np.uint8),
167
+ cv2.DIST_L2, 5
168
+ )
169
+
170
+ # Normalize distances
171
+ total_dist = dist_fg + dist_bg + 1e-10
172
+ alpha_unknown = dist_fg / total_dist
173
+
174
+ # Apply only to unknown regions
175
+ alpha[is_unknown] = alpha_unknown[is_unknown]
176
+
177
+ # Apply guided filter for smoothing
178
+ if self.config.use_guided_filter:
179
+ alpha = self.guided_filter(
180
+ (alpha * 255).astype(np.uint8),
181
+ image
182
+ ) / 255.0
183
+
184
+ return np.clip(alpha, 0, 1)
185
+
186
+ def deep_matting(self, image: np.ndarray,
187
+ mask: np.ndarray,
188
+ model: Optional[nn.Module] = None) -> np.ndarray:
189
+ """
190
+ Apply deep learning-based matting refinement.
191
+
192
+ Args:
193
+ image: RGB image
194
+ mask: Initial mask
195
+ model: Optional pre-trained model
196
+
197
+ Returns:
198
+ Refined alpha matte
199
+ """
200
+ device = self.device_manager.get_device()
201
+
202
+ # Prepare input
203
+ h, w = image.shape[:2]
204
+
205
+ # Resize for model input
206
+ input_size = (512, 512)
207
+ image_resized = cv2.resize(image, input_size)
208
+ mask_resized = cv2.resize(mask, input_size)
209
+
210
+ # Convert to tensor
211
+ image_tensor = torch.from_numpy(
212
+ image_resized.transpose(2, 0, 1)
213
+ ).float().unsqueeze(0) / 255.0
214
+
215
+ mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
216
+
217
+ # Move to device
218
+ image_tensor = image_tensor.to(device)
219
+ mask_tensor = mask_tensor.to(device)
220
+
221
+ # If no model provided, use simple refinement
222
+ if model is None:
223
+ # Simple CNN-based refinement
224
+ with torch.no_grad():
225
+ # Concatenate image and mask
226
+ x = torch.cat([image_tensor, mask_tensor], dim=1)
227
+
228
+ # Simple refinement network simulation
229
+ refined = self._simple_refine_network(x)
230
+
231
+ # Convert back to numpy
232
+ alpha = refined.squeeze().cpu().numpy()
233
+ else:
234
+ with torch.no_grad():
235
+ alpha = model(image_tensor, mask_tensor)
236
+ alpha = alpha.squeeze().cpu().numpy()
237
+
238
+ # Resize back to original size
239
+ alpha = cv2.resize(alpha, (w, h))
240
+
241
+ return np.clip(alpha, 0, 1)
242
+
243
+ def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor:
244
+ """Simple refinement network for demonstration."""
245
+ # Extract mask channel
246
+ mask = x[:, 3:4, :, :]
247
+
248
+ # Apply series of filters
249
+ refined = mask
250
+
251
+ # Edge-aware smoothing
252
+ for _ in range(3):
253
+ refined = F.avg_pool2d(refined, 3, stride=1, padding=1)
254
+ refined = torch.sigmoid((refined - 0.5) * 10)
255
+
256
+ return refined
257
+
258
+ def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray:
259
+ """
260
+ Apply morphological operations for refinement.
261
+
262
+ Args:
263
+ alpha: Alpha matte
264
+
265
+ Returns:
266
+ Refined alpha matte
267
+ """
268
+ # Convert to uint8
269
+ alpha_uint8 = (alpha * 255).astype(np.uint8)
270
+
271
+ # Morphological operations
272
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
273
+
274
+ # Remove small holes
275
+ alpha_uint8 = cv2.morphologyEx(
276
+ alpha_uint8, cv2.MORPH_CLOSE, kernel,
277
+ iterations=self.config.erode_iterations
278
+ )
279
+
280
+ # Remove small components
281
+ alpha_uint8 = cv2.morphologyEx(
282
+ alpha_uint8, cv2.MORPH_OPEN, kernel,
283
+ iterations=self.config.dilate_iterations
284
+ )
285
+
286
+ # Smooth boundaries
287
+ if self.config.blur_radius > 0:
288
+ alpha_uint8 = cv2.GaussianBlur(
289
+ alpha_uint8,
290
+ (self.config.blur_radius * 2 + 1, self.config.blur_radius * 2 + 1),
291
+ 0
292
+ )
293
+
294
+ return alpha_uint8.astype(np.float32) / 255.0
295
+
296
+ def process(self, image: np.ndarray,
297
+ mask: np.ndarray,
298
+ method: str = 'auto') -> Dict[str, np.ndarray]:
299
+ """
300
+ Process image with selected matting method.
301
+
302
+ Args:
303
+ image: RGB image
304
+ mask: Initial segmentation mask
305
+ method: Matting method ('auto', 'trimap', 'deep', 'guided')
306
+
307
+ Returns:
308
+ Dictionary with alpha matte and confidence
309
+ """
310
+ try:
311
+ # Analyze quality
312
+ quality_metrics = self.quality_analyzer.analyze_frame(image)
313
+
314
+ # Select method based on quality
315
+ if method == 'auto':
316
+ if quality_metrics['blur_score'] > 50:
317
+ method = 'guided'
318
+ elif quality_metrics['edge_clarity'] > 0.7:
319
+ method = 'trimap'
320
+ else:
321
+ method = 'deep'
322
+
323
+ logger.info(f"Using matting method: {method}")
324
+
325
+ # Apply selected method
326
+ if method == 'trimap':
327
+ trimap = self.create_trimap(mask)
328
+ alpha = self.closed_form_matting(image, trimap)
329
+
330
+ elif method == 'deep':
331
+ alpha = self.deep_matting(image, mask)
332
+
333
+ elif method == 'guided':
334
+ alpha = mask.astype(np.float32) / 255.0
335
+ if self.config.use_guided_filter:
336
+ alpha = self.guided_filter(
337
+ (alpha * 255).astype(np.uint8),
338
+ image
339
+ ) / 255.0
340
+ else:
341
+ # Default to simple refinement
342
+ alpha = mask.astype(np.float32) / 255.0
343
+
344
+ # Apply morphological refinement
345
+ alpha = self.morphological_refinement(alpha)
346
+
347
+ # Edge refinement
348
+ alpha = self.edge_refinement.refine_edges(
349
+ image,
350
+ (alpha * 255).astype(np.uint8)
351
+ ) / 255.0
352
+
353
+ # Calculate confidence
354
+ confidence = self._calculate_confidence(alpha, quality_metrics)
355
+
356
+ return {
357
+ 'alpha': alpha,
358
+ 'confidence': confidence,
359
+ 'method_used': method,
360
+ 'quality_metrics': quality_metrics
361
+ }
362
+
363
+ except Exception as e:
364
+ logger.error(f"Matting processing failed: {e}")
365
+ # Return original mask as fallback
366
+ return {
367
+ 'alpha': mask.astype(np.float32) / 255.0,
368
+ 'confidence': 0.0,
369
+ 'method_used': 'fallback',
370
+ 'error': str(e)
371
+ }
372
+
373
+ def _calculate_confidence(self, alpha: np.ndarray,
374
+ quality_metrics: Dict) -> float:
375
+ """Calculate confidence score for the matting result."""
376
+ # Base confidence from quality metrics
377
+ confidence = quality_metrics.get('overall_quality', 0.5)
378
+
379
+ # Adjust based on alpha distribution
380
+ alpha_mean = np.mean(alpha)
381
+ alpha_std = np.std(alpha)
382
+
383
+ # Good matting should have clear separation
384
+ if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3:
385
+ confidence *= 1.2
386
+
387
+ # Check for edge clarity
388
+ edges = cv2.Canny((alpha * 255).astype(np.uint8), 50, 150)
389
+ edge_ratio = np.sum(edges > 0) / edges.size
390
+
391
+ if edge_ratio < 0.1: # Clear boundaries
392
+ confidence *= 1.1
393
+
394
+ return np.clip(confidence, 0.0, 1.0)
395
+
396
+
397
+ class CompositingEngine:
398
+ """Handle alpha compositing and blending."""
399
+
400
+ def __init__(self):
401
+ self.logger = setup_logger(f"{__name__}.CompositingEngine")
402
+
403
+ def composite(self, foreground: np.ndarray,
404
+ background: np.ndarray,
405
+ alpha: np.ndarray) -> np.ndarray:
406
+ """
407
+ Composite foreground over background using alpha.
408
+
409
+ Args:
410
+ foreground: Foreground image (H, W, 3)
411
+ background: Background image (H, W, 3)
412
+ alpha: Alpha matte (H, W) or (H, W, 1)
413
+
414
+ Returns:
415
+ Composited image
416
+ """
417
+ # Ensure alpha is 3-channel
418
+ if len(alpha.shape) == 2:
419
+ alpha = np.expand_dims(alpha, axis=2)
420
+ if alpha.shape[2] == 1:
421
+ alpha = np.repeat(alpha, 3, axis=2)
422
+
423
+ # Ensure float32
424
+ fg = foreground.astype(np.float32) / 255.0
425
+ bg = background.astype(np.float32) / 255.0
426
+ a = alpha.astype(np.float32)
427
+
428
+ if a.max() > 1.0:
429
+ a = a / 255.0
430
+
431
+ # Alpha blending
432
+ result = fg * a + bg * (1 - a)
433
+
434
+ # Convert back to uint8
435
+ result = np.clip(result * 255, 0, 255).astype(np.uint8)
436
+
437
+ return result
438
+
439
+ def premultiply_alpha(self, image: np.ndarray,
440
+ alpha: np.ndarray) -> np.ndarray:
441
+ """Premultiply image by alpha channel."""
442
+ if len(alpha.shape) == 2:
443
+ alpha = np.expand_dims(alpha, axis=2)
444
+
445
+ result = image.astype(np.float32) * alpha.astype(np.float32)
446
+
447
+ if alpha.max() > 1.0:
448
+ result = result / 255.0
449
+
450
+ return np.clip(result, 0, 255).astype(np.uint8)