File size: 7,386 Bytes
05f03a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
End-to-End Smoke Test (SAM2 β†’ MatAnyone β†’ TwoStageProcessor)
- Loads a short input video
- Extracts first frame, runs SAM2 coarse mask
- Bootstraps MatAnyone and saves its refined alpha for the first frame
- Runs full TwoStageProcessor pipeline (both stages)
- Writes out:
    out/sam2_mask0.png
    out/matanyone_alpha0.png
    out/greenscreen.mp4
    out/final.mp4
- Prints a compact summary and non-zero exit on critical failure

Usage:
    python tools/e2e_smoke_test.py --video path/to/clip.mp4 --bg path/to/bg.jpg
    # or pass a solid background color:
    python tools/e2e_smoke_test.py --video path/to/clip.mp4 --bg-color 30 30 30
"""

# --- fix OMP/BLAS early (before numpy/torch/opencv import) ---
import os
omp = os.environ.get("OMP_NUM_THREADS", "")
if not omp.strip().isdigit():
    os.environ["OMP_NUM_THREADS"] = "2"
os.environ.setdefault("MKL_NUM_THREADS", "2")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "2")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")

import sys
import argparse
import time
import logging
from pathlib import Path

import cv2
import numpy as np

# Ensure repo root on path (this file lives in tools/)
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))

from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader
from processing.two_stage.two_stage_processor import TwoStageProcessor


def _read_first_frame(video_path: str):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, "Could not open video"
    ok, frame = cap.read()
    cap.release()
    if not ok or frame is None:
        return None, "Could not read first frame"
    return frame, None


def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def _save_mask_png(mask01: np.ndarray, path: Path):
    m = mask01.astype(np.float32)
    if m.max() <= 1.0:
        m = (m * 255.0)
    cv2.imwrite(str(path), np.clip(m, 0, 255).astype(np.uint8))


def _load_background(bg_path: str | None, size_wh: tuple[int, int], bg_color: tuple[int, int, int] | None):
    w, h = size_wh
    if bg_path:
        img = cv2.imread(bg_path, cv2.IMREAD_COLOR)
        if img is None:
            return None, f"Failed to read background image: {bg_path}"
        return cv2.resize(img, (w, h)), None
    # Solid color
    color = bg_color if bg_color is not None else (0, 0, 0)
    canvas = np.zeros((h, w, 3), np.uint8)
    canvas[:] = tuple(int(x) for x in color)
    return canvas, None


def main():
    ap = argparse.ArgumentParser(description="E2E smoke test for SAM2 + MatAnyone + TwoStageProcessor")
    ap.add_argument("--video", required=True, help="Path to a short input video (3–10s ideal)")
    ap.add_argument("--bg", default=None, help="Optional path to background image for Stage 2")
    ap.add_argument("--bg-color", nargs=3, type=int, default=None, help="Solid BGR background (e.g. 30 30 30)")
    ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device for models")
    ap.add_argument("--model-size", default="auto", choices=["auto","tiny","small","base","large"], help="SAM2 size")
    ap.add_argument("--outdir", default="out", help="Output dir")
    args = ap.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
    log = logging.getLogger("e2e")

    outdir = Path(args.outdir)
    _ensure_dir(outdir)

    # 1) Load first frame
    frame0_bgr, err = _read_first_frame(args.video)
    if err:
        log.error(err)
        sys.exit(2)
    h0, w0 = frame0_bgr.shape[:2]
    log.info(f"First frame size: {w0}x{h0}")

    # 2) Load SAM2
    t0 = time.time()
    sam = SAM2Loader(device=args.device).load(args.model_size)
    if not sam:
        log.error("SAM2 failed to load")
        sys.exit(3)
    log.info(f"SAM2 loaded in {time.time()-t0:.2f}s")

    # 3) Coarse mask from SAM2 on frame 0
    sam.set_image(frame0_bgr)  # accepts BGR or RGB
    out = sam.predict(point_coords=None, point_labels=None)
    masks = out.get("masks", None)
    if masks is None or len(masks) == 0:
        log.warning("SAM2 returned no masks; using fallback ones mask")
        mask0 = np.ones((h0, w0), np.float32)
    else:
        mask0 = masks[0].astype(np.float32)
        if mask0.shape != (h0, w0):
            mask0 = cv2.resize(mask0, (w0, h0), interpolation=cv2.INTER_LINEAR)
    _save_mask_png(mask0, outdir / "sam2_mask0.png")
    log.info(f"Wrote {outdir/'sam2_mask0.png'}")

    # 4) Load MatAnyone (stateful session)
    t1 = time.time()
    mat_session = MatAnyoneLoader(device=args.device).load()
    if mat_session is None:
        log.error("MatAnyone failed to load")
        sys.exit(4)
    log.info(f"MatAnyone loaded in {time.time()-t1:.2f}s")

    # 5) Bootstrap MatAnyone on first frame (TwoStageProcessor also does this, but we test it explicitly here)
    frame0_rgb = cv2.cvtColor(frame0_bgr, cv2.COLOR_BGR2RGB)
    alpha0 = mat_session(frame0_rgb, mask0)  # returns 2-D float32 [H, W]
    _save_mask_png(alpha0, outdir / "matanyone_alpha0.png")
    log.info(f"Wrote {outdir/'matanyone_alpha0.png'}")

    # 6) Prepare background for Stage 2
    bg_img, err = _load_background(args.bg, (w0, h0), tuple(args.bg_color) if args.bg_color else None)
    if err:
        log.error(err)
        sys.exit(5)

    # 7) End-to-end pipeline (both stages)
    tsp = TwoStageProcessor(sam2_predictor=sam, matanyone_model=mat_session)

    def _progress(pct: float, desc: str):
        # keep console output compact
        sys.stdout.write(f"\r[{pct*100:5.1f}%] {desc:60.60s}")
        sys.stdout.flush()

    # Write greenscreen intermediate and final composite
    greenscreen_path = str(outdir / "greenscreen.mp4")
    final_path       = str(outdir / "final.mp4")

    log.info("\nStage 1 β†’ greenscreen...")
    st1_info, st1_msg = tsp.stage1_extract_to_greenscreen(
        video_path=args.video,
        output_path=greenscreen_path,
        key_color_mode="auto",
        progress_callback=_progress,
        stop_event=None,
    )
    print()  # newline after progress

    if st1_info is None:
        log.error(f"Stage 1 failed: {st1_msg}")
        sys.exit(6)
    else:
        log.info(f"Stage 1 OK: {st1_msg}")

    log.info("Stage 2 β†’ final composite...")
    # We pass the ndarray bg_img directly (TwoStageProcessor accepts str or ndarray)
    st2_path, st2_msg = tsp.stage2_greenscreen_to_final(
        gs_path=st1_info["path"],
        background=bg_img,
        output_path=final_path,
        chroma_settings=None,
        progress_callback=_progress,
        stop_event=None,
    )
    print()

    if st2_path is None:
        log.error(f"Stage 2 failed: {st2_msg}")
        sys.exit(7)
    else:
        log.info(f"Stage 2 OK: {st2_msg}")

    # 8) Summary
    log.info("----- SUMMARY -----")
    log.info(f"SAM2 first mask:     {outdir/'sam2_mask0.png'}")
    log.info(f"MatAnyone alpha 0:   {outdir/'matanyone_alpha0.png'}")
    log.info(f"Greenscreen video:   {greenscreen_path}")
    log.info(f"Final composite:     {final_path}")
    log.info("Smoke test completed successfully.")
    return 0


if __name__ == "__main__":
    try:
        rc = main()
        sys.exit(rc if isinstance(rc, int) else 0)
    except KeyboardInterrupt:
        print("\nInterrupted.")
        sys.exit(130)