File size: 47,835 Bytes
47df7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1a40e
47df7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1a40e
47df7fb
bb1a40e
47df7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1a40e
 
47df7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
#!/usr/bin/env python3
"""
Video Background Replacer (GPU-Optimized)

- MatAnyone (primary), SAM2 (mask seeding), rembg (fallback)
- K-Governor guards torch.topk/kthvalue (no __wrapped__ assumption)
- Adaptive MatAnyone loader (from_pretrained | constructor network/model | repo-id)
- Optional repo pinning via MATANYONE_COMMIT / SAM2_COMMIT
- First-run warmup → READY ✅ before first request
- Robust Gradio input coercion (path | dict | file-like | PIL | NumPy)
- Alpha probing & (optional) stitching alpha_*.png sequences to a video
- Short-clip stabilizer (pre-roll) with correct trim
- Concurrency lock for MatAnyone core
"""

# =========================
# EARLY env & imports
# =========================
import os, sys, re, time, gc, shutil, subprocess, tempfile, threading, traceback, inspect, glob
from pathlib import Path

# ---- Thread/env sanitization (must run BEFORE numpy/torch/cv2) ----
def _safe_int_env(var: str, default: int = 2, cap: int = 8) -> int:
    v = os.environ.get(var, "").strip()
    if not v or not re.fullmatch(r"\d+", v):
        os.environ[var] = str(default); return default
    iv = max(1, min(int(v), cap))
    os.environ[var] = str(iv); return iv

_safe_int_env("OMP_NUM_THREADS", 2, 8)
_safe_int_env("MKL_NUM_THREADS", 2, 8)

# General runtime defaults
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512")
os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# MatAnyone prefs
os.environ.setdefault("MATANYONE_MAX_EDGE", "1024")
os.environ.setdefault("MATANYONE_TARGET_PIXELS", "1000000")
os.environ.setdefault("MATANYONE_WINDOWED", "1")
os.environ.setdefault("MATANYONE_WINDOW", "16")
os.environ.setdefault("MAX_MODEL_SIZE", "1920")

# CUDA + cuDNN
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1")
os.environ.setdefault("CUDNN_BENCHMARK", "1")

# HF cache
os.environ.setdefault("HF_HOME", "./checkpoints/hf")
os.environ.setdefault("TRANSFORMERS_CACHE", "./checkpoints/hf")
os.environ.setdefault("HF_DATASETS_CACHE", "./checkpoints/hf")
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")

# Gradio
os.environ.setdefault("GRADIO_SERVER_NAME", "0.0.0.0")
os.environ.setdefault("GRADIO_SERVER_PORT", "7860")

# Features
os.environ.setdefault("USE_MATANYONE", "true")
os.environ.setdefault("USE_SAM2", "true")
os.environ.setdefault("SELF_CHECK_MODE", "false")

# Stabilizer defaults
os.environ.setdefault("MATANYONE_STABILIZE", "true")
os.environ.setdefault("MATANYONE_PREROLL_FRAMES", "12")

# Optional strict re-sanitization later
os.environ.setdefault("STRICT_ENV_GUARD", "1")

# =========================
# Std imports (safe now)
# =========================
import cv2
import numpy as np
from PIL import Image
import gradio as gr
from moviepy.editor import VideoFileClip, ImageSequenceClip, concatenate_videoclips

print("=" * 50)
print("Application Startup at", os.popen('date').read().strip())
print("=" * 50)
print("Environment Configuration:")
print(f"Python: {sys.version}")
print(f"Working directory: {os.getcwd()}")
print(f"CUDA_MODULE_LOADING: {os.getenv('CUDA_MODULE_LOADING')}")
print(f"OMP_NUM_THREADS: {os.getenv('OMP_NUM_THREADS')}")
print("=" * 50)

# =========================
# Third-party repos & optional pinning
# =========================
BASE_DIR = Path(__file__).resolve().parent
TP_DIR = BASE_DIR / "third_party"
CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
TP_DIR.mkdir(exist_ok=True); CHECKPOINTS_DIR.mkdir(exist_ok=True)

def _git_clone_if_missing(url: str, path: Path, name: str):
    if path.exists():
        return
    print(f"Cloning {name}…")
    try:
        subprocess.run(["git", "clone", "--depth", "1", url, str(path)], check=True, timeout=300)
        print(f"{name} cloned successfully")
    except Exception as e:
        print(f"Failed to clone {name}: {e}")

_git_clone_if_missing("https://github.com/facebookresearch/segment-anything-2.git", TP_DIR/"sam2", "SAM2")
_git_clone_if_missing("https://github.com/pq-yang/MatAnyone.git", TP_DIR/"matanyone", "MatAnyone")

def _checkout(repo_dir: Path, commit: str):
    if not commit:
        print(f"{repo_dir.name} not pinned (env is empty) — using current HEAD.")
        return
    try:
        subprocess.run(["git", "-C", str(repo_dir), "fetch", "--depth", "1", "origin", commit], check=True)
        subprocess.run(["git", "-C", str(repo_dir), "checkout", "--detach", commit], check=True)
        print(f"Locked {repo_dir.name} to {commit}")
    except Exception as e:
        print(f"Warning: failed to lock {repo_dir.name} to {commit}: {e}")

MATANYONE_COMMIT = os.getenv("MATANYONE_COMMIT", "").strip()
SAM2_COMMIT      = os.getenv("SAM2_COMMIT", "").strip()
_checkout(TP_DIR / "matanyone", MATANYONE_COMMIT)
_checkout(TP_DIR / "sam2", SAM2_COMMIT)

# Ensure vendored paths are importable
for p in [TP_DIR / "sam2", TP_DIR / "matanyone"]:
    if p.exists() and str(p) not in sys.path:
        sys.path.insert(0, str(p)); print(f"Added to path: {p}")

