MogensR commited on
Commit
00acd62
·
1 Parent(s): 05f03a8

Create tools/self_check.py

Browse files
Files changed (1) hide show
  1. tools/self_check.py +155 -0
tools/self_check.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Quiet, one-shot startup self-check for HF Spaces.
5
+
6
+ What it does:
7
+ - loads SAM2Loader + MatAnyoneLoader (device from env or cuda/cpu auto)
8
+ - runs a minimal first-frame path (synthetic frame) to validate
9
+ - caches status in module state for later UI queries
10
+ - does NOT print unless failure; logs via `BackgroundFX`/root logger
11
+
12
+ Control via env:
13
+ - DISABLE_SELF_CHECK=1 → skip entirely
14
+ - SELF_CHECK_DEVICE=cpu|cuda → override device
15
+ - SELF_CHECK_TIMEOUT=seconds → default 45
16
+ """
17
+
18
+ from __future__ import annotations
19
+ import os, time, threading, logging
20
+ from typing import Optional, Dict, Any
21
+
22
+ import numpy as np
23
+ import cv2
24
+ import torch
25
+
26
+ # Import loaders and processor from your project
27
+ from models.loaders.sam2_loader import SAM2Loader
28
+ from models.loaders.matanyone_loader import MatAnyoneLoader
29
+ from processing.two_stage.two_stage_processor import TwoStageProcessor
30
+
31
+ logger = logging.getLogger("BackgroundFX") or logging.getLogger(__name__)
32
+
33
+ # Module-level cache
34
+ _SELF_CHECK_LOCK = threading.Lock()
35
+ _SELF_CHECK_DONE = False
36
+ _SELF_CHECK_OK = False
37
+ _SELF_CHECK_MSG = "Self-check did not run yet."
38
+ _SELF_CHECK_DURATION = 0.0
39
+
40
+ def _pick_device() -> str:
41
+ dev = os.environ.get("SELF_CHECK_DEVICE", "").strip().lower()
42
+ if dev in ("cpu", "cuda"):
43
+ return dev
44
+ return "cuda" if torch.cuda.is_available() else "cpu"
45
+
46
+ def _synth_frame(w=640, h=360) -> np.ndarray:
47
+ """
48
+ Create a simple BGR frame with a 'person-like' central blob over green.
49
+ We just need a plausible image; quality doesn’t matter for self-check.
50
+ """
51
+ img = np.zeros((h, w, 3), np.uint8)
52
+ # base: gray
53
+ img[:] = (40, 40, 40)
54
+ # put a green screen-like area on right to make chroma pass exercise
55
+ cv2.rectangle(img, (int(0.65*w), 0), (w, h), (0, 255, 0), -1)
56
+ # draw a central "person" blob
57
+ cx, cy = w//3, h//2
58
+ cv2.ellipse(img, (cx, cy-40), (35, 45), 0, 0, 360, (60, 60, 200), -1) # head-ish
59
+ cv2.rectangle(img, (cx-40, cy-10), (cx+40, cy+80), (60, 60, 200), -1) # torso-ish
60
+ return img
61
+
62
+ def _run_once(timeout_s: float = 45.0) -> tuple[bool, str, float]:
63
+ t0 = time.time()
64
+ device = _pick_device()
65
+ try:
66
+ # 1) Load SAM2
67
+ sam = SAM2Loader(device=device).load("auto")
68
+ if sam is None:
69
+ return False, "SAM2 failed to load", time.time()-t0
70
+
71
+ # 2) Get synthetic frame
72
+ bgr = _synth_frame()
73
+ sam.set_image(bgr)
74
+ out = sam.predict(point_coords=None, point_labels=None)
75
+ masks = out.get("masks", None)
76
+ h, w = bgr.shape[:2]
77
+ if masks is None or len(masks) == 0:
78
+ logger.warning("Self-check: SAM2 returned no masks; accepting fallback.")
79
+ mask0 = np.ones((h, w), np.float32)
80
+ else:
81
+ mask0 = masks[0].astype(np.float32)
82
+ if mask0.shape != (h, w):
83
+ mask0 = cv2.resize(mask0, (w, h), interpolation=cv2.INTER_LINEAR)
84
+
85
+ # 3) Load MatAnyone stateful session
86
+ session = MatAnyoneLoader(device=device).load()
87
+ if session is None:
88
+ return False, "MatAnyone failed to load", time.time()-t0
89
+
90
+ # 4) Bootstrap (frame 0 must have a mask; fallback already ensured)
91
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
92
+ alpha0 = session(rgb, mask0)
93
+ if not isinstance(alpha0, np.ndarray) or alpha0.shape != (h, w):
94
+ return False, f"MatAnyone alpha shape unexpected: {getattr(alpha0, 'shape', None)}", time.time()-t0
95
+
96
+ # 5) Minimal TwoStageProcessor wiring (no file IO, just instantiate)
97
+ _ = TwoStageProcessor(sam2_predictor=sam, matanyone_model=session)
98
+
99
+ return True, "OK", time.time()-t0
100
+
101
+ except Exception as e:
102
+ return False, f"Self-check error: {e}", time.time()-t0
103
+ finally:
104
+ # crude timeout enforcement info (the thread is joined by caller)
105
+ dur = time.time()-t0
106
+ if dur > timeout_s:
107
+ logger.warning(f"Self-check exceeded timeout {timeout_s:.1f}s (took {dur:.2f}s)")
108
+ return locals().get("sam", None) is not None and locals().get("session", None) is not None, \
109
+ locals().get("e", None) and f"Self-check error: {e}" or "OK", \
110
+ dur
111
+
112
+ def _runner(timeout_s: float):
113
+ global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
114
+ ok, msg, dur = _run_once(timeout_s=timeout_s)
115
+ with _SELF_CHECK_LOCK:
116
+ _SELF_CHECK_DONE = True
117
+ _SELF_CHECK_OK = bool(ok and msg == "OK")
118
+ _SELF_CHECK_MSG = msg
119
+ _SELF_CHECK_DURATION = float(dur)
120
+ if _SELF_CHECK_OK:
121
+ logger.info(f"✅ Startup self-check OK in {dur:.2f}s")
122
+ else:
123
+ logger.error(f"❌ Startup self-check FAILED in {dur:.2f}s: {msg}")
124
+
125
+ def launch_self_check_async(timeout_s: Optional[float] = None):
126
+ """
127
+ Fire-and-forget startup check. No effect if disabled or already started.
128
+ """
129
+ if os.environ.get("DISABLE_SELF_CHECK", "0") == "1":
130
+ logger.info("Self-check disabled via DISABLE_SELF_CHECK=1")
131
+ with _SELF_CHECK_LOCK:
132
+ global _SELF_CHECK_DONE, _SELF_CHECK_OK, _SELF_CHECK_MSG, _SELF_CHECK_DURATION
133
+ _SELF_CHECK_DONE = True
134
+ _SELF_CHECK_OK = True
135
+ _SELF_CHECK_MSG = "Disabled"
136
+ _SELF_CHECK_DURATION = 0.0
137
+ return
138
+
139
+ timeout_s = float(os.environ.get("SELF_CHECK_TIMEOUT", str(timeout_s or 45.0)))
140
+ # Only launch once
141
+ with _SELF_CHECK_LOCK:
142
+ if getattr(launch_self_check_async, "_started", False):
143
+ return
144
+ launch_self_check_async._started = True # type: ignore[attr-defined]
145
+ th = threading.Thread(target=_runner, args=(timeout_s,), daemon=True)
146
+ th.start()
147
+
148
+ def get_self_check_status() -> Dict[str, Any]:
149
+ with _SELF_CHECK_LOCK:
150
+ return {
151
+ "done": _SELF_CHECK_DONE,
152
+ "ok": _SELF_CHECK_OK,
153
+ "message": _SELF_CHECK_MSG,
154
+ "duration": _SELF_CHECK_DURATION,
155
+ }