File size: 26,097 Bytes
efe9b1b
26841b5
c3211a4
 
 
d2502a6
 
 
26841b5
9f4df99
efe9b1b
 
33839bc
f7c6a9c
efe9b1b
fd66920
f7c6a9c
 
 
26841b5
 
9f4df99
33839bc
 
 
 
 
 
 
 
 
 
 
 
 
c3211a4
 
 
 
 
efe9b1b
b3a57d5
f7c6a9c
 
 
 
 
 
 
 
 
b3a57d5
f7c6a9c
 
d2502a6
efe9b1b
f7c6a9c
d2502a6
 
 
 
f7c6a9c
 
d2502a6
 
 
 
 
 
efe9b1b
d2502a6
fd66920
d2502a6
 
fd66920
d2502a6
 
 
 
 
 
 
 
 
 
fd66920
f7c6a9c
 
 
d2502a6
f7c6a9c
d2502a6
 
 
 
 
f7c6a9c
 
fd66920
 
 
 
 
 
d2502a6
 
fd66920
 
d2502a6
fd66920
 
d2502a6
fd66920
 
d2502a6
fd66920
 
d2502a6
fd66920
 
 
 
d2502a6
fd66920
 
 
 
d2502a6
 
fd66920
f7c6a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efe9b1b
c3211a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efe9b1b
b3a57d5
efe9b1b
f7c6a9c
 
 
 
 
 
 
efe9b1b
f7c6a9c
 
 
 
 
 
 
6a16de4
f7c6a9c
 
c3211a4
f7c6a9c
d2502a6
 
 
 
fd66920
b3a57d5
 
 
fd66920
d2502a6
f7c6a9c
 
 
fd66920
b3a57d5
 
fd66920
b3a57d5
 
 
fd66920
b3a57d5
 
 
 
 
fd66920
c3211a4
 
 
 
 
 
 
 
f7c6a9c
6a16de4
 
 
 
 
 
 
04ca462
c3211a4
 
d2502a6
04ca462
33839bc
 
 
 
d2502a6
 
fd66920
b3a57d5
33839bc
d2502a6
fd66920
b3a57d5
 
c3211a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2502a6
 
 
 
 
 
fd66920
d2502a6
fd66920
d2502a6
 
 
 
 
 
 
fd66920
b3a57d5
d2502a6
 
b3a57d5
04ca462
fd66920
d2502a6
b3a57d5
 
 
d2502a6
b3a57d5
 
d2502a6
fd66920
b3a57d5
d2502a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33839bc
c3211a4
 
d2502a6
c3211a4
 
 
 
d2502a6
b3a57d5
c3211a4
 
 
 
 
fd66920
b3a57d5
 
fd66920
b3a57d5
 
d2502a6
b3a57d5
 
04ca462
b3a57d5
04ca462
 
c3211a4
04ca462
6a16de4
 
 
fd66920
 
 
6a16de4
 
 
 
f7c6a9c
c3211a4
fd66920
 
 
 
 
 
 
 
 
f7c6a9c
33839bc
 
 
 
f7c6a9c
fd66920
b3a57d5
33839bc
c3211a4
 
 
 
b3a57d5
fd66920
 
b3a57d5
d2502a6
fd66920
 
 
 
 
 
 
 
d2502a6
c3211a4
 
fd66920
 
 
b3a57d5
fd66920
 
 
 
 
 
 
 
 
b3a57d5
 
fd66920
 
b3a57d5
 
fd66920
 
b3a57d5
 
fd66920
 
b3a57d5
 
fd66920
d2502a6
c3211a4
 
b3a57d5
 
fd66920
b3a57d5
 
fd66920
b3a57d5
 
c3211a4
b3a57d5
c3211a4
 
 
 
b3a57d5
 
c3211a4
 
 
 
 
 
 
 
 
 
 
b3a57d5
c3211a4
d2502a6
fd66920
b3a57d5
 
fd66920
b3a57d5
fd66920
b3a57d5
fd66920
b3a57d5
fd66920
d2502a6
 
b3a57d5
c3211a4
 
 
 
 
 
 
 
 
 
 
b3a57d5
 