# =========================
# K-Governor (with bypass; robust for PyTorch 2.2)
# =========================
if os.getenv("SAFE_TOPK_BYPASS", "0") not in ("1","true","TRUE"):
    import re as _re
    def _write_safe_ops_file(pkg_root: Path):
        utils_dir = pkg_root / "matanyone" / "utils"
        if not utils_dir.exists(): utils_dir = pkg_root / "utils"
        utils_dir.mkdir(parents=True, exist_ok=True)
        (utils_dir / "safe_ops.py").write_text(
            """
import os
import torch

_VERBOSE = bool(int(os.environ.get("SAFE_TOPK_VERBOSE", "1")))

# Robust for builds where topk/kthvalue are builtins without attributes.
_ORIG_TOPK = getattr(torch.topk, "__wrapped__", torch.topk)
_ORIG_KTH  = getattr(torch.kthvalue, "__wrapped__", torch.kthvalue)

def _log(msg):
    if _VERBOSE:
        print(f"[K-Governor] {msg}")

def safe_topk(x, k, dim=None, largest=True, sorted=True):
    if not isinstance(k, int):
        k = int(k)
    if dim is None:
        dim = -1
    n = x.size(dim)
    k_eff = max(1, min(k, int(n)))
    if k_eff != k:
        _log(f"torch.topk: clamp k {k} -> {k_eff} for dim={dim} shape={tuple(x.shape)}")
    values, indices = _ORIG_TOPK(x, k_eff, dim=dim, largest=largest, sorted=sorted)
    if k_eff < k:
        pad = k - k_eff
        pad_shape = list(values.shape); pad_shape[dim] = pad
        pad_vals = values.new_full(pad_shape, float('-inf'))
        pad_idx  = indices.new_zeros(pad_shape, dtype=indices.dtype)
        values = torch.cat([values, pad_vals], dim=dim)
        indices = torch.cat([indices, pad_idx],  dim=dim)
    return values, indices

def safe_kthvalue(x, k, dim=None, keepdim=False):
    if not isinstance(k, int):
        k = int(k)
    if dim is None:
        dim = -1
    n = x.size(dim)
    k_eff = max(1, min(k, int(n)))
    if k_eff != k:
        _log(f"torch.kthvalue: clamp k {k} -> {k_eff} for dim={dim} shape={tuple(x.shape)}")
    return _ORIG_KTH(x, k_eff, dim=dim, keepdim=keepdim)
""".lstrip(), encoding="utf-8")

    def _patch_matanyone_sources(repo_dir: Path) -> int:
        root = repo_dir / "matanyone"
        if not root.exists(): root = repo_dir
        changed = 0
        header_import = "from matanyone.utils.safe_ops import safe_topk, safe_kthvalue\n"
        pt = _re.compile(r"\btorch\.topk\s*\(")
        pm = _re.compile(r"(\b[\w\.]+)\.topk\s*\(")
        kt = _re.compile(r"\btorch\.kthvalue\s*\(")
        km = _re.compile(r"(\b[\w\.]+)\.kthvalue\s*\(")
        for py in root.rglob("*.py"):
            try:
                txt = py.read_text(encoding="utf-8"); orig = txt
                if "safe_topk" not in txt and py.name != "__init__.py":
                    lines = txt.splitlines(keepends=True)
                    insert_at = 0
                    for i, L in enumerate(lines[:80]):
                        if L.startswith(("import ","from ")): insert_at = i+1
                    lines.insert(insert_at, header_import)
                    txt = "".join(lines)
                txt = pt.sub("safe_topk(", txt)
                txt = kt.sub("safe_kthvalue(", txt)
                def _mt(m): return f"safe_topk({m.group(1)}, "
                def _mk(m): return f"safe_kthvalue({m.group(1)}, "
                txt = pm.sub(_mt, txt); txt = km.sub(_mk, txt)
                if txt != orig:
                    py.write_text(txt, encoding="utf-8"); changed += 1
            except Exception as e:
                print(f"[K-Governor] Patch warning on {py}: {e}")
        return changed

    try:
        MATANY_REPO_DIR = TP_DIR / "matanyone"
        _write_safe_ops_file(MATANY_REPO_DIR)
        patched_files = _patch_matanyone_sources(MATANY_REPO_DIR)
        print(f"[K-Governor] Patched MatAnyone sources: {patched_files} files updated.")
    except Exception as e:
        print(f"[K-Governor] Patch failed: {e}")
else:
    print("[K-Governor] BYPASSED via SAFE_TOPK_BYPASS")

# =========================
# Torch & device
# =========================
TORCH_AVAILABLE = False; CUDA_AVAILABLE = False; GPU_NAME = "N/A"; DEVICE = "cpu"
try:
    import torch
    TORCH_AVAILABLE = True
    CUDA_AVAILABLE = torch.cuda.is_available()
    if CUDA_AVAILABLE:
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        GPU_NAME = torch.cuda.get_device_name(0); DEVICE = "cuda"
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU: {GPU_NAME}")
        print(f"VRAM: {gpu_memory:.1f} GB")
        print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")
        try: torch.cuda.set_per_process_memory_fraction(0.9)
        except Exception: pass
    print(f"Torch version: {torch.__version__}")
    print(f"CUDA available: {CUDA_AVAILABLE}")
    print(f"Device: {DEVICE}")
except Exception as e:
    print(f"Torch not available: {e}")

# =========================
# Light GPU monitor
# =========================
class GPUMonitor:
    def __init__(self):
        self.monitoring = False
        self.stats = {"gpu_util": 0, "memory_used": 0, "memory_total": 0}
    def start_monitoring(self):
        if not CUDA_AVAILABLE: return
        self.monitoring = True
        threading.Thread(target=self._monitor_loop, daemon=True).start()
    def stop_monitoring(self): self.monitoring = False
    def _monitor_loop(self):
        while self.monitoring:
            try:
                if CUDA_AVAILABLE:
                    mem_used = torch.cuda.memory_allocated(0) / 1024**3
                    mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    self.stats.update({
                        "memory_used": mem_used, "memory_total": mem_total,
                        "memory_percent": (mem_used/mem_total)*100 if mem_total else 0
                    })
                    try:
                        import pynvml
                        pynvml.nvmlInit()
                        h = pynvml.nvmlDeviceGetHandleByIndex(0)
                        util = pynvml.nvmlDeviceGetUtilizationRates(h)
                        self.stats["gpu_util"] = util.gpu
                    except Exception:
                        pass
            except Exception as e:
                print(f"GPU monitoring error: {e}")
            time.sleep(1)
    def get_stats(self): return self.stats.copy()

