MogensR commited on
Commit
9f4df99
·
1 Parent(s): 9bd9bcd

Create utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +1134 -0
utils/cv_processing.py ADDED
@@ -0,0 +1,1134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
2
+ """Assess mask quality automatically"""
3
+ try:
4
+ h, w = image.shape[:2]
5
+ scores = []
6
+
7
+ mask_area = np.sum(mask > 127)
8
+ total_area = h * w
9
+ area_ratio = mask_area / total_area
10
+
11
+ if 0.05 <= area_ratio <= 0.8:
12
+ area_score = 1.0
13
+ elif area_ratio < 0.05:
14
+ area_score = area_ratio / 0.05
15
+ else:
16
+ area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
17
+ scores.append(area_score)
18
+
19
+ mask_binary = mask > 127
20
+ if np.any(mask_binary):
21
+ mask_center_y, mask_center_x = np.where(mask_binary)
22
+ center_y = np.mean(mask_center_y) / h
23
+ center_x = np.mean(mask_center_x) / w
24
+
25
+ center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
26
+ scores.append(center_score)
27
+ else:
28
+ scores.append(0.0)
29
+
30
+ edges = cv2.Canny(mask, 50, 150)
31
+ edge_density = np.sum(edges > 0) / total_area
32
+ smoothness_score = max(0, 1.0 - edge_density * 10)
33
+ scores.append(smoothness_score)
34
+
35
+ num_labels, _ = cv2.connectedComponents(mask)
36
+ connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2)
37
+ scores.append(connectivity_score)
38
+
39
+ weights = [0.3, 0.2, 0.3, 0.2]
40
+ overall_score = np.average(scores, weights=weights)
41
+
42
+ return overall_score
43
+
44
+ except Exception as e:
45
+ logger.warning(f"Quality assessment failed: {e}")
46
+ return 0.5
47
+
48
+ def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
49
+ """Identify problematic areas in mask"""
50
+ try:
51
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
52
+ edges = cv2.Canny(gray, 50, 150)
53
+ mask_edges = cv2.Canny(mask, 50, 150)
54
+ edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
55
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
56
+ error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
57
+ return error_regions > 0
58
+ except Exception as e:
59
+ logger.warning(f"Error detection failed: {e}")
60
+ return np.zeros_like(mask, dtype=bool)
61
+
62
+ def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
63
+ problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
64
+ """Generate corrective prompts based on problem areas"""
65
+ try:
66
+ contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
67
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+
69
+ corrective_points = []
70
+ corrective_labels = []
71
+
72
+ for contour in contours:
73
+ if cv2.contourArea(contour) > 100:
74
+ M = cv2.moments(contour)
75
+ if M["m00"] != 0:
76
+ cx = int(M["m10"] / M["m00"])
77
+ cy = int(M["m01"] / M["m00"])
78
+
79
+ current_mask_value = mask[cy, cx]
80
+
81
+ if current_mask_value < 127:
82
+ corrective_points.append([cx, cy])
83
+ corrective_labels.append(1)
84
+ else:
85
+ corrective_points.append([cx, cy])
86
+ corrective_labels.append(0)
87
+
88
+ return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
89
+ np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
90
+
91
+ except Exception as e:
92
+ logger.warning(f"Corrective prompt generation failed: {e}")
93
+ return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
94
+
95
+ # ============================================================================
96
+ # HELPER FUNCTIONS - PROCESSING
97
+ # ============================================================================
98
+
99
+ def _process_mask(mask: np.ndarray) -> np.ndarray:
100
+ """Process raw mask to ensure correct format and range"""
101
+ try:
102
+ if len(mask.shape) > 2:
103
+ mask = mask.squeeze()
104
+
105
+ if len(mask.shape) > 2:
106
+ mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2)
107
+
108
+ if mask.dtype == bool:
109
+ mask = mask.astype(np.uint8) * 255
110
+ elif mask.dtype == np.float32 or mask.dtype == np.float64:
111
+ if mask.max() <= 1.0:
112
+ mask = (mask * 255).astype(np.uint8)
113
+ else:
114
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
115
+ else:
116
+ mask = mask.astype(np.uint8)
117
+
118
+ kernel = np.ones((3, 3), np.uint8)
119
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
120
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
121
+
122
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
123
+
124
+ return mask
125
+
126
+ except Exception as e:
127
+ logger.error(f"Mask processing failed: {e}")
128
+ h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256)
129
+ fallback = np.zeros((h, w), dtype=np.uint8)
130
+ fallback[h//4:3*h//4, w//4:3*w//4] = 255
131
+ return fallback
132
+
133
+ def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool:
134
+ """Validate that the mask meets quality criteria"""
135
+ try:
136
+ h, w = image_shape
137
+ mask_area = np.sum(mask > 127)
138
+ total_area = h * w
139
+
140
+ area_ratio = mask_area / total_area
141
+ if area_ratio < 0.05 or area_ratio > 0.8:
142
+ logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}")
143
+ return False
144
+
145
+ mask_binary = mask > 127
146
+ mask_center_y, mask_center_x = np.where(mask_binary)
147
+
148
+ if len(mask_center_y) == 0:
149
+ logger.warning("Empty mask")
150
+ return False
151
+
152
+ center_y = np.mean(mask_center_y)
153
+ center_x = np.mean(mask_center_x)
154
+
155
+ if center_y < h * 0.2 or center_y > h * 0.9:
156
+ logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}")
157
+ return False
158
+
159
+ return True
160
+
161
+ except Exception as e:
162
+ logger.warning(f"Mask validation error: {e}")
163
+ return True
164
+
165
+ def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
166
+ """Fallback segmentation when AI models fail"""
167
+ try:
168
+ logger.info("Using fallback segmentation strategy")
169
+ h, w = image.shape[:2]
170
+
171
+ try:
172
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
173
+
174
+ edge_pixels = np.concatenate([
175
+ gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
176
+ ])
177
+ bg_color = np.median(edge_pixels)
178
+
179
+ diff = np.abs(gray.astype(float) - bg_color)
180
+ mask = (diff > 30).astype(np.uint8) * 255
181
+
182
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
183
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
184
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
185
+
186
+ if _validate_mask_quality(mask, image.shape[:2]):
187
+ logger.info("Background subtraction fallback successful")
188
+ return mask
189
+
190
+ except Exception as e:
191
+ logger.warning(f"Background subtraction fallback failed: {e}")
192
+
193
+ mask = np.zeros((h, w), dtype=np.uint8)
194
+
195
+ center_x, center_y = w // 2, h // 2
196
+ radius_x, radius_y = w // 3, h // 2.5
197
+
198
+ y, x = np.ogrid[:h, :w]
199
+ mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1
200
+ mask[mask_ellipse] = 255
201
+
202
+ logger.info("Using geometric fallback mask")
203
+ return mask
204
+
205
+ except Exception as e:
206
+ logger.error(f"All fallback strategies failed: {e}")
207
+ h, w = image.shape[:2]
208
+ mask = np.zeros((h, w), dtype=np.uint8)
209
+ mask[h//6:5*h//6, w//4:3*w//4] = 255
210
+ return mask
211
+
212
+ def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
213
+ """Attempt MatAnyone mask refinement"""
214
+ try:
215
+ if hasattr(processor, 'infer'):
216
+ refined_mask = processor.infer(image, mask)
217
+ elif hasattr(processor, 'process'):
218
+ refined_mask = processor.process(image, mask)
219
+ elif callable(processor):
220
+ refined_mask = processor(image, mask)
221
+ else:
222
+ logger.warning("Unknown MatAnyone interface")
223
+ return None
224
+
225
+ if refined_mask is None:
226
+ return None
227
+
228
+ refined_mask = _process_mask(refined_mask)
229
+ logger.debug("MatAnyone refinement successful")
230
+ return refined_mask
231
+
232
+ except Exception as e:
233
+ logger.warning(f"MatAnyone processing error: {e}")
234
+ return None
235
+
236
+ def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray:
237
+ """Approximation of guided filter for edge-aware smoothing"""
238
+ try:
239
+ guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide
240
+ guide_gray = guide_gray.astype(np.float32) / 255.0
241
+ mask_float = mask.astype(np.float32) / 255.0
242
+
243
+ kernel_size = 2 * radius + 1
244
+
245
+ mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size))
246
+ mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size))
247
+ corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size))
248
+
249
+ cov_guide_mask = corr_guide_mask - mean_guide * mean_mask
250
+ mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size))
251
+ var_guide = mean_guide_sq - mean_guide * mean_guide
252
+
253
+ a = cov_guide_mask / (var_guide + eps)
254
+ b = mean_mask - a * mean_guide
255
+
256
+ mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size))
257
+ mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size))
258
+
259
+ output = mean_a * guide_gray + mean_b
260
+ output = np.clip(output * 255, 0, 255).astype(np.uint8)
261
+
262
+ return output
263
+
264
+ except Exception as e:
265
+ logger.warning(f"Guided filter approximation failed: {e}")
266
+ return mask
267
+
268
+ # ============================================================================
269
+ # HELPER FUNCTIONS - COMPOSITING
270
+ # ============================================================================
271
+
272
+ def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
273
+ """Advanced compositing with edge feathering and color correction"""
274
+ try:
275
+ threshold = 100
276
+ _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
277
+
278
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
279
+ mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel)
280
+ mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel)
281
+
282
+ mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0)
283
+ mask_smooth = mask_smooth / 255.0
284
+
285
+ mask_smooth = np.power(mask_smooth, 0.8)
286
+
287
+ mask_smooth = np.where(mask_smooth > 0.5,
288
+ np.minimum(mask_smooth * 1.1, 1.0),
289
+ mask_smooth * 0.9)
290
+
291
+ frame_adjusted = _color_match_edges(frame, background, mask_smooth)
292
+
293
+ alpha_3ch = np.stack([mask_smooth] * 3, axis=2)
294
+
295
+ frame_float = frame_adjusted.astype(np.float32)
296
+ background_float = background.astype(np.float32)
297
+
298
+ result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch)
299
+ result = np.clip(result, 0, 255).astype(np.uint8)
300
+
301
+ return result
302
+
303
+ except Exception as e:
304
+ logger.error(f"Advanced compositing error: {e}")
305
+ raise
306
+
307
+ def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
308
+ """Subtle color matching at edges to reduce halos"""
309
+ try:
310
+ edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3)
311
+ edge_mask = np.abs(edge_mask)
312
+ edge_mask = (edge_mask > 0.1).astype(np.float32)
313
+
314
+ edge_areas = edge_mask > 0
315
+ if not np.any(edge_areas):
316
+ return frame
317
+
318
+ frame_adjusted = frame.copy().astype(np.float32)
319
+ background_float = background.astype(np.float32)
320
+
321
+ adjustment_strength = 0.1
322
+ for c in range(3):
323
+ frame_adjusted[:, :, c] = np.where(
324
+ edge_areas,
325
+ frame_adjusted[:, :, c] * (1 - adjustment_strength) +
326
+ background_float[:, :, c] * adjustment_strength,
327
+ frame_adjusted[:, :, c]
328
+ )
329
+
330
+ return np.clip(frame_adjusted, 0, 255).astype(np.uint8)
331
+
332
+ except Exception as e:
333
+ logger.warning(f"Color matching failed: {e}")
334
+ return frame
335
+
336
+ def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
337
+ """Simple fallback compositing method"""
338
+ try:
339
+ logger.info("Using simple compositing fallback")
340
+
341
+ background = cv2.resize(background, (frame.shape[1], frame.shape[0]))
342
+
343
+ if len(mask.shape) == 3:
344
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
345
+ if mask.max() <= 1.0:
346
+ mask = (mask * 255).astype(np.uint8)
347
+
348
+ _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
349
+
350
+ mask_norm = mask_binary.astype(np.float32) / 255.0
351
+ mask_3ch = np.stack([mask_norm] * 3, axis=2)
352
+
353
+ result = frame * mask_3ch + background * (1 - mask_3ch)
354
+ return result.astype(np.uint8)
355
+
356
+ except Exception as e:
357
+ logger.error(f"Simple compositing failed: {e}")
358
+ return frame
359
+
360
+ # ============================================================================
361
+ # HELPER FUNCTIONS - BACKGROUND CREATION
362
+ # ============================================================================
363
+
364
+ def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
365
+ """Create solid color background"""
366
+ color_hex = bg_config["colors"][0].lstrip('#')
367
+ color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
368
+ color_bgr = color_rgb[::-1]
369
+ return np.full((height, width, 3), color_bgr, dtype=np.uint8)
370
+
371
+ def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
372
+ """Create enhanced gradient background with better quality"""
373
+ try:
374
+ colors = bg_config["colors"]
375
+ direction = bg_config.get("direction", "vertical")
376
+
377
+ rgb_colors = []
378
+ for color_hex in colors:
379
+ color_hex = color_hex.lstrip('#')
380
+ rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
381
+ rgb_colors.append(rgb)
382
+
383
+ if not rgb_colors:
384
+ rgb_colors = [(128, 128, 128)]
385
+
386
+ if direction == "vertical":
387
+ background = _create_vertical_gradient(rgb_colors, width, height)
388
+ elif direction == "horizontal":
389
+ background = _create_horizontal_gradient(rgb_colors, width, height)
390
+ elif direction == "diagonal":
391
+ background = _create_diagonal_gradient(rgb_colors, width, height)
392
+ elif direction in ["radial", "soft_radial"]:
393
+ background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial")
394
+ else:
395
+ background = _create_vertical_gradient(rgb_colors, width, height)
396
+
397
+ return cv2.cvtColor(background, cv2.COLOR_RGB2BGR)
398
+
399
+ except Exception as e:
400
+ logger.error(f"Gradient creation error: {e}")
401
+ return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
402
+
403
+ def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray:
404
+ """Create vertical gradient using NumPy for performance"""
405
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
406
+
407
+ for y in range(height):
408
+ progress = y / height if height > 0 else 0
409
+ color = _interpolate_color(colors, progress)
410
+ gradient[y, :] = color
411
+
412
+ return gradient
413
+
414
+ def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray:
415
+ """Create horizontal gradient using NumPy for performance"""
416
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
417
+
418
+ for x in range(width):
419
+ progress = x / width if width > 0 else 0
420
+ color = _interpolate_color(colors, progress)
421
+ gradient[:, x] = color
422
+
423
+ return gradient
424
+
425
+ def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray:
426
+ """Create diagonal gradient using vectorized operations"""
427
+ y_coords, x_coords = np.mgrid[0:height, 0:width]
428
+ max_distance = width + height
429
+ progress = (x_coords + y_coords) / max_distance
430
+ progress = np.clip(progress, 0, 1)
431
+
432
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
433
+ for c in range(3):
434
+ gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
435
+
436
+ return gradient
437
+
438
+ def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray:
439
+ """Create radial gradient using vectorized operations"""
440
+ center_x, center_y = width // 2, height // 2
441
+ max_distance = np.sqrt(center_x**2 + center_y**2)
442
+
443
+ y_coords, x_coords = np.mgrid[0:height, 0:width]
444
+ distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
445
+ progress = distances / max_distance
446
+ progress = np.clip(progress, 0, 1)
447
+
448
+ if soft:
449
+ progress = np.power(progress, 0.7)
450
+
451
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
452
+ for c in range(3):
453
+ gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
454
+
455
+ return gradient
456
+
457
+ def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray:
458
+ """Vectorized color interpolation for performance"""
459
+ if len(colors) == 1:
460
+ return np.full_like(progress, colors[0][channel], dtype=np.uint8)
461
+
462
+ num_segments = len(colors) - 1
463
+ segment_progress = progress * num_segments
464
+ segment_indices = np.floor(segment_progress).astype(int)
465
+ segment_indices = np.clip(segment_indices, 0, num_segments - 1)
466
+ local_progress = segment_progress - segment_indices
467
+
468
+ start_colors = np.array([colors[i][channel] for i in range(len(colors))])
469
+ end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))])
470
+
471
+ start_vals = start_colors[segment_indices]
472
+ end_vals = end_colors[segment_indices]
473
+
474
+ result = start_vals + (end_vals - start_vals) * local_progress
475
+ return np.clip(result, 0, 255).astype(np.uint8)
476
+
477
+ def _interpolate_color(colors: list, progress: float) -> tuple:
478
+ """Interpolate between multiple colors"""
479
+ if len(colors) == 1:
480
+ return colors[0]
481
+ elif len(colors) == 2:
482
+ r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress)
483
+ g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress)
484
+ b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress)
485
+ return (r, g, b)
486
+ else:
487
+ segment = progress * (len(colors) - 1)
488
+ idx = int(segment)
489
+ local_progress = segment - idx
490
+ if idx >= len(colors) - 1:
491
+ return colors[-1]
492
+ c1, c2 = colors[idx], colors[idx + 1]
493
+ r = int(c1[0] + (c2[0] - c1[0]) * local_progress)
494
+ g = int(c1[1] + (c2[1] - c1[1]) * local_progress)
495
+ b = int(c1[2] + (c2[2] - c1[2]) * local_progress)
496
+ return (r, g, b)
497
+
498
+ def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray:
499
+ """Apply brightness and contrast adjustments to background"""
500
+ try:
501
+ brightness = bg_config.get("brightness", 1.0)
502
+ contrast = bg_config.get("contrast", 1.0)
503
+
504
+ if brightness != 1.0 or contrast != 1.0:
505
+ background = background.astype(np.float32)
506
+ background = background * contrast * brightness
507
+ background = np.clip(background, 0, 255).astype(np.uint8)
508
+
509
+ return background
510
+
511
+ except Exception as e:
512
+ logger.warning(f"Background adjustment failed: {e}")
513
+ return background"""
514
+ Computer Vision Processing Module for BackgroundFX Pro
515
+ Contains segmentation, mask refinement, background replacement, and helper functions
516
+ """
517
+
518
+ # Set OMP_NUM_THREADS at the very beginning to prevent libgomp errors
519
+ import os
520
+ if 'OMP_NUM_THREADS' not in os.environ:
521
+ os.environ['OMP_NUM_THREADS'] = '4'
522
+ os.environ['MKL_NUM_THREADS'] = '4'
523
+
524
+ import logging
525
+ from typing import Optional, Tuple, Dict, Any
526
+ import numpy as np
527
+ import cv2
528
+ import torch
529
+
530
+ logger = logging.getLogger(__name__)
531
+
532
+ # ============================================================================
533
+ # CONFIGURATION AND CONSTANTS
534
+ # ============================================================================
535
+
536
+ # Version control flags for CV functions
537
+ USE_ENHANCED_SEGMENTATION = True
538
+ USE_AUTO_TEMPORAL_CONSISTENCY = True
539
+ USE_INTELLIGENT_PROMPTING = True
540
+ USE_ITERATIVE_REFINEMENT = True
541
+
542
+ # Professional background templates
543
+ PROFESSIONAL_BACKGROUNDS = {
544
+ "office_modern": {
545
+ "name": "Modern Office",
546
+ "type": "gradient",
547
+ "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"],
548
+ "direction": "diagonal",
549
+ "description": "Clean, contemporary office environment",
550
+ "brightness": 0.95,
551
+ "contrast": 1.1
552
+ },
553
+ "studio_blue": {
554
+ "name": "Professional Blue",
555
+ "type": "gradient",
556
+ "colors": ["#1e3c72", "#2a5298", "#3498db"],
557
+ "direction": "radial",
558
+ "description": "Broadcast-quality blue studio",
559
+ "brightness": 0.9,
560
+ "contrast": 1.2
561
+ },
562
+ "studio_green": {
563
+ "name": "Broadcast Green",
564
+ "type": "color",
565
+ "colors": ["#00b894"],
566
+ "chroma_key": True,
567
+ "description": "Professional green screen replacement",
568
+ "brightness": 1.0,
569
+ "contrast": 1.0
570
+ },
571
+ "minimalist": {
572
+ "name": "Minimalist White",
573
+ "type": "gradient",
574
+ "colors": ["#ffffff", "#f1f2f6", "#ddd"],
575
+ "direction": "soft_radial",
576
+ "description": "Clean, minimal background",
577
+ "brightness": 0.98,
578
+ "contrast": 0.9
579
+ },
580
+ "warm_gradient": {
581
+ "name": "Warm Sunset",
582
+ "type": "gradient",
583
+ "colors": ["#ff7675", "#fd79a8", "#fdcb6e"],
584
+ "direction": "diagonal",
585
+ "description": "Warm, inviting atmosphere",
586
+ "brightness": 0.85,
587
+ "contrast": 1.15
588
+ },
589
+ "tech_dark": {
590
+ "name": "Tech Dark",
591
+ "type": "gradient",
592
+ "colors": ["#0c0c0c", "#2d3748", "#4a5568"],
593
+ "direction": "vertical",
594
+ "description": "Modern tech/gaming setup",
595
+ "brightness": 0.7,
596
+ "contrast": 1.3
597
+ }
598
+ }
599
+
600
+ # ============================================================================
601
+ # CUSTOM EXCEPTIONS
602
+ # ============================================================================
603
+
604
+ class SegmentationError(Exception):
605
+ """Custom exception for segmentation failures"""
606
+ pass
607
+
608
+ class MaskRefinementError(Exception):
609
+ """Custom exception for mask refinement failures"""
610
+ pass
611
+
612
+ class BackgroundReplacementError(Exception):
613
+ """Custom exception for background replacement failures"""
614
+ pass
615
+
616
+ # ============================================================================
617
+ # MAIN SEGMENTATION FUNCTIONS
618
+ # ============================================================================
619
+
620
+ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
621
+ """High-quality person segmentation with intelligent automation"""
622
+ if not USE_ENHANCED_SEGMENTATION:
623
+ return segment_person_hq_original(image, predictor, fallback_enabled)
624
+
625
+ logger.debug("Using ENHANCED segmentation with intelligent automation")
626
+
627
+ if image is None or image.size == 0:
628
+ raise SegmentationError("Invalid input image")
629
+
630
+ try:
631
+ if predictor is None:
632
+ if fallback_enabled:
633
+ logger.warning("SAM2 predictor not available, using fallback")
634
+ return _fallback_segmentation(image)
635
+ else:
636
+ raise SegmentationError("SAM2 predictor not available")
637
+
638
+ try:
639
+ predictor.set_image(image)
640
+ except Exception as e:
641
+ logger.error(f"Failed to set image in predictor: {e}")
642
+ if fallback_enabled:
643
+ return _fallback_segmentation(image)
644
+ else:
645
+ raise SegmentationError(f"Predictor setup failed: {e}")
646
+
647
+ if USE_INTELLIGENT_PROMPTING:
648
+ mask = _segment_with_intelligent_prompts(image, predictor)
649
+ else:
650
+ mask = _segment_with_basic_prompts(image, predictor)
651
+
652
+ if USE_ITERATIVE_REFINEMENT and mask is not None:
653
+ mask = _auto_refine_mask_iteratively(image, mask, predictor)
654
+
655
+ if not _validate_mask_quality(mask, image.shape[:2]):
656
+ logger.warning("Mask quality validation failed")
657
+ if fallback_enabled:
658
+ return _fallback_segmentation(image)
659
+ else:
660
+ raise SegmentationError("Poor mask quality")
661
+
662
+ logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}")
663
+ return mask
664
+
665
+ except SegmentationError:
666
+ raise
667
+ except Exception as e:
668
+ logger.error(f"Unexpected segmentation error: {e}")
669
+ if fallback_enabled:
670
+ return _fallback_segmentation(image)
671
+ else:
672
+ raise SegmentationError(f"Unexpected error: {e}")
673
+
674
+ def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
675
+ """Original version of person segmentation for rollback"""
676
+ if image is None or image.size == 0:
677
+ raise SegmentationError("Invalid input image")
678
+
679
+ try:
680
+ if predictor is None:
681
+ if fallback_enabled:
682
+ logger.warning("SAM2 predictor not available, using fallback")
683
+ return _fallback_segmentation(image)
684
+ else:
685
+ raise SegmentationError("SAM2 predictor not available")
686
+
687
+ try:
688
+ predictor.set_image(image)
689
+ except Exception as e:
690
+ logger.error(f"Failed to set image in predictor: {e}")
691
+ if fallback_enabled:
692
+ return _fallback_segmentation(image)
693
+ else:
694
+ raise SegmentationError(f"Predictor setup failed: {e}")
695
+
696
+ h, w = image.shape[:2]
697
+
698
+ points = np.array([
699
+ [w//2, h//4],
700
+ [w//2, h//2],
701
+ [w//2, 3*h//4],
702
+ [w//3, h//2],
703
+ [2*w//3, h//2],
704
+ [w//2, h//6],
705
+ [w//4, 2*h//3],
706
+ [3*w//4, 2*h//3],
707
+ ], dtype=np.float32)
708
+
709
+ labels = np.ones(len(points), dtype=np.int32)
710
+
711
+ try:
712
+ with torch.no_grad():
713
+ masks, scores, _ = predictor.predict(
714
+ point_coords=points,
715
+ point_labels=labels,
716
+ multimask_output=True
717
+ )
718
+ except Exception as e:
719
+ logger.error(f"SAM2 prediction failed: {e}")
720
+ if fallback_enabled:
721
+ return _fallback_segmentation(image)
722
+ else:
723
+ raise SegmentationError(f"Prediction failed: {e}")
724
+
725
+ if masks is None or len(masks) == 0:
726
+ logger.warning("SAM2 returned no masks")
727
+ if fallback_enabled:
728
+ return _fallback_segmentation(image)
729
+ else:
730
+ raise SegmentationError("No masks generated")
731
+
732
+ if scores is None or len(scores) == 0:
733
+ logger.warning("SAM2 returned no scores")
734
+ best_mask = masks[0]
735
+ else:
736
+ best_idx = np.argmax(scores)
737
+ best_mask = masks[best_idx]
738
+ logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
739
+
740
+ mask = _process_mask(best_mask)
741
+
742
+ if not _validate_mask_quality(mask, image.shape[:2]):
743
+ logger.warning("Mask quality validation failed")
744
+ if fallback_enabled:
745
+ return _fallback_segmentation(image)
746
+ else:
747
+ raise SegmentationError("Poor mask quality")
748
+
749
+ logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}")
750
+ return mask
751
+
752
+ except SegmentationError:
753
+ raise
754
+ except Exception as e:
755
+ logger.error(f"Unexpected segmentation error: {e}")
756
+ if fallback_enabled:
757
+ return _fallback_segmentation(image)
758
+ else:
759
+ raise SegmentationError(f"Unexpected error: {e}")
760
+
761
+ # ============================================================================
762
+ # MASK REFINEMENT FUNCTIONS
763
+ # ============================================================================
764
+
765
+ def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
766
+ fallback_enabled: bool = True) -> np.ndarray:
767
+ """Enhanced mask refinement with MatAnyone and robust fallbacks"""
768
+ if image is None or mask is None:
769
+ raise MaskRefinementError("Invalid input image or mask")
770
+
771
+ try:
772
+ mask = _process_mask(mask)
773
+
774
+ if matanyone_processor is not None:
775
+ try:
776
+ logger.debug("Attempting MatAnyone refinement")
777
+ refined_mask = _matanyone_refine(image, mask, matanyone_processor)
778
+
779
+ if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]):
780
+ logger.debug("MatAnyone refinement successful")
781
+ return refined_mask
782
+ else:
783
+ logger.warning("MatAnyone produced poor quality mask")
784
+
785
+ except Exception as e:
786
+ logger.warning(f"MatAnyone refinement failed: {e}")
787
+
788
+ if fallback_enabled:
789
+ logger.debug("Using enhanced OpenCV refinement")
790
+ return enhance_mask_opencv_advanced(image, mask)
791
+ else:
792
+ raise MaskRefinementError("MatAnyone failed and fallback disabled")
793
+
794
+ except MaskRefinementError:
795
+ raise
796
+ except Exception as e:
797
+ logger.error(f"Unexpected mask refinement error: {e}")
798
+ if fallback_enabled:
799
+ return enhance_mask_opencv_advanced(image, mask)
800
+ else:
801
+ raise MaskRefinementError(f"Unexpected error: {e}")
802
+
803
+ def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
804
+ """Advanced OpenCV-based mask enhancement with multiple techniques"""
805
+ try:
806
+ if len(mask.shape) == 3:
807
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
808
+
809
+ if mask.max() <= 1.0:
810
+ mask = (mask * 255).astype(np.uint8)
811
+
812
+ refined_mask = cv2.bilateralFilter(mask, 9, 75, 75)
813
+ refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2)
814
+
815
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
816
+ refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close)
817
+
818
+ kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
819
+ refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open)
820
+
821
+ refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8)
822
+
823
+ _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY)
824
+
825
+ return refined_mask
826
+
827
+ except Exception as e:
828
+ logger.warning(f"Enhanced OpenCV refinement failed: {e}")
829
+ return cv2.GaussianBlur(mask, (5, 5), 1.0)
830
+
831
+ # ============================================================================
832
+ # BACKGROUND REPLACEMENT FUNCTIONS
833
+ # ============================================================================
834
+
835
+ def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
836
+ fallback_enabled: bool = True) -> np.ndarray:
837
+ """Enhanced background replacement with comprehensive error handling"""
838
+ if frame is None or mask is None or background is None:
839
+ raise BackgroundReplacementError("Invalid input frame, mask, or background")
840
+
841
+ try:
842
+ background = cv2.resize(background, (frame.shape[1], frame.shape[0]),
843
+ interpolation=cv2.INTER_LANCZOS4)
844
+
845
+ if len(mask.shape) == 3:
846
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
847
+
848
+ if mask.dtype != np.uint8:
849
+ mask = mask.astype(np.uint8)
850
+
851
+ if mask.max() <= 1.0:
852
+ logger.debug("Converting normalized mask to 0-255 range")
853
+ mask = (mask * 255).astype(np.uint8)
854
+
855
+ try:
856
+ result = _advanced_compositing(frame, mask, background)
857
+ logger.debug("Advanced compositing successful")
858
+ return result
859
+
860
+ except Exception as e:
861
+ logger.warning(f"Advanced compositing failed: {e}")
862
+ if fallback_enabled:
863
+ return _simple_compositing(frame, mask, background)
864
+ else:
865
+ raise BackgroundReplacementError(f"Advanced compositing failed: {e}")
866
+
867
+ except BackgroundReplacementError:
868
+ raise
869
+ except Exception as e:
870
+ logger.error(f"Unexpected background replacement error: {e}")
871
+ if fallback_enabled:
872
+ return _simple_compositing(frame, mask, background)
873
+ else:
874
+ raise BackgroundReplacementError(f"Unexpected error: {e}")
875
+
876
+ def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
877
+ """Enhanced professional background creation with quality improvements"""
878
+ try:
879
+ if bg_config["type"] == "color":
880
+ background = _create_solid_background(bg_config, width, height)
881
+ elif bg_config["type"] == "gradient":
882
+ background = _create_gradient_background_enhanced(bg_config, width, height)
883
+ else:
884
+ background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
885
+
886
+ background = _apply_background_adjustments(background, bg_config)
887
+
888
+ return background
889
+
890
+ except Exception as e:
891
+ logger.error(f"Background creation error: {e}")
892
+ return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
893
+
894
+ # ============================================================================
895
+ # VALIDATION FUNCTION
896
+ # ============================================================================
897
+
898
+ def validate_video_file(video_path: str) -> Tuple[bool, str]:
899
+ """Enhanced video file validation with detailed checks"""
900
+ if not video_path or not os.path.exists(video_path):
901
+ return False, "Video file not found"
902
+
903
+ try:
904
+ file_size = os.path.getsize(video_path)
905
+ if file_size == 0:
906
+ return False, "Video file is empty"
907
+
908
+ if file_size > 2 * 1024 * 1024 * 1024:
909
+ return False, "Video file too large (>2GB)"
910
+
911
+ cap = cv2.VideoCapture(video_path)
912
+ if not cap.isOpened():
913
+ return False, "Cannot open video file"
914
+
915
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
916
+ fps = cap.get(cv2.CAP_PROP_FPS)
917
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
918
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
919
+
920
+ cap.release()
921
+
922
+ if frame_count == 0:
923
+ return False, "Video appears to be empty (0 frames)"
924
+
925
+ if fps <= 0 or fps > 120:
926
+ return False, f"Invalid frame rate: {fps}"
927
+
928
+ if width <= 0 or height <= 0:
929
+ return False, f"Invalid resolution: {width}x{height}"
930
+
931
+ if width > 4096 or height > 4096:
932
+ return False, f"Resolution too high: {width}x{height} (max 4096x4096)"
933
+
934
+ duration = frame_count / fps
935
+ if duration > 300:
936
+ return False, f"Video too long: {duration:.1f}s (max 300s)"
937
+
938
+ return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s"
939
+
940
+ except Exception as e:
941
+ return False, f"Error validating video: {str(e)}"
942
+
943
+ # ============================================================================
944
+ # HELPER FUNCTIONS - SEGMENTATION
945
+ # ============================================================================
946
+
947
+ def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
948
+ """Intelligent automatic prompt generation for segmentation"""
949
+ try:
950
+ h, w = image.shape[:2]
951
+ pos_points, neg_points = _generate_smart_prompts(image)
952
+
953
+ if len(pos_points) == 0:
954
+ pos_points = np.array([[w//2, h//2]], dtype=np.float32)
955
+
956
+ points = np.vstack([pos_points, neg_points])
957
+ labels = np.hstack([
958
+ np.ones(len(pos_points), dtype=np.int32),
959
+ np.zeros(len(neg_points), dtype=np.int32)
960
+ ])
961
+
962
+ logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
963
+
964
+ with torch.no_grad():
965
+ masks, scores, _ = predictor.predict(
966
+ point_coords=points,
967
+ point_labels=labels,
968
+ multimask_output=True
969
+ )
970
+
971
+ if masks is None or len(masks) == 0:
972
+ raise SegmentationError("No masks generated")
973
+
974
+ if scores is not None and len(scores) > 0:
975
+ best_idx = np.argmax(scores)
976
+ best_mask = masks[best_idx]
977
+ logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
978
+ else:
979
+ best_mask = masks[0]
980
+
981
+ return _process_mask(best_mask)
982
+
983
+ except Exception as e:
984
+ logger.error(f"Intelligent prompting failed: {e}")
985
+ raise
986
+
987
+ def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
988
+ """Basic prompting method for segmentation"""
989
+ h, w = image.shape[:2]
990
+
991
+ positive_points = np.array([
992
+ [w//2, h//3],
993
+ [w//2, h//2],
994
+ [w//2, 2*h//3],
995
+ ], dtype=np.float32)
996
+
997
+ negative_points = np.array([
998
+ [w//10, h//10],
999
+ [9*w//10, h//10],
1000
+ [w//10, 9*h//10],
1001
+ [9*w//10, 9*h//10],
1002
+ ], dtype=np.float32)
1003
+
1004
+ points = np.vstack([positive_points, negative_points])
1005
+ labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
1006
+
1007
+ with torch.no_grad():
1008
+ masks, scores, _ = predictor.predict(
1009
+ point_coords=points,
1010
+ point_labels=labels,
1011
+ multimask_output=True
1012
+ )
1013
+
1014
+ if masks is None or len(masks) == 0:
1015
+ raise SegmentationError("No masks generated")
1016
+
1017
+ best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
1018
+ best_mask = masks[best_idx]
1019
+
1020
+ return _process_mask(best_mask)
1021
+
1022
+ def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
1023
+ """Generate optimal positive/negative points automatically"""
1024
+ try:
1025
+ h, w = image.shape[:2]
1026
+
1027
+ try:
1028
+ saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
1029
+ success, saliency_map = saliency.computeSaliency(image)
1030
+
1031
+ if success:
1032
+ saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1]
1033
+ contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8),
1034
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
1035
+
1036
+ positive_points = []
1037
+ if contours:
1038
+ for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
1039
+ M = cv2.moments(contour)
1040
+ if M["m00"] != 0:
1041
+ cx = int(M["m10"] / M["m00"])
1042
+ cy = int(M["m01"] / M["m00"])
1043
+ if 0 < cx < w and 0 < cy < h:
1044
+ positive_points.append([cx, cy])
1045
+
1046
+ if positive_points:
1047
+ logger.debug(f"Generated {len(positive_points)} saliency-based points")
1048
+ positive_points = np.array(positive_points, dtype=np.float32)
1049
+ else:
1050
+ raise Exception("No valid saliency points found")
1051
+
1052
+ except Exception as e:
1053
+ logger.debug(f"Saliency method failed: {e}, using fallback")
1054
+ positive_points = np.array([
1055
+ [w//2, h//3],
1056
+ [w//2, h//2],
1057
+ [w//2, 2*h//3],
1058
+ ], dtype=np.float32)
1059
+
1060
+ negative_points = np.array([
1061
+ [10, 10],
1062
+ [w-10, 10],
1063
+ [10, h-10],
1064
+ [w-10, h-10],
1065
+ [w//2, 5],
1066
+ [w//2, h-5],
1067
+ ], dtype=np.float32)
1068
+
1069
+ return positive_points, negative_points
1070
+
1071
+ except Exception as e:
1072
+ logger.warning(f"Smart prompt generation failed: {e}")
1073
+ h, w = image.shape[:2]
1074
+ positive_points = np.array([[w//2, h//2]], dtype=np.float32)
1075
+ negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
1076
+ return positive_points, negative_points
1077
+
1078
+ # ============================================================================
1079
+ # HELPER FUNCTIONS - REFINEMENT
1080
+ # ============================================================================
1081
+
1082
+ def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
1083
+ predictor: Any, max_iterations: int = 2) -> np.ndarray:
1084
+ """Automatically refine mask based on quality assessment"""
1085
+ try:
1086
+ current_mask = initial_mask.copy()
1087
+
1088
+ for iteration in range(max_iterations):
1089
+ quality_score = _assess_mask_quality(current_mask, image)
1090
+ logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
1091
+
1092
+ if quality_score > 0.85:
1093
+ logger.debug(f"Quality sufficient after {iteration} iterations")
1094
+ break
1095
+
1096
+ problem_areas = _find_mask_errors(current_mask, image)
1097
+
1098
+ if np.any(problem_areas):
1099
+ corrective_points, corrective_labels = _generate_corrective_prompts(
1100
+ image, current_mask, problem_areas
1101
+ )
1102
+
1103
+ if len(corrective_points) > 0:
1104
+ try:
1105
+ with torch.no_grad():
1106
+ masks, scores, _ = predictor.predict(
1107
+ point_coords=corrective_points,
1108
+ point_labels=corrective_labels,
1109
+ mask_input=current_mask[None, :, :],
1110
+ multimask_output=False
1111
+ )
1112
+
1113
+ if masks is not None and len(masks) > 0:
1114
+ refined_mask = _process_mask(masks[0])
1115
+
1116
+ if _assess_mask_quality(refined_mask, image) > quality_score:
1117
+ current_mask = refined_mask
1118
+ logger.debug(f"Improved mask in iteration {iteration}")
1119
+ else:
1120
+ logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
1121
+ break
1122
+
1123
+ except Exception as e:
1124
+ logger.debug(f"Refinement iteration {iteration} failed: {e}")
1125
+ break
1126
+ else:
1127
+ logger.debug("No problem areas detected")
1128
+ break
1129
+
1130
+ return current_mask
1131
+
1132
+ except Exception as e:
1133
+ logger.warning(f"Iterative refinement failed: {e}")
1134
+ return initial_mask