d2502a6
fd66920
b3a57d5
fd66920
b3a57d5
 
 
fd66920
b3a57d5
fd66920
d2502a6
 
efe9b1b
 
c3211a4
f7c6a9c
6a16de4
 
 
 
 
 
 
c3211a4
f7c6a9c
 
fd66920
f7c6a9c
 
fd66920
 
 
c3211a4
 
 
 
 
 
 
 
fd66920
f7c6a9c
fd66920
f7c6a9c
fd66920
f7c6a9c
fd66920
f7c6a9c
 
b3a57d5
f7c6a9c
 
 
 
6a16de4
efe9b1b
 
 
26841b5
990992c
9f4df99
efe9b1b
 
 
 
b3a57d5
990992c
26841b5
 
b3a57d5
990992c
efe9b1b
b3a57d5
 
 
26841b5
990992c
efe9b1b
 
26841b5
b3a57d5
efe9b1b
b3a57d5
efe9b1b
b3a57d5
efe9b1b
 
990992c
b3a57d5
990992c
9f4df99
efe9b1b
f7c6a9c
 
 
 
 
 
 
 
 
 
 
 
6a16de4
33839bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
#!/usr/bin/env python3
"""
cv_processing.py · MAXIMUM QUALITY VERSION with enhanced SAM2Handler integration
Updated to work with enhanced SAM2Handler that has full-body detection strategies
Now includes maximum quality mask cleaning and aggressive post-processing

All public functions in this module expect RGB images (H,W,3) unless stated otherwise.
CoreVideoProcessor already converts BGR→RGB before calling into this module.
"""

from __future__ import annotations

import os
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Callable

import cv2
import numpy as np

logger = logging.getLogger(__name__)

# ----------------------------------------------------------------------------
# Environment variable helpers
# ----------------------------------------------------------------------------
def _use_sam2_enabled() -> bool:
    """Check if SAM2 should be used based on environment variable"""
    val = os.getenv("USE_SAM2", "1")
    return val.lower() in ("1", "true", "yes", "on")

def _use_matanyone_enabled() -> bool:
    """Check if MatAnyone should be used based on environment variable"""
    val = os.getenv("USE_MATANYONE", "1")
    return val.lower() in ("1", "true", "yes", "on")

def _use_max_quality_enabled() -> bool:
    """Check if maximum quality processing should be used"""
    val = os.getenv("BFX_QUALITY", "max")
    return val.lower() == "max"

# ----------------------------------------------------------------------------
# Background presets
# ----------------------------------------------------------------------------
PROFESSIONAL_BACKGROUNDS_LOCAL: Dict[str, Dict[str, Any]] = {
    "office":   {"color": (240, 248, 255), "gradient": True},
    "studio":   {"color": (32, 32, 32),    "gradient": False},
    "nature":   {"color": (34, 139, 34),   "gradient": True},
    "abstract": {"color": (75, 0, 130),    "gradient": True},
    "white":    {"color": (255, 255, 255), "gradient": False},
    "black":    {"color": (0, 0, 0),       "gradient": False},
}
PROFESSIONAL_BACKGROUNDS = PROFESSIONAL_BACKGROUNDS_LOCAL

# ----------------------------------------------------------------------------
# Helpers (RGB-safe)
# ----------------------------------------------------------------------------
def _ensure_rgb(img: np.ndarray) -> np.ndarray:
    """
    Identity for RGB HWC images. If channels-first, convert to HWC.
    DOES NOT perform BGR↔RGB swaps (the caller is responsible for color space).
    """
    if img is None:
        return img
    x = np.asarray(img)
    if x.ndim == 3 and x.shape[-1] in (3, 4):
        return x[..., :3]
    if x.ndim == 3 and x.shape[0] in (1, 3, 4) and x.shape[-1] not in (1, 3, 4):
        return np.transpose(x, (1, 2, 0))[..., :3]
    return x