gpu_monitor = GPUMonitor(); gpu_monitor.start_monitoring()

# =========================
# SAM2 (verified micro-inference)
# =========================
SAM2_IMPORTED = False; SAM2_AVAILABLE = False; SAM2_PREDICTOR = None
if TORCH_AVAILABLE and os.getenv("USE_SAM2","true").lower()=="true":
    try:
        print("Setting up SAM2…")
        from hydra import initialize_config_dir, compose
        from hydra.core.global_hydra import GlobalHydra
        from sam2.build_sam import build_sam2
        from sam2.sam2_image_predictor import SAM2ImagePredictor
        SAM2_IMPORTED = True
        ckpt = Path("./checkpoints/sam2.1_hiera_tiny.pt")
        ckpt.parent.mkdir(parents=True, exist_ok=True)
        if not ckpt.exists():
            print("Downloading SAM2.1 checkpoint…")
            import requests
            url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt"
            r = requests.get(url, stream=True, timeout=60); r.raise_for_status()
            with open(ckpt, "wb") as f:
                for ch in r.iter_content(chunk_size=8192):
                    if ch: f.write(ch)
            print(f"SAM2 checkpoint downloaded to {ckpt}")
        if GlobalHydra().is_initialized():
            GlobalHydra.instance().clear()
        config_dir = str(TP_DIR / "sam2" / "sam2" / "configs")
        config_file = "sam2.1/sam2.1_hiera_t.yaml"
        initialize_config_dir(config_dir=config_dir, version_base=None)
        _ = compose(config_name=config_file)
        model = build_sam2(config_file, str(ckpt), device="cuda" if CUDA_AVAILABLE else "cpu")
        if CUDA_AVAILABLE and hasattr(torch, "compile"):
            try: model = torch.compile(model, mode="max-autotune")
            except Exception as _e: print(f"torch.compile not used: {_e}")
        SAM2_PREDICTOR = SAM2ImagePredictor(model)
        try:
            dummy = np.zeros((64,64,3), dtype=np.uint8)
            SAM2_PREDICTOR.set_image(dummy)
            pts = np.array([[32,32]], dtype=np.int32); lbs = np.array([1], dtype=np.int32)
            _m,_s,_l = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
            SAM2_AVAILABLE = True; print("✅ SAM2 verified via micro-inference.")
        except Exception as ver_e:
            SAM2_AVAILABLE = False; SAM2_PREDICTOR = None
            print(f"SAM2 verification failed: {ver_e}")
    except Exception as e:
        print(f"SAM2 setup failed: {e}")

# =========================
# MatAnyone import (canonical first, fallback)
# =========================
MATANYONE_IMPORTED = False; MatAnyInferenceCore = None
try:
    from matanyone.inference.inference_core import InferenceCore as MatAnyInferenceCore
    MATANYONE_IMPORTED = True
    print("MatAnyone import OK: matanyone.inference.inference_core.InferenceCore")
except Exception as e1:
    try:
        from matanyone import InferenceCore as MatAnyInferenceCore
        MATANYONE_IMPORTED = True
        print("MatAnyone import OK: matanyone.InferenceCore")
    except Exception as e2:
        print(f"MatAnyone not importable: {e2 or e1}")

# =========================
# rembg fallback
# =========================
REMBG_AVAILABLE = False
try:
    from rembg import remove
    REMBG_AVAILABLE = True; print("rembg import OK (fallback ready).")
except Exception as e:
    print(f"rembg not available: {e}")

# =========================
# Background helpers
# =========================
def make_solid(w, h, rgb): return np.full((h, w, 3), rgb, dtype=np.uint8)
def make_vertical_gradient(w, h, top_rgb, bottom_rgb):
    top = np.array(top_rgb, dtype=np.float32); bot = np.array(bottom_rgb, dtype=np.float32)
    t = np.linspace(0,1,h,dtype=np.float32)[:,None]
    grad = (1-t)*top + t*bot; grad = np.clip(grad,0,255).astype(np.uint8)
    return np.repeat(grad[None,...], w, axis=0).transpose(1,0,2)
def build_professional_bg(w, h, preset: str) -> np.ndarray:
    p = (preset or "").lower()
    if p == "office (soft gray)": return make_vertical_gradient(w,h,(245,246,248),(220,223,228))
    if p == "studio (charcoal)":  return make_vertical_gradient(w,h,(32,32,36),(64,64,70))
    if p == "nature (green tint)":return make_vertical_gradient(w,h,(180,220,190),(100,160,120))
    if p == "brand blue":         return make_solid(w,h,(18,112,214))
    return make_solid(w,h,(240,240,240))

