MogensR commited on
Commit
05f03a8
·
1 Parent(s): 1fc7206

Create utils/tools/e2e_smoke_test.py

Browse files
Files changed (1) hide show
  1. utils/tools/e2e_smoke_test.py +213 -0
utils/tools/e2e_smoke_test.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ End-to-End Smoke Test (SAM2 → MatAnyone → TwoStageProcessor)
5
+ - Loads a short input video
6
+ - Extracts first frame, runs SAM2 coarse mask
7
+ - Bootstraps MatAnyone and saves its refined alpha for the first frame
8
+ - Runs full TwoStageProcessor pipeline (both stages)
9
+ - Writes out:
10
+ out/sam2_mask0.png
11
+ out/matanyone_alpha0.png
12
+ out/greenscreen.mp4
13
+ out/final.mp4
14
+ - Prints a compact summary and non-zero exit on critical failure
15
+
16
+ Usage:
17
+ python tools/e2e_smoke_test.py --video path/to/clip.mp4 --bg path/to/bg.jpg
18
+ # or pass a solid background color:
19
+ python tools/e2e_smoke_test.py --video path/to/clip.mp4 --bg-color 30 30 30
20
+ """
21
+
22
+ # --- fix OMP/BLAS early (before numpy/torch/opencv import) ---
23
+ import os
24
+ omp = os.environ.get("OMP_NUM_THREADS", "")
25
+ if not omp.strip().isdigit():
26
+ os.environ["OMP_NUM_THREADS"] = "2"
27
+ os.environ.setdefault("MKL_NUM_THREADS", "2")
28
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "2")
29
+ os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
30
+
31
+ import sys
32
+ import argparse
33
+ import time
34
+ import logging
35
+ from pathlib import Path
36
+
37
+ import cv2
38
+ import numpy as np
39
+
40
+ # Ensure repo root on path (this file lives in tools/)
41
+ REPO_ROOT = Path(__file__).resolve().parents[1]
42
+ sys.path.insert(0, str(REPO_ROOT))
43
+
44
+ from models.loaders.sam2_loader import SAM2Loader
45
+ from models.loaders.matanyone_loader import MatAnyoneLoader
46
+ from processing.two_stage.two_stage_processor import TwoStageProcessor
47
+
48
+
49
+ def _read_first_frame(video_path: str):
50
+ cap = cv2.VideoCapture(video_path)
51
+ if not cap.isOpened():
52
+ return None, "Could not open video"
53
+ ok, frame = cap.read()
54
+ cap.release()
55
+ if not ok or frame is None:
56
+ return None, "Could not read first frame"
57
+ return frame, None
58
+
59
+
60
+ def _ensure_dir(p: Path):
61
+ p.mkdir(parents=True, exist_ok=True)
62
+
63
+
64
+ def _save_mask_png(mask01: np.ndarray, path: Path):
65
+ m = mask01.astype(np.float32)
66
+ if m.max() <= 1.0:
67
+ m = (m * 255.0)
68
+ cv2.imwrite(str(path), np.clip(m, 0, 255).astype(np.uint8))
69
+
70
+
71
+ def _load_background(bg_path: str | None, size_wh: tuple[int, int], bg_color: tuple[int, int, int] | None):
72
+ w, h = size_wh
73
+ if bg_path:
74
+ img = cv2.imread(bg_path, cv2.IMREAD_COLOR)
75
+ if img is None:
76
+ return None, f"Failed to read background image: {bg_path}"
77
+ return cv2.resize(img, (w, h)), None
78
+ # Solid color
79
+ color = bg_color if bg_color is not None else (0, 0, 0)
80
+ canvas = np.zeros((h, w, 3), np.uint8)
81
+ canvas[:] = tuple(int(x) for x in color)
82
+ return canvas, None
83
+
84
+
85
+ def main():
86
+ ap = argparse.ArgumentParser(description="E2E smoke test for SAM2 + MatAnyone + TwoStageProcessor")
87
+ ap.add_argument("--video", required=True, help="Path to a short input video (3–10s ideal)")
88
+ ap.add_argument("--bg", default=None, help="Optional path to background image for Stage 2")
89
+ ap.add_argument("--bg-color", nargs=3, type=int, default=None, help="Solid BGR background (e.g. 30 30 30)")
90
+ ap.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device for models")
91
+ ap.add_argument("--model-size", default="auto", choices=["auto","tiny","small","base","large"], help="SAM2 size")
92
+ ap.add_argument("--outdir", default="out", help="Output dir")
93
+ args = ap.parse_args()
94
+
95
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
96
+ log = logging.getLogger("e2e")
97
+
98
+ outdir = Path(args.outdir)
99
+ _ensure_dir(outdir)
100
+
101
+ # 1) Load first frame
102
+ frame0_bgr, err = _read_first_frame(args.video)
103
+ if err:
104
+ log.error(err)
105
+ sys.exit(2)
106
+ h0, w0 = frame0_bgr.shape[:2]
107
+ log.info(f"First frame size: {w0}x{h0}")
108
+
109
+ # 2) Load SAM2
110
+ t0 = time.time()
111
+ sam = SAM2Loader(device=args.device).load(args.model_size)
112
+ if not sam:
113
+ log.error("SAM2 failed to load")
114
+ sys.exit(3)
115
+ log.info(f"SAM2 loaded in {time.time()-t0:.2f}s")
116
+
117
+ # 3) Coarse mask from SAM2 on frame 0
118
+ sam.set_image(frame0_bgr) # accepts BGR or RGB
119
+ out = sam.predict(point_coords=None, point_labels=None)
120
+ masks = out.get("masks", None)
121
+ if masks is None or len(masks) == 0:
122
+ log.warning("SAM2 returned no masks; using fallback ones mask")
123
+ mask0 = np.ones((h0, w0), np.float32)
124
+ else:
125
+ mask0 = masks[0].astype(np.float32)
126
+ if mask0.shape != (h0, w0):
127
+ mask0 = cv2.resize(mask0, (w0, h0), interpolation=cv2.INTER_LINEAR)
128
+ _save_mask_png(mask0, outdir / "sam2_mask0.png")
129
+ log.info(f"Wrote {outdir/'sam2_mask0.png'}")
130
+
131
+ # 4) Load MatAnyone (stateful session)
132
+ t1 = time.time()
133
+ mat_session = MatAnyoneLoader(device=args.device).load()
134
+ if mat_session is None:
135
+ log.error("MatAnyone failed to load")
136
+ sys.exit(4)
137
+ log.info(f"MatAnyone loaded in {time.time()-t1:.2f}s")
138
+
139
+ # 5) Bootstrap MatAnyone on first frame (TwoStageProcessor also does this, but we test it explicitly here)
140
+ frame0_rgb = cv2.cvtColor(frame0_bgr, cv2.COLOR_BGR2RGB)
141
+ alpha0 = mat_session(frame0_rgb, mask0) # returns 2-D float32 [H, W]
142
+ _save_mask_png(alpha0, outdir / "matanyone_alpha0.png")
143
+ log.info(f"Wrote {outdir/'matanyone_alpha0.png'}")
144
+
145
+ # 6) Prepare background for Stage 2
146
+ bg_img, err = _load_background(args.bg, (w0, h0), tuple(args.bg_color) if args.bg_color else None)
147
+ if err:
148
+ log.error(err)
149
+ sys.exit(5)
150
+
151
+ # 7) End-to-end pipeline (both stages)
152
+ tsp = TwoStageProcessor(sam2_predictor=sam, matanyone_model=mat_session)
153
+
154
+ def _progress(pct: float, desc: str):
155
+ # keep console output compact
156
+ sys.stdout.write(f"\r[{pct*100:5.1f}%] {desc:60.60s}")
157
+ sys.stdout.flush()
158
+
159
+ # Write greenscreen intermediate and final composite
160
+ greenscreen_path = str(outdir / "greenscreen.mp4")
161
+ final_path = str(outdir / "final.mp4")
162
+
163
+ log.info("\nStage 1 → greenscreen...")
164
+ st1_info, st1_msg = tsp.stage1_extract_to_greenscreen(
165
+ video_path=args.video,
166
+ output_path=greenscreen_path,
167
+ key_color_mode="auto",
168
+ progress_callback=_progress,
169
+ stop_event=None,
170
+ )
171
+ print() # newline after progress
172
+
173
+ if st1_info is None:
174
+ log.error(f"Stage 1 failed: {st1_msg}")
175
+ sys.exit(6)
176
+ else:
177
+ log.info(f"Stage 1 OK: {st1_msg}")
178
+
179
+ log.info("Stage 2 → final composite...")
180
+ # We pass the ndarray bg_img directly (TwoStageProcessor accepts str or ndarray)
181
+ st2_path, st2_msg = tsp.stage2_greenscreen_to_final(
182
+ gs_path=st1_info["path"],
183
+ background=bg_img,
184
+ output_path=final_path,
185
+ chroma_settings=None,
186
+ progress_callback=_progress,
187
+ stop_event=None,
188
+ )
189
+ print()
190
+
191
+ if st2_path is None:
192
+ log.error(f"Stage 2 failed: {st2_msg}")
193
+ sys.exit(7)
194
+ else:
195
+ log.info(f"Stage 2 OK: {st2_msg}")
196
+
197
+ # 8) Summary
198
+ log.info("----- SUMMARY -----")
199
+ log.info(f"SAM2 first mask: {outdir/'sam2_mask0.png'}")
200
+ log.info(f"MatAnyone alpha 0: {outdir/'matanyone_alpha0.png'}")
201
+ log.info(f"Greenscreen video: {greenscreen_path}")
202
+ log.info(f"Final composite: {final_path}")
203
+ log.info("Smoke test completed successfully.")
204
+ return 0
205
+
206
+
207
+ if __name__ == "__main__":
208
+ try:
209
+ rc = main()
210
+ sys.exit(rc if isinstance(rc, int) else 0)
211
+ except KeyboardInterrupt:
212
+ print("\nInterrupted.")
213
+ sys.exit(130)