def _ensure_rgb01(frame_rgb: np.ndarray) -> np.ndarray:
    """
    Convert RGB uint8/float to RGB float32 in [0,1], HWC.
    No channel swaps are performed.
    """
    if frame_rgb is None:
        raise ValueError("frame_rgb is None")
    x = _ensure_rgb(frame_rgb)
    if x.dtype == np.uint8:
        return (x.astype(np.float32) / 255.0).copy()
    if np.issubdtype(x.dtype, np.floating):
        return np.clip(x.astype(np.float32), 0.0, 1.0).copy()
    # other integer types
    x = np.clip(x, 0, 255).astype(np.uint8)
    return (x.astype(np.float32) / 255.0).copy()

def _to_mask01(m: np.ndarray) -> np.ndarray:
    if m is None:
        return None
    if m.ndim == 3 and m.shape[2] in (1, 3, 4):
        m = m[..., 0]
    m = np.asarray(m)
    if m.dtype == np.uint8:
        m = m.astype(np.float32) / 255.0
    elif m.dtype != np.float32:
        m = m.astype(np.float32)
    return np.clip(m, 0.0, 1.0)

def _mask_to_2d(mask: np.ndarray) -> np.ndarray:
    """
    Reduce any mask to 2-D float32 [H,W], contiguous, in [0,1].
    Handles HWC/CHW/B1HW/1HW/HW, etc.
    """
    m = np.asarray(mask)

    # CHW with single channel
    if m.ndim == 3 and m.shape[0] == 1 and (m.shape[1] > 1 and m.shape[2] > 1):
        m = m[0]
    # HWC with single channel
    if m.ndim == 3 and m.shape[-1] == 1:
        m = m[..., 0]
    # generic 3D -> take first channel
    if m.ndim == 3:
        m = m[..., 0] if m.shape[-1] in (1, 3, 4) else m[0]

    m = np.squeeze(m)
    if m.ndim != 2:
        # fall back to neutral 0.5 mask
        h = int(m.shape[-2]) if m.ndim >= 2 else 512
        w = int(m.shape[-1]) if m.ndim >= 2 else 512
        logger.warning(f"_mask_to_2d: unexpected shape {mask.shape}, creating neutral mask.")
        m = np.full((h, w), 0.5, dtype=np.float32)

    if m.dtype == np.uint8:
        m = m.astype(np.float32) / 255.0
    elif m.dtype != np.float32:
        m = m.astype(np.float32)

    return np.ascontiguousarray(np.clip(m, 0.0, 1.0))

def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
    if mask01.ndim == 3:
        mask01 = mask01[..., 0]
    k = max(1, int(k) * 2 + 1)
    m = cv2.GaussianBlur((mask01 * 255.0).astype(np.uint8), (k, k), 0)
    return (m.astype(np.float32) / 255.0)

def _vertical_gradient(top: Tuple[int,int,int], bottom: Tuple[int,int,int], width: int, height: int) -> np.ndarray:
    bg = np.zeros((height, width, 3), dtype=np.uint8)
    for y in range(height):
        t = y / max(1, height - 1)
        r = int(top[0] * (1 - t) + bottom[0] * t)
        g = int(top[1] * (1 - t) + bottom[1] * t)
        b = int(top[2] * (1 - t) + bottom[2] * t)
        bg[y, :] = (r, g, b)
    return bg

# ----------------------------------------------------------------------------
# Maximum Quality Mask Cleaning (integrated from TwoStageProcessor)
# ----------------------------------------------------------------------------
def _maximum_quality_mask_cleaning(mask: np.ndarray) -> np.ndarray:
    """Maximum quality mask cleaning and refinement - same as TwoStageProcessor."""
    try:
        # Ensure uint8 format
        if mask.max() <= 1.0:
            mask_uint8 = (mask * 255).astype(np.uint8)
        else:
            mask_uint8 = mask.astype(np.uint8)
        
        # Step 1: Fill small holes aggressively
        kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
        mask_filled = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_fill)
        
        # Step 2: Connect nearby regions
        kernel_connect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
        mask_connected = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, kernel_connect)
        
        # Step 3: Smooth boundaries heavily
        mask_smooth1 = cv2.GaussianBlur(mask_connected, (7, 7), 2.0)
        
        # Step 4: Re-threshold to crisp edges
        _, mask_thresh = cv2.threshold(mask_smooth1, 127, 255, cv2.THRESH_BINARY)
        
        # Step 5: Final edge smoothing
        mask_final = cv2.GaussianBlur(mask_thresh, (5, 5), 1.0)
        
        # Step 6: Dilate slightly to ensure full coverage
        kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        mask_dilated = cv2.dilate(mask_final, kernel_dilate, iterations=1)
        
        logger.info("Maximum quality mask cleaning applied successfully")
        return (mask_dilated.astype(np.float32) / 255.0)
        
    except Exception as e:
        logger.warning(f"Maximum quality mask cleaning failed: {e}")
        return mask