# =========================
# MatAnyone wrapper (+ lock, adaptive constructor, alpha stitching)
# =========================
class OptimizedMatAnyoneProcessor:
    def __init__(self):
        self.processor = None
        self.device = "cuda" if (TORCH_AVAILABLE and CUDA_AVAILABLE) else "cpu"
        self.initialized = False
        self.verified = False
        self.last_error = None
        self.stabilize = os.getenv("MATANYONE_STABILIZE","true").lower()=="true"
        try: self.preroll_frames = max(0, int(os.getenv("MATANYONE_PREROLL_FRAMES","12")))
        except Exception: self.preroll_frames = 12
        self._lock = threading.Lock()

    # ---- Adaptive core constructor
    def _construct_inference_core(self, network_or_repo):
        # prefer classmethod if available
        try:
            if hasattr(MatAnyInferenceCore, "from_pretrained"):
                return MatAnyInferenceCore.from_pretrained(
                    network_or_repo,
                    device=("cuda" if CUDA_AVAILABLE else "cpu")
                )
        except Exception:
            pass
        # try constructor with introspection
        try:
            sig = inspect.signature(MatAnyInferenceCore)
            if isinstance(network_or_repo, str):
                return MatAnyInferenceCore(network_or_repo)
            if "network" in sig.parameters:
                return MatAnyInferenceCore(network=network_or_repo)
            if "model" in sig.parameters:
                return MatAnyInferenceCore(model=network_or_repo)
            return MatAnyInferenceCore(network_or_repo)
        except Exception as e:
            raise RuntimeError(f"InferenceCore construction failed: {type(e).__name__}: {e}")

    # ---- Normalize return + disk probe + png sequence stitch
    def _stitch_alpha_sequence(self, outdir: str, fps: float) -> str | None:
        # common patterns
        patt_list = ["alpha_%04d.png", "alpha_%03d.png", "alpha_%05d.png", "alpha_*.png"]
        frames = []
        for patt in patt_list:
            frames = sorted(glob.glob(os.path.join(outdir, patt.replace("%0", "*").replace("d",""))))
            if frames:
                break
        if not frames:
            return None
        # read as float [0,1]
        ary = []
        for p in frames:
            im = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
            if im is None: continue
            ary.append((im.astype(np.float32) / 255.0))
        if not ary:
            return None
        clip = ImageSequenceClip([f for f in ary], fps=max(1, int(round(fps or 24))))
        alpha_mp4 = tempfile.NamedTemporaryFile(delete=False, suffix="_alpha_seq.mp4").name
        clip.write_videofile(alpha_mp4, audio=False, logger=None)
        clip.close()
        return alpha_mp4

    def _normalize_ret_and_probe(self, ret, outdir: str, fallback_fps: float = 24.0):
        fg_path = alpha_path = None
        if isinstance(ret, (list, tuple)):
            if len(ret) >= 2: fg_path, alpha_path = ret[0], ret[1]
            elif len(ret) == 1: alpha_path = ret[0]
        elif isinstance(ret, str):
            alpha_path = ret

        def _valid(p: str) -> bool:
            return p and os.path.exists(p) and os.path.getsize(p) > 0

        # probe common video names
        if not _valid(alpha_path):
            for cand in ("alpha.mp4","alpha.mkv","alpha.mov","alpha.webm"):
                p = os.path.join(outdir, cand)
                if _valid(p):
                    alpha_path = p; break

        # try stitching sequences if needed
        if not _valid(alpha_path):
            stitched = self._stitch_alpha_sequence(outdir, fallback_fps)
            if stitched and _valid(stitched):
                alpha_path = stitched

        return fg_path, alpha_path

    def _warmup(self) -> None:
        import numpy as _np, cv2 as _cv2, os as _os
        from moviepy.editor import ImageSequenceClip as _ISC
        with tempfile.TemporaryDirectory() as td:
            frames = []
            for t in range(8):
                fr = _np.zeros((64,64,3), _np.uint8); x = 8 + t*4
                _cv2.rectangle(fr, (x,20), (x+12,44), 200, -1); frames.append(fr)
            vid = _os.path.join(td,"warmup.mp4"); _ISC(frames, fps=10).write_videofile(vid, audio=False, logger=None)
            m = _np.zeros((64,64), _np.uint8); _cv2.rectangle(m,(24,24),(40,40),255,-1)
            mask = _os.path.join(td,"mask.png"); _cv2.imwrite(mask, m)
            outdir = _os.path.join(td,"out"); os.makedirs(outdir, exist_ok=True)
            # ensure method exists
            if not hasattr(self.processor, "process_video"):
                if hasattr(self.processor, "process"):
                    self.processor.process_video = self.processor.process
                else:
                    raise RuntimeError("MatAnyone core lacks process_video/process")

            ret = self.processor.process_video(input_path=vid, mask_path=mask, output_path=outdir, max_size=512)
            _fg, alpha = self._normalize_ret_and_probe(ret, outdir, fallback_fps=10)
            if not alpha or not os.path.exists(alpha) or os.path.getsize(alpha) == 0:
                raise RuntimeError("Warmup: MatAnyone produced no alpha")

    def initialize(self) -> bool:
        with self._lock:
            if not MATANYONE_IMPORTED:
                print("MatAnyone not importable; skipping init."); return False
            if self.initialized and self.processor is not None:
                return True
            self.last_error = None

            # HF path first
            try:
                print(f"Initializing MatAnyone (HF repo-id) on {self.device}…")
                self.processor = self._construct_inference_core("PeiqingYang/MatAnyone")
                if self.device == "cuda":
                    import torch as _t
                    _t.cuda.empty_cache(); _ = _t.rand(1, device="cuda") * 0.0
                # alias method if needed
                if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
                    self.processor.process_video = self.processor.process
                self._warmup()
                self.verified = True; self.initialized = True
                print("✅ MatAnyone initialized & warmed up (HF repo-id).")
                return True
            except Exception as e:
                self.last_error = f"HF init failed: {type(e).__name__}: {e}"
                print(self.last_error)

            # Local ckpt fallback
            try:
                print("Falling back to local checkpoint init for MatAnyone…")
                from hydra.core.global_hydra import GlobalHydra
                if hasattr(GlobalHydra,"instance") and GlobalHydra().is_initialized():
                    GlobalHydra.instance().clear()
                import requests
                from matanyone.utils.get_default_model import get_matanyone_model
                ckpt_dir = Path("./pretrained_models"); ckpt_dir.mkdir(parents=True, exist_ok=True)
                ckpt_path = ckpt_dir / "matanyone.pth"
                if not ckpt_path.exists():
                    url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
                    print(f"Downloading MatAnyone checkpoint from: {url}")
                    with requests.get(url, stream=True, timeout=180) as r:
                        r.raise_for_status()
                        with open(ckpt_path, "wb") as f:
                            for chunk in r.iter_content(chunk_size=8192):
                                if chunk: f.write(chunk)
                    print(f"Checkpoint saved to {ckpt_path}")
                network = get_matanyone_model(str(ckpt_path), device=("cuda" if CUDA_AVAILABLE else "cpu"))
                self.processor = self._construct_inference_core(network)
                if self.device == "cuda":
                    import torch as _t
                    _t.cuda.empty_cache(); _ = _t.rand(1, device="cuda") * 0.0
                if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
                    self.processor.process_video = self.processor.process
                self._warmup()
                self.verified = True; self.initialized = True
                print("✅ MatAnyone initialized & warmed up (local checkpoint).")
                return True
            except Exception as e:
                self.last_error = f"Local init/warmup failed: {type(e).__name__}: {e}"
                print(f"MatAnyone initialization failed: {self.last_error}")
                traceback.print_exc(); return False

    # ---- Pre-roll & trimming
    @staticmethod
    def _build_preroll_concat(input_path: str, frames: int) -> tuple[str, float, float]:
        clip = VideoFileClip(input_path)
        fps = float(clip.fps or 24.0)
        preroll_frames = max(0, frames)
        if preroll_frames == 0:
            out = input_path; clip.close(); return out, 0.0, fps
        first = clip.get_frame(0)
        pre = ImageSequenceClip([first]*preroll_frames, fps=max(1, int(round(fps))))
        concat = concatenate_videoclips([pre, clip])
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix="_concat.mp4")
        concat.write_videofile(tmp.name, audio=False, logger=None)
        pre.close(); concat.close(); clip.close()
        return tmp.name, preroll_frames / fps, fps

    @staticmethod
    def _trim_head(video_path: str, seconds: float) -> str:
        if seconds <= 0: return video_path
        clip = VideoFileClip(video_path); dur = clip.duration or 0
        start = min(seconds, max(0.0, dur - 0.001))
        trimmed = tempfile.NamedTemporaryFile(delete=False, suffix="_trim.mp4").name
        clip.subclip(start, None).write_videofile(trimmed, audio=False, logger=None)
        clip.close(); return trimmed

    def create_mask_optimized(self, video_path: str, output_path: str) -> str:
        cap = cv2.VideoCapture(video_path); ret, frame = cap.read(); cap.release()
        if not ret: raise ValueError("Could not read first frame from video.")
        if SAM2_AVAILABLE and SAM2_PREDICTOR is not None:
            try:
                print("Creating mask with SAM2 (first frame)…")
                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                SAM2_PREDICTOR.set_image(rgb)
                h, w = rgb.shape[:2]
                pts = np.array([[w//2, h//2],[w//3, h//3],[2*w//3, 2*h//3]], dtype=np.int32)
                lbs = np.array([1,1,1], dtype=np.int32)
                masks, scores, _ = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
                best = masks[np.argmax(scores)]
                mask = ((best.astype(np.uint8) > 0).astype(np.uint8)) * 255  # 1ch u8 {0,255}
                cv2.imwrite(output_path, mask)
                print(f"Self-test mask uniques: {np.unique(mask//255)}")
                return output_path
            except Exception as e:
                print(f"SAM2 mask creation failed; fallback rectangle. Error: {e}")
        # Fallback: centered box
        h, w = frame.shape[:2]
        mask = np.zeros((h,w), dtype=np.uint8)
        mx, my = int(w*0.15), int(h*0.10)
        mask[my:h-my, mx:w-mx] = 255
        cv2.imwrite(output_path, mask); return output_path

    def process_video_optimized(self, input_path: str, output_dir: str):
        with self._lock:
            if not self.initialized and not self.initialize():
                return None
            try:
                print("🚀 MatAnyone processing…")
                if CUDA_AVAILABLE:
                    import torch as _t
                    _t.cuda.empty_cache(); gc.collect()

                concat_path = input_path; preroll_sec = 0.0; fps_used = 24.0
                if self.stabilize and self.preroll_frames > 0:
                    concat_path, preroll_sec, fps_used = self._build_preroll_concat(input_path, self.preroll_frames)
                    print(f"[Stabilizer] Pre-rolled {self.preroll_frames} frames ({preroll_sec:.3f}s).")

                mask_path = os.path.join(output_dir, "mask.png")
                self.create_mask_optimized(input_path, mask_path)

                if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
                    self.processor.process_video = self.processor.process

                ret = self.processor.process_video(
                    input_path=concat_path,
                    mask_path=mask_path,
                    output_path=output_dir,
                    max_size=int(os.getenv("MAX_MODEL_SIZE","1920"))
                )
                fg_path, alpha_path = self._normalize_ret_and_probe(ret, output_dir, fallback_fps=fps_used)

                if not alpha_path or not os.path.exists(alpha_path):
                    raise RuntimeError("MatAnyone finished without a valid alpha video on disk.")

                if preroll_sec > 0.0:
                    alpha_path = self._trim_head(alpha_path, preroll_sec)
                    print(f"[Stabilizer] Trimmed {preroll_sec:.3f}s from alpha.")

                if not os.path.exists(alpha_path) or os.path.getsize(alpha_path) == 0:
                    raise RuntimeError("Alpha exists but is empty/zero bytes after trim.")

                return alpha_path

            except Exception as e:
                print(f"❌ MatAnyone processing failed: {e}")
                traceback.print_exc()
                return None

matanyone_processor = OptimizedMatAnyoneProcessor()

# =========================
# rembg helpers
# =========================
REMBG_AVAILABLE = REMBG_AVAILABLE
def process_frame_rembg_optimized(frame_bgr_u8, bg_img_rgb_u8):
    if not REMBG_AVAILABLE:
        return cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)
    try:
        frame_rgb = cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)
        pil_im = Image.fromarray(frame_rgb)
        from rembg import remove  # lazy import in case plugin is heavy
        result = remove(pil_im).convert("RGBA")
        result_np = np.array(result)
        if result_np.shape[2] == 4:
            alpha = (result_np[:, :, 3:4].astype(np.float32) / 255.0)
            comp = alpha * result_np[:, :, :3].astype(np.float32) + (1 - alpha) * bg_img_rgb_u8.astype(np.float32)
            return comp.astype(np.uint8)
        return result_np.astype(np.uint8)
    except Exception as e:
        print(f"rembg processing error: {e}")
        return cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)

# =========================
# Compositing
# =========================
def composite_with_background(original_path, alpha_path, bg_path=None, bg_preset=None):
    print("🎬 Compositing final video…")
    orig_clip = VideoFileClip(original_path)
    alpha_clip = VideoFileClip(alpha_path)
    fps = orig_clip.fps or 24
    w, h = orig_clip.size
    if bg_path:
        bg_img = cv2.imread(bg_path)
        if bg_img is None: raise ValueError(f"Could not read background image: {bg_path}")
        bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB); bg_img = cv2.resize(bg_img, (w, h))
    else:
        bg_img = build_professional_bg(w, h, bg_preset)

    def process_func(get_frame, t):
        frame = get_frame(t)
        a = alpha_clip.get_frame(t)
        if a.ndim == 2: a = a[..., None]
        elif a.shape[2] > 1: a = a[..., :1]
        a = np.clip(a, 0.0, 1.0).astype(np.float32)
        bg_f32 = (bg_img.astype(np.float32) / 255.0)
        comp = a * frame.astype(np.float32) + (1.0 - a) * bg_f32
        return comp.astype(np.float32)

    new_clip = orig_clip.fl(process_func).set_fps(fps)
    output_path = "final_output.mp4"
    new_clip.write_videofile(output_path, audio=False, logger=None)
    alpha_clip.close(); orig_clip.close(); new_clip.close()
    return output_path

# =========================
# rembg whole-video fallback
# =========================
def process_video_rembg_fallback(video_path, bg_image_path=None, bg_preset=None):
    print("🔄 Processing with rembg fallback…")
    cap = cv2.VideoCapture(video_path); ret, frame = cap.read()
    if not ret: cap.release(); raise ValueError("Could not read video")
    h, w, _ = frame.shape; cap.release()
    if bg_image_path:
        bg_img = cv2.imread(bg_image_path)
        if bg_img is None: raise ValueError(f"Could not read background image: {bg_image_path}")
        bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB); bg_img = cv2.resize(bg_img, (w, h))
    else:
        bg_img = build_professional_bg(w, h, bg_preset)
    clip = VideoFileClip(video_path)
    fps = clip.fps or 24
    def process_func(get_frame, t):
        fr = get_frame(t)
        fr_u8 = (fr * 255).astype(np.uint8)
        comp = process_frame_rembg_optimized(cv2.cvtColor(fr_u8, cv2.COLOR_RGB2BGR), bg_img)
        return (comp.astype(np.float32) / 255.0)
    new_clip = clip.fl(process_func).set_fps(fps)
    output_path = "rembg_output.mp4"
    new_clip.write_videofile(output_path, audio=False, logger=None)
    clip.close(); new_clip.close()
    return output_path

# =========================
# Self-test harness
# =========================
def _ok(flag): return "✅" if flag else "❌"
def self_test_cuda():
    try:
        if not TORCH_AVAILABLE: return False, "Torch not importable"
        if not CUDA_AVAILABLE: return False, "CUDA not available"
        import torch as _t
        a = _t.randn((1024,1024), device="cuda"); b = _t.randn((1024,1024), device="cuda")
        c = (a @ b).mean().item(); return True, f"CUDA matmul ok, mean={c:.6f}"
    except Exception as e: return False, f"CUDA op failed: {e}"
def self_test_ffmpeg_moviepy():
    try:
        ff = shutil.which("ffmpeg")
        if not ff: return False, "ffmpeg not found on PATH"
        frames = [(np.zeros((64,64,3), np.uint8) + i).clip(0,255) for i in range(0,200,25)]
        clip = ImageSequenceClip(frames, fps=4)
        with tempfile.TemporaryDirectory() as td:
            vp = os.path.join(td, "tiny.mp4")
            clip.write_videofile(vp, audio=False, logger=None); clip.close()
            clip_r = VideoFileClip(vp); _ = clip_r.get_frame(0.1); clip_r.close()
        return True, "FFmpeg/MoviePy encode/decode ok"
    except Exception as e: return False, f"FFmpeg/MoviePy test failed: {e}"
def self_test_rembg():
    try:
        if not REMBG_AVAILABLE: return False, "rembg not importable"
        from rembg import remove
        img = np.zeros((64,64,3), dtype=np.uint8); img[:,:] = (0,255,0)
        pil = Image.fromarray(img); out = remove(pil)
        ok = isinstance(out, Image.Image) and out.size == (64,64)
        return ok, "rembg ok" if ok else "rembg returned unexpected output"
    except Exception as e: return False, f"rembg failed: {e}"
def self_test_sam2():
    try:
        if not SAM2_IMPORTED: return False, "SAM2 not importable"
        if not SAM2_PREDICTOR: return False, "SAM2 predictor not initialized"
        dummy = np.zeros((64,64,3), dtype=np.uint8)
        SAM2_PREDICTOR.set_image(dummy)
        pts = np.array([[32,32]], dtype=np.int32); lbs = np.array([1], dtype=np.int32)
        masks, scores, _ = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
        ok = masks is not None and len(masks) > 0
        return ok, "SAM2 micro-inference ok" if ok else "SAM2 predict returned no masks"
    except Exception as e: return False, f"SAM2 micro-inference failed: {e}"
def self_test_matanyone():
    try:
        ok_init = matanyone_processor.initialize()
        if not ok_init: return False, f"MatAnyone init failed: {getattr(matanyone_processor,'last_error','no details')}"
        if not matanyone_processor.verified: return False, "MatAnyone missing process_video API"
        with tempfile.TemporaryDirectory() as td:
            frames = []
            for t in range(8):
                frame = np.zeros((64,64,3), dtype=np.uint8)
                x = 8 + t*4; cv2.rectangle(frame, (x,20),(x+12,44), 200, -1); frames.append(frame)
            vid_path = os.path.join(td,"tiny_input.mp4")
            clip = ImageSequenceClip(frames, fps=8); clip.write_videofile(vid_path, audio=False, logger=None); clip.close()
            mask = np.zeros((64,64), dtype=np.uint8); cv2.rectangle(mask,(24,24),(40,40),255,-1)
            mask_path = os.path.join(td,"mask.png"); cv2.imwrite(mask_path, mask)
            alpha = matanyone_processor.process_video_optimized(vid_path, td)
            if alpha is None or not os.path.exists(alpha): return False, "MatAnyone did not produce alpha video"
            _alpha_clip = VideoFileClip(alpha); _ = _alpha_clip.get_frame(0.1); _alpha_clip.close()
            return True, "MatAnyone process_video ok"
    except Exception as e: return False, f"MatAnyone test failed: {e}"
def run_self_test() -> str:
    lines = []
    lines.append("=== SELF TEST REPORT ===")
    lines.append(f"Python: {sys.version.split()[0]}")
    lines.append(f"Torch: {torch.__version__ if TORCH_AVAILABLE else 'N/A'} | CUDA: {CUDA_AVAILABLE} | Device: {DEVICE} | GPU: {GPU_NAME}")
    lines.append(f"FFmpeg on PATH: {bool(shutil.which('ffmpeg'))}")
    lines.append("")
    tests = [("CUDA", self_test_cuda), ("FFmpeg/MoviePy", self_test_ffmpeg_moviepy),
             ("rembg", self_test_rembg), ("SAM2", self_test_sam2), ("MatAnyone", self_test_matanyone)]
    for name, fn in tests:
        t0 = time.time(); ok, msg = fn(); dt = time.time() - t0
        lines.append(f"{_ok(ok)} {name}: {msg} [{dt:.2f}s]")
    return "\n".join(lines)

# =========================
# Gradio input coercion helpers
# =========================
def _coerce_video_to_path(video_file):
    if video_file is None:
        return None
    if isinstance(video_file, str):
        return video_file
    if isinstance(video_file, dict) and "name" in video_file:
        return video_file["name"]
    return getattr(video_file, "name", None)