# ----------------------------------------------------------------------------
# Background creation
# ----------------------------------------------------------------------------
def create_professional_background(key_or_cfg: Any, width: int, height: int) -> np.ndarray:
    if isinstance(key_or_cfg, str):
        cfg = PROFESSIONAL_BACKGROUNDS_LOCAL.get(key_or_cfg, PROFESSIONAL_BACKGROUNDS_LOCAL["office"])
    elif isinstance(key_or_cfg, dict):
        cfg = key_or_cfg
    else:
        cfg = PROFESSIONAL_BACKGROUNDS_LOCAL["office"]

    color = tuple(int(x) for x in cfg.get("color", (255, 255, 255)))
    use_grad = bool(cfg.get("gradient", False))

    if not use_grad:
        return np.full((height, width, 3), color, dtype=np.uint8)

    dark = (int(color[0]*0.7), int(color[1]*0.7), int(color[2]*0.7))
    return _vertical_gradient(dark, color, width, height)

# ----------------------------------------------------------------------------
# Improved Segmentation (expects RGB input) - ENHANCED FOR SAM2Handler
# ----------------------------------------------------------------------------
def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
    """Basic fallback segmentation using color detection on RGB frames."""
    h, w = frame_rgb.shape[:2]
    hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)

    lower_skin = np.array([0, 20, 70], dtype=np.uint8)
    upper_skin = np.array([20, 255, 255], dtype=np.uint8)
    skin_mask = cv2.inRange(hsv, lower_skin, upper_skin)

    # detect greenscreen-ish
    lower_green = np.array([40, 40, 40], dtype=np.uint8)
    upper_green = np.array([80, 255, 255], dtype=np.uint8)
    green_mask = cv2.inRange(hsv, lower_green, upper_green)

    person_mask = cv2.bitwise_not(green_mask)
    person_mask = cv2.bitwise_or(person_mask, skin_mask)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
    person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_OPEN, kernel, iterations=1)

    contours, _ = cv2.findContours(person_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        person_mask = np.zeros_like(person_mask)
        cv2.drawContours(person_mask, [largest_contour], -1, 255, -1)

    mask_result = (person_mask.astype(np.float32) / 255.0)
    
    # Apply maximum quality cleaning if enabled
    if _use_max_quality_enabled():
        mask_result = _maximum_quality_mask_cleaning(mask_result)
        logger.info("Applied maximum quality cleaning to fallback segmentation")
    
    return mask_result

def segment_person_hq(
    frame: np.ndarray,
    predictor: Optional[Any] = None,
    fallback_enabled: bool = True,
    use_sam2: Optional[bool] = None,
    **_compat_kwargs,
) -> np.ndarray:
    """
    High-quality person segmentation with ENHANCED SAM2Handler integration.
    Now uses enhanced SAM2Handler.create_mask() for full-body detection.
    Expects RGB frame (H,W,3), uint8 or float in [0,1].
    """
    # Override with environment variable if not explicitly set
    if use_sam2 is None:
        use_sam2 = _use_sam2_enabled()
    
    frame_rgb = _ensure_rgb(frame)
    h, w = frame_rgb.shape[:2]

    if use_sam2 is False:
        logger.info("SAM2 disabled by environment variable, using fallback segmentation")
        return _simple_person_segmentation(frame_rgb)

    if predictor is not None:
        try:
            # ENHANCED: Check if this is the new SAM2Handler with create_mask method
            if hasattr(predictor, 'create_mask'):
                logger.info("Using ENHANCED SAM2Handler.create_mask() with full-body detection")
                # SAM2Handler expects RGB uint8
                if frame_rgb.dtype != np.uint8:
                    rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
                             if np.issubdtype(frame_rgb.dtype, np.floating) else frame_rgb.astype(np.uint8)
                else:
                    rgb_u8 = frame_rgb
                
                # Use enhanced SAM2Handler with full-body detection strategies
                mask = predictor.create_mask(rgb_u8)
                
                if mask is not None:
                    # Convert to float format
                    mask_float = _to_mask01(mask)
                    logger.info(f"Enhanced SAM2Handler mask stats: shape={mask_float.shape}, min={mask_float.min():.3f}, max={mask_float.max():.3f}, mean={mask_float.mean():.3f}")
                    
                    if float(mask_float.max()) > 0.1:
                        # Apply additional maximum quality cleaning if enabled
                        if _use_max_quality_enabled():
                            mask_float = _maximum_quality_mask_cleaning(mask_float)
                            logger.info("Applied additional maximum quality cleaning to enhanced SAM2 result")
                        return np.ascontiguousarray(mask_float)
                    else:
                        logger.warning("Enhanced SAM2Handler mask too weak, using fallback")
                else:
                    logger.warning("Enhanced SAM2Handler returned None mask")
            
            # FALLBACK: Basic SAM2 predictor handling (legacy compatibility)
            elif hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
                logger.info("Using legacy SAM2 predictor interface")
                # Predictor adapter expects RGB uint8; convert if needed
                if frame_rgb.dtype != np.uint8:
                    rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
                             if np.issubdtype(frame_rgb.dtype, np.floating) else frame_rgb.astype(np.uint8)
                else:
                    rgb_u8 = frame_rgb

                predictor.set_image(rgb_u8)

                # Center + a couple of body-biased prompts
                points = np.array([
                    [w // 2, h // 2],
                    [w // 2, h // 4],
                    [w // 2, h // 2 + h // 8],
                ], dtype=np.float32)
                labels = np.array([1, 1, 1], dtype=np.int32)

                result = predictor.predict(
                    point_coords=points,
                    point_labels=labels,
                    multimask_output=True
                )

                # normalize outputs
                if isinstance(result, dict):
                    masks = result.get("masks", None)
                    scores = result.get("scores", None)
                elif isinstance(result, (tuple, list)) and len(result) >= 2:
                    masks, scores = result[0], result[1]
                else:
                    masks, scores = result, None

                if masks is not None:
                    masks = np.asarray(masks)
                    if masks.ndim == 2:
                        mask = masks
                    elif masks.ndim == 3 and masks.shape[0] > 0:
                        if scores is not None:
                            best_idx = int(np.argmax(np.asarray(scores)))
                            mask = masks[best_idx]
                        else:
                            mask = masks[0]
                    elif masks.ndim == 4 and masks.shape[1] == 1:
                        # (N,1,H,W)
                        if scores is not None:
                            best_idx = int(np.argmax(np.asarray(scores)))
                            mask = masks[best_idx, 0]
                        else:
                            mask = masks[0, 0]
                    else:
                        logger.warning(f"Unexpected mask shape from SAM2: {masks.shape}")
                        mask = None

                    if mask is not None:
                        mask = _to_mask01(mask)
                        # Add debug logging
                        logger.info(f"Legacy SAM2 mask stats: shape={mask.shape}, min={mask.min():.3f}, max={mask.max():.3f}, mean={mask.mean():.3f}")
                        
                        if float(mask.max()) > 0.1:
                            # Apply maximum quality cleaning if enabled
                            if _use_max_quality_enabled():
                                mask = _maximum_quality_mask_cleaning(mask)
                                logger.info("Applied maximum quality cleaning to legacy SAM2 result")
                            return np.ascontiguousarray(mask)
                        else:
                            logger.warning("Legacy SAM2 mask too weak, using fallback")
                    else:
                        logger.warning("Legacy SAM2 returned no masks")
            else:
                logger.warning("Predictor doesn't have expected SAM2 interface")

        except Exception as e:
            logger.warning(f"SAM2 segmentation error: {e}")

    if fallback_enabled:
        logger.debug("Using fallback segmentation")
        return _simple_person_segmentation(frame_rgb)
    else:
        return np.ones((h, w), dtype=np.float32)

segment_person_hq_original = segment_person_hq

# ----------------------------------------------------------------------------
# MatAnyone Refinement (Stateful-capable) - ENHANCED WITH MAX QUALITY
# ----------------------------------------------------------------------------
def refine_mask_hq(
    frame: np.ndarray,
    mask: np.ndarray,
    matanyone: Optional[Callable] = None,
    *,
    frame_idx: Optional[int] = None,
    fallback_enabled: bool = True,
    use_matanyone: Optional[bool] = None,
    **_compat_kwargs,
) -> np.ndarray:
    """
    Refine mask with MatAnyone + maximum quality post-processing.

    Modes:
      • Stateful (preferred): provide `frame_idx`. On frame_idx==0, the session encodes with the mask.
        On subsequent frames, the session propagates without a mask.
      • Backward-compat (stateless): if `frame_idx` is None, we try callable/step/process with (frame, mask)
        like before.

    Returns:
      2-D float32 alpha [H,W], contiguous, in [0,1] (OpenCV-safe).
    """
    # Override with environment variable if not explicitly set
    if use_matanyone is None:
        use_matanyone = _use_matanyone_enabled()
    
    mask01 = _to_mask01(mask)

    if use_matanyone is False:
        logger.info("MatAnyone disabled by environment variable, returning unrefined mask")
        # Still apply maximum quality cleaning if enabled
        if _use_max_quality_enabled():
            mask01 = _maximum_quality_mask_cleaning(mask01)
            logger.info("Applied maximum quality cleaning to unrefined mask")
        return mask01

    if matanyone is not None and callable(matanyone):
        try:
            rgb01 = _ensure_rgb01(frame)  # RGB float32 in [0,1]

            # Stateful path (preferred)
            if frame_idx is not None:
                if frame_idx == 0:
                    refined = matanyone(rgb01, mask01)        # encode + first-frame predict inside
                else:
                    refined = matanyone(rgb01)                 # propagate without mask
                refined = _mask_to_2d(refined)
                if float(refined.max()) > 0.1:
                    result = _postprocess_mask_max_quality(refined)
                    return result
                logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")

            # Backward-compat (stateless) path
            refined = None

            # Method 1: Direct callable with (frame, mask)
            try:
                refined = matanyone(rgb01, mask01)
                refined = _mask_to_2d(refined)
            except Exception as e:
                logger.debug(f"MatAnyone callable failed: {e}")

            # Method 2: step(image, mask)
            if refined is None and hasattr(matanyone, 'step'):
                try:
                    refined = matanyone.step(rgb01, mask01)
                    refined = _mask_to_2d(refined)
                except Exception as e:
                    logger.debug(f"MatAnyone step failed: {e}")

            # Method 3: process(image, mask)
            if refined is None and hasattr(matanyone, 'process'):
                try:
                    refined = matanyone.process(rgb01, mask01)
                    refined = _mask_to_2d(refined)
                except Exception as e:
                    logger.debug(f"MatAnyone process failed: {e}")

            if refined is not None and float(refined.max()) > 0.1:
                result = _postprocess_mask_max_quality(refined)
                return result
            else:
                logger.warning("MatAnyone refinement failed or produced empty mask")

        except Exception as e:
            logger.warning(f"MatAnyone error: {e}")

    # Fallback refinement
    if fallback_enabled:
        return _fallback_refine_max_quality(mask01)
    else:
        # Still apply maximum quality cleaning if enabled
        if _use_max_quality_enabled():
            mask01 = _maximum_quality_mask_cleaning(mask01)
            logger.info("Applied maximum quality cleaning to fallback mask")
        return mask01

def _postprocess_mask_max_quality(mask01: np.ndarray) -> np.ndarray:
    """Post-process mask with maximum quality cleaning"""
    if _use_max_quality_enabled():
        # Use the aggressive maximum quality cleaning
        result = _maximum_quality_mask_cleaning(mask01)
        logger.info("Applied maximum quality post-processing to MatAnyone result")
        return result
    else:
        # Use standard post-processing
        return _postprocess_mask(mask01)

def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
    """Standard post-process mask to clean edges and remove artifacts"""
    mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)

    kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_close)

    mask_uint8 = cv2.GaussianBlur(mask_uint8, (3, 3), 0)

    _, mask_uint8 = cv2.threshold(mask_uint8, 127, 255, cv2.THRESH_BINARY)

    mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)

    out = mask_uint8.astype(np.float32) / 255.0
    return np.ascontiguousarray(out)

def _fallback_refine_max_quality(mask01: np.ndarray) -> np.ndarray:
    """Fallback refinement with maximum quality option"""
    if _use_max_quality_enabled():
        # Use aggressive maximum quality cleaning
        result = _maximum_quality_mask_cleaning(mask01)
        logger.info("Applied maximum quality cleaning to fallback refinement")
        return result
    else:
        # Use standard fallback refinement
        return _fallback_refine(mask01)

def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
    """Simple fallback refinement"""
    mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)

    mask_uint8 = cv2.bilateralFilter(mask_uint8, 9, 75, 75)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)
    mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel)

    mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)

    out = mask_uint8.astype(np.float32) / 255.0
    return np.ascontiguousarray(out)

# ----------------------------------------------------------------------------
# Compositing (expects RGB inputs) - ENHANCED WITH MAX QUALITY
# ----------------------------------------------------------------------------
def replace_background_hq(
    frame: np.ndarray,
    mask01: np.ndarray,
    background: np.ndarray,
    fallback_enabled: bool = True,
    **_compat,
) -> np.ndarray:
    """High-quality background replacement with alpha blending (RGB in/out) - enhanced with max quality."""
    try:
        H, W = frame.shape[:2]

        if background.shape[:2] != (H, W):
            background = cv2.resize(background, (W, H), interpolation=cv2.INTER_LANCZOS4)

        m = _mask_to_2d(_to_mask01(mask01))

        # Apply maximum quality cleaning to mask before compositing
        if _use_max_quality_enabled():
            m = _maximum_quality_mask_cleaning(m)
            logger.debug("Applied maximum quality cleaning to compositing mask")

        # Enhanced feathering for maximum quality
        feather_strength = 3 if _use_max_quality_enabled() else 1
        m = _feather(m, k=feather_strength)

        m3 = np.repeat(m[:, :, None], 3, axis=2)

        comp = frame.astype(np.float32) * m3 + background.astype(np.float32) * (1.0 - m3)

        return np.clip(comp, 0, 255).astype(np.uint8)

    except Exception as e:
        if fallback_enabled:
            logger.warning(f"Compositing failed ({e}) – returning original frame")
            return frame
        raise