def _coerce_bg_to_path(bg_image, temp_dir):
    """Return filesystem path for background image, writing it to temp_dir if needed."""
    if bg_image is None:
        return None
    if isinstance(bg_image, str):
        return bg_image
    if isinstance(bg_image, dict) and "name" in bg_image:
        return bg_image["name"]
    if hasattr(bg_image, "name") and isinstance(bg_image.name, str):
        return bg_image.name
    if isinstance(bg_image, Image.Image):
        p = os.path.join(temp_dir, "bg_uploaded.png")
        bg_image.save(p); return p
    if isinstance(bg_image, np.ndarray):
        p = os.path.join(temp_dir, "bg_uploaded.png")
        arr = bg_image
        if arr.ndim == 3 and arr.shape[2] == 3:
            cv2.imwrite(p, cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
        else:
            cv2.imwrite(p, arr)
        return p
    return None

# =========================
# Gradio callback
# =========================
def gradio_interface_optimized(video_file, bg_image, use_matanyone=True, bg_preset="Office (Soft Gray)", stabilize=True, preroll_frames=12):
    try:
        if video_file is None:
            return None, None, "Please upload a video."
        print(f"UI types: video={type(video_file)}, bg={type(bg_image)}")

        with tempfile.TemporaryDirectory() as temp_dir:
            video_path = _coerce_video_to_path(video_file)
            if not video_path or not os.path.exists(video_path):
                return None, None, "Could not read the uploaded video path."
            bg_path = _coerce_bg_to_path(bg_image, temp_dir)  # may be None → preset is used

            # reflect UI choices
            matanyone_processor.stabilize = bool(stabilize)
            try:
                matanyone_processor.preroll_frames = max(0, int(preroll_frames))
            except Exception:
                pass

            start_time = time.time()

            if use_matanyone and MATANYONE_IMPORTED:
                if not matanyone_processor.initialized:
                    matanyone_processor.initialize()

                if matanyone_processor.initialized and matanyone_processor.verified:
                    alpha_video_path = matanyone_processor.process_video_optimized(video_path, temp_dir)
                    if alpha_video_path is None:
                        out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
                        method = "rembg (fallback after MatAnyone error)"
                    else:
                        out = composite_with_background(video_path, alpha_video_path, bg_path, bg_preset=bg_preset)
                        method = f"MatAnyone (GPU: {CUDA_AVAILABLE})"
                else:
                    out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
                    method = "rembg (MatAnyone not verified)"
            else:
                out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
                method = "rembg"

            final_gpu = gpu_monitor.get_stats()
            elapsed = time.time() - start_time
            status = (
                f"✅ Processing complete\n"
                f"Method: {method}\n"
                f"Time: {elapsed:.2f}s\n"
                f"Output: {out}\n\n"
                f"GPU Stats:\n"
                f"• Mem: {final_gpu.get('memory_used', 0):.2f}GB / {final_gpu.get('memory_total', 0):.2f}GB"
                f" ({final_gpu.get('memory_percent', 0):.1f}%)\n"
                f"• Util: {final_gpu.get('gpu_util', 0)}%\n"
                f"• CUDA: {CUDA_AVAILABLE}"
            )
            return out, out, status

    except Exception as e:
        traceback.print_exc()
        msg = (
            f"❌ Error: {e}\n"
            f"- MatAnyone imported: {MATANYONE_IMPORTED}\n"
            f"- MatAnyone initialized: {matanyone_processor.initialized}\n"
            f"- MatAnyone verified: {matanyone_processor.verified}\n"
            f"- MatAnyone last_error: {matanyone_processor.last_error}\n"
            f"- SAM2 imported: {SAM2_IMPORTED}\n"
            f"- SAM2 verified: {SAM2_AVAILABLE}\n"
            f"- rembg: {REMBG_AVAILABLE}\n"
            f"- CUDA: {CUDA_AVAILABLE}\n"
            f"(see server logs for traceback)"
        )
        return None, None, msg

def gradio_run_self_test(): return run_self_test()
def show_matanyone_diag():
    try:
        ok = matanyone_processor.initialized and matanyone_processor.verified
        return "READY ✅" if ok else (matanyone_processor.last_error or "Not initialized yet")
    except Exception as e:
        return f"Diag error: {e}"

# =========================
# UI
# =========================
with gr.Blocks(title="Video Background Replacer - GPU Optimized", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎬 Video Background Replacer (GPU Optimized)")
    gr.Markdown("All green checks are earned by real tests. No guesses.")
    gpu_status = f"✅ {GPU_NAME}" if CUDA_AVAILABLE else "❌ CPU Only"
    matany_status = "✅ Module Imported" if MATANYONE_IMPORTED else "❌ Not Importable"
    sam2_status = "✅ Verified" if SAM2_AVAILABLE else ("⚠️ Imported but unverified" if SAM2_IMPORTED else "❌ Not Ready")
    rembg_status = "✅ Ready" if REMBG_AVAILABLE else "❌ Not Available"
    torch_status = "✅ GPU" if CUDA_AVAILABLE else "❌ CPU"
    status_html = f"""
    <div style='padding: 15px; background: #f8f9fa; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #6c757d;'>
        <h4 style='margin-top: 0;'>🖥️ System Status (verified)</h4>
        <strong>GPU:</strong> {gpu_status}<br>
        <strong>Device:</strong> {DEVICE}<br>
        <strong>MatAnyone module:</strong> {matany_status}<br>
        <strong>MatAnyone ready:</strong> {"✅ Yes" if getattr(matanyone_processor, "verified", False) else "❌ No"}<br>
        <strong>SAM2:</strong> {sam2_status}<br>
        <strong>rembg:</strong> {rembg_status}<br>
        <strong>PyTorch:</strong> {torch_status}
    </div>
    """
    gr.HTML(status_html)

    with gr.Row():
        with gr.Column():
            video_input = gr.Video(label="📹 Input Video")
            bg_input = gr.Image(label="🖼️ Background Image (optional)", type="filepath")
            bg_preset = gr.Dropdown(
                label="🎨 Background Preset (if no image)",
                choices=["Office (Soft Gray)","Studio (Charcoal)","Nature (Green Tint)","Brand Blue","Plain Light"],
                value="Office (Soft Gray)",
            )
            use_matanyone = gr.Checkbox(label="🚀 Use MatAnyone (GPU accelerated, best quality)",
                                        value=MATANYONE_IMPORTED, interactive=MATANYONE_IMPORTED)
            stabilize = gr.Checkbox(label="🧱 Stabilize short clips (pre-roll first frame)",
                                    value=os.getenv("MATANYONE_STABILIZE","true").lower()=="true")
            preroll_frames = gr.Slider(label="Pre-roll frames", minimum=0, maximum=24, step=1,
                                       value=int(os.getenv("MATANYONE_PREROLL_FRAMES","12")))
            process_btn = gr.Button("🚀 Process Video", variant="primary")
            gr.Markdown("### 🔎 Self-Verification"); selftest_btn = gr.Button("Run Self-Test")
            selftest_out = gr.Textbox(label="Self-Test Report", lines=16)
            gr.Markdown("### 🛠 MatAnyone Diagnostics"); mat_diag_btn = gr.Button("Show MatAnyone Diagnostics")
            mat_diag_out = gr.Textbox(label="MatAnyone Last Error / Status", lines=6)
        with gr.Column():
            output_video = gr.Video(label="✨ Result")
            download_file = gr.File(label="💾 Download")
            status_text = gr.Textbox(label="📊 Status & Performance", lines=8)

    process_btn.click(fn=gradio_interface_optimized,
                      inputs=[video_input, bg_input, use_matanyone, bg_preset, stabilize, preroll_frames],
                      outputs=[output_video, download_file, status_text])
    selftest_btn.click(fn=gradio_run_self_test, inputs=[], outputs=[selftest_out])
    mat_diag_btn.click(fn=show_matanyone_diag, inputs=[], outputs=[mat_diag_out])

    gr.Markdown("---")
    gr.Markdown("""
    **Notes**
    - K-Governor clamps/pads Top-K inside MatAnyone to prevent 'k out of range' crashes.
    - Short-clip stabilizer pre-roll is trimmed out of alpha automatically.
    - SAM2 shows ✅ only after a real micro-inference passes.
    - FFmpeg/MoviePy, CUDA, and rembg are validated by actually running them.
    """)

# =========================
# Proactive warmup at boot (before UI render)
# =========================
try:
    if MATANYONE_IMPORTED and os.getenv("USE_MATANYONE","true").lower()=="true":
        print("Warming up MatAnyone…")
        matanyone_processor.initialize()
        print("MatAnyone warmup complete.")
except Exception as e:
    print(f"MatAnyone warmup failed (non-fatal): {e}")
    traceback.print_exc()

# =========================
# Late re-sanitization for external .env overrides
# =========================
def _re_sanitize_threads():
    for v in ("OMP_NUM_THREADS", "MKL_NUM_THREADS"):
        val = os.environ.get(v, "")
        if not str(val).isdigit():
            os.environ[v] = "2"
            print(f"{v} had invalid value; reset to 2")

if os.getenv("STRICT_ENV_GUARD","1") in ("1","true","TRUE"):
    _re_sanitize_threads()

# =========================
# Entrypoint / CLI self-test
# =========================
if __name__ == "__main__":
    if "--self-test" in sys.argv:
        report = run_self_test(); print(report)
        exit_code = 0
        for line in report.splitlines():
            if line.startswith("❌"): exit_code = 2; break
        sys.exit(exit_code)
    print("\n" + "="*50)
    print("🚀 Starting GPU-optimized Gradio app…")
    print("URL: http://0.0.0.0:7860")
    print(f"GPU Monitoring: {'Active' if CUDA_AVAILABLE else 'Disabled'}")
    print("="*50 + "\n")
    demo.launch(server_name="0.0.0.0", server_port=7860)