# ----------------------------------------------------------------------------
# Video validation
# ----------------------------------------------------------------------------
def validate_video_file(video_path: str) -> Tuple[bool, str]:
    if not video_path or not Path(video_path).exists():
        return False, "Video file not found"

    try:
        size = Path(video_path).stat().st_size
        if size == 0:
            return False, "File is empty"
        if size > 2 * 1024 * 1024 * 1024:
            return False, "File > 2 GB"

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return False, "Cannot read file"

        n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cap.release()

        if n_frames == 0:
            return False, "No frames detected"
        if fps <= 0 or fps > 120:
            return False, f"Invalid FPS: {fps}"
        if w <= 0 or h <= 0:
            return False, "Invalid resolution"
        if w > 4096 or h > 4096:
            return False, f"Resolution {w}×{h} too high"
        if (n_frames / fps) > 300:
            return False, "Video longer than 5 minutes"

        return True, f"OK → {w}×{h}, {fps:.1f} fps, {n_frames/fps:.1f}s"

    except Exception as e:
        logger.error(f"validate_video_file: {e}")
        return False, f"Validation error: {e}"

# ----------------------------------------------------------------------------
# Public symbols
# ----------------------------------------------------------------------------
__all__ = [
    "segment_person_hq",
    "segment_person_hq_original",
    "refine_mask_hq",
    "replace_background_hq",
    "create_professional_background",
    "validate_video_file",
    "PROFESSIONAL_BACKGROUNDS",
]