MogensR commited on
Commit
06646af
·
1 Parent(s): 03cab85

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +85 -228
models/loaders/matanyone_loader.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
  """
4
  MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping)
5
  =================================================================================
@@ -11,16 +9,11 @@
11
  e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning)
12
  e.g. [B,C,H,W] -> [C,H,W]
13
  e.g. [H,W,C,1] -> [H,W,C]
14
- - Optional CUDA mixed precision (fp16/bf16)
15
  - Robust alpha extraction -> (H,W) float32 [0,1]
16
  """
 
17
 
18
- import os
19
- import time
20
  import logging
21
- import tempfile
22
- import traceback
23
- from pathlib import Path
24
  from typing import Optional, Dict, Any, Tuple, Union
25
 
26
  import numpy as np
@@ -28,6 +21,12 @@
28
 
29
  logger = logging.getLogger(__name__)
30
 
 
 
 
 
 
 
31
 
32
  # ------------------------------ Helpers ------------------------------
33
 
@@ -75,28 +74,23 @@ def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "
75
 
76
  if torch.is_tensor(image):
77
  t = image
78
- # Convert 4D (rare if caller passes) once more
79
  if t.ndim == 4:
80
  t = _strip_leading_extras_to_ndim(t, 3)
81
 
82
  if t.ndim == 3:
83
  c0, c1, c2 = t.shape
84
  if c0 in (1, 3, 4):
85
- # CHW
86
- pass
87
  elif c2 in (1, 3, 4):
88
- # HWC -> CHW
89
- t = t.permute(2, 0, 1)
90
  else:
91
- # Ambiguous, assume HWC-like and take first channel after moving to CHW
92
  logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.")
93
  t = t.permute(2, 0, 1)
94
  if t.shape[0] > 1:
95
  t = t[0]
96
- t = t.unsqueeze(0) # back to 1HW
97
  elif t.ndim == 2:
98
- # HW -> 1HW
99
- t = t.unsqueeze(0)
100
  else:
101
  raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.")
102
 
@@ -107,21 +101,20 @@ def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "
107
  logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
108
  return t
109
 
110
- # numpy path
111
  arr = np.asarray(image)
112
  if arr.ndim == 4:
113
  arr = _strip_leading_extras_to_ndim(arr, 3)
114
 
115
  if arr.ndim == 3:
116
- if arr.shape[0] in (1, 3, 4): # CHW
117
- pass
118
- elif arr.shape[-1] in (1, 3, 4): # HWC -> CHW
119
- arr = arr.transpose(2, 0, 1)
120
  else:
121
  logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.")
122
- arr = arr.transpose(2, 0, 1) # HWC->CHW
123
  if arr.shape[0] > 1:
124
- arr = arr[0:1, ...] # 1HW
125
  elif arr.ndim == 2:
126
  arr = arr[None, ...] # 1HW
127
  else:
@@ -144,24 +137,20 @@ def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "m
144
  if torch.is_tensor(mask):
145
  m = mask
146
  if m.ndim == 3:
147
- # 1HW or CHW or HWC-like
148
  if m.shape[0] == 1:
149
  pass # 1HW
150
  elif m.shape[-1] == 1:
151
  m = m.permute(2, 0, 1) # HW1 -> 1HW
152
  else:
153
- # If multi-channel, take first
154
  logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.")
155
- # Assume CHW or HWC-like already normalized earlier; prefer leading as channel
156
  if m.shape[0] in (3, 4):
157
  m = m[0:1, ...]
158
  elif m.shape[-1] in (3, 4):
159
  m = m.permute(2, 0, 1)[0:1, ...]
160
  else:
161
- # Ambiguous -> take first along first axis and ensure 1HW
162
  m = m[0:1, ...]
163
  elif m.ndim == 2:
164
- m = m.unsqueeze(0) # 1HW
165
  else:
166
  raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.")
167
 
@@ -172,7 +161,6 @@ def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "m
172
  logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)")
173
  return m
174
 
175
- # numpy path
176
  arr = np.asarray(mask)
177
  if arr.ndim == 3:
178
  if arr.shape[0] == 1:
@@ -182,13 +170,13 @@ def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "m
182
  else:
183
  logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.")
184
  if arr.shape[0] in (3, 4):
185
- arr = arr[0:1, ...] # CHW -> 1HW
186
  elif arr.shape[-1] in (3, 4):
187
- arr = arr.transpose(2, 0, 1)[0:1, ...] # HWC -> CHW -> 1HW
188
  else:
189
- arr = arr[0:1, ...] # ambiguous -> 1HW by slice
190
  elif arr.ndim == 2:
191
- arr = arr[None, ...] # 1HW
192
  else:
193
  raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
194
 
@@ -207,7 +195,6 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
207
  result = result.detach().float().cpu()
208
 
209
  arr = np.asarray(result)
210
- # Strip to <= 3 dims, then extract
211
  while arr.ndim > 3:
212
  if arr.shape[0] > 1:
213
  logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.")
@@ -216,14 +203,13 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
216
  if arr.ndim == 2:
217
  alpha = arr
218
  elif arr.ndim == 3:
219
- if arr.shape[0] in (1, 3, 4): # CHW -> take channel 0
220
  alpha = arr[0]
221
- elif arr.shape[-1] in (1, 3, 4): # HWC -> take channel 0
222
  alpha = arr[..., 0]
223
  else:
224
- alpha = arr[0] # ambiguous
225
  else:
226
- # 1D or 0D shouldn't happen; fallback
227
  alpha = np.full((512, 512), 0.5, dtype=np.float32)
228
 
229
  alpha = alpha.astype(np.float32, copy=False)
@@ -233,26 +219,18 @@ def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
233
 
234
  def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
235
  """Best-effort infer (H, W) for fallback mask sizing."""
236
- if torch.is_tensor(x):
237
- shape = tuple(x.shape)
238
- else:
239
- shape = np.asarray(x).shape
240
-
241
- # Try common orders
242
- if len(shape) == 2: # HW
243
  return shape[0], shape[1]
244
  if len(shape) == 3:
245
- if shape[0] in (1, 3, 4): # CHW
246
  return shape[1], shape[2]
247
- if shape[-1] in (1, 3, 4): # HWC
248
  return shape[0], shape[1]
249
- # Ambiguous -> treat as CHW
250
  return shape[1], shape[2]
251
  if len(shape) >= 4:
252
- # Assume leading are batch/time; try BCHW first
253
  if len(shape) >= 4 and (shape[1] in (1, 3, 4)):
254
  return shape[2], shape[3]
255
- # Else BHWC-ish
256
  return shape[-3], shape[-2]
257
  return 512, 512
258
 
@@ -270,25 +248,16 @@ class MatAnyoneCallableWrapper:
270
  - Strips any extra dims from inputs before calling core.
271
  """
272
 
273
- def __init__(self, inference_core, device: str = "cuda", mixed_precision: Optional[str] = "fp16"):
274
  self.core = inference_core
275
  self.initialized = False
276
- self.device = device if (device in ("cuda", "cpu")) else ("cuda" if torch.cuda.is_available() else "cpu")
277
- self.mixed_precision = mixed_precision if self.device == "cuda" else None # "fp16"|"bf16"|None
278
-
279
- def _maybe_autocast(self):
280
- if self.device == "cuda" and self.mixed_precision in ("fp16", "bf16"):
281
- dtype = torch.float16 if self.mixed_precision == "fp16" else torch.bfloat16
282
- return torch.autocast(device_type="cuda", dtype=dtype)
283
- # no-op ctx
284
- class _NullCtx:
285
- def __enter__(self): return None
286
- def __exit__(self, *exc): return False
287
- return _NullCtx()
288
 
289
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
290
  try:
291
- # Preprocess (unbatched)
292
  img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True)
293
 
294
  if not self.initialized:
@@ -300,35 +269,32 @@ def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
300
  m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True)
301
 
302
  with torch.inference_mode():
303
- with self._maybe_autocast():
304
- if hasattr(self.core, "step"):
305
- result = self.core.step(image=img_chw, mask=m_1hw, **kwargs)
306
- elif hasattr(self.core, "process_frame"):
307
- result = self.core.process_frame(img_chw, m_1hw, **kwargs)
308
- else:
309
- logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
310
- return _alpha_from_result(mask)
311
 
312
  self.initialized = True
313
  return _alpha_from_result(result)
314
 
315
  # Subsequent frames (no mask)
316
  with torch.inference_mode():
317
- with self._maybe_autocast():
318
- if hasattr(self.core, "step"):
319
- result = self.core.step(image=img_chw, **kwargs)
320
- elif hasattr(self.core, "process_frame"):
321
- result = self.core.process_frame(img_chw, **kwargs)
322
- else:
323
- h, w = _hw_from_image_like(image)
324
- logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
325
- return np.full((h, w), 0.5, dtype=np.float32)
326
 
327
  return _alpha_from_result(result)
328
 
329
  except Exception as e:
330
  logger.error(f"MatAnyone wrapper call failed: {e}")
331
- logger.debug(traceback.format_exc())
332
  # Fallbacks
333
  if mask is not None:
334
  try:
@@ -353,160 +319,51 @@ def reset(self):
353
  logger.debug(f"Core clear_memory() failed: {e}")
354
 
355
 
356
- # ------------------------------- Loader -------------------------------
357
-
358
  class MatAnyoneLoader:
359
- """
360
- Loads MatAnyone's InferenceCore and returns a callable wrapper.
361
-
362
- Usage:
363
- loader = MatAnyoneLoader(device="cuda")
364
- session = loader.load() # callable
365
- alpha = session(frame, first_frame_mask) # returns (H, W) float32
366
- """
367
 
368
- def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache",
369
- mixed_precision: Optional[str] = "fp16"):
370
- self.device = self._select_device(device)
371
- self.cache_dir = cache_dir
372
- os.makedirs(self.cache_dir, exist_ok=True)
373
-
374
- self.processor = None
375
- self.wrapper = None
376
- self.model_id = "PeiqingYang/MatAnyone"
377
- self.load_time = 0.0
378
- self.loaded = False
379
- self.load_error = None
380
- self.temp_dir = Path(tempfile.mkdtemp())
381
- self.mixed_precision = mixed_precision if self.device == "cuda" else None
382
-
383
- def _select_device(self, pref: str) -> str:
384
- pref = (pref or "").lower()
385
- if pref.startswith("cuda"):
386
- return "cuda" if torch.cuda.is_available() else "cpu"
387
- if pref == "cpu":
388
- return "cpu"
389
- return "cuda" if torch.cuda.is_available() else "cpu"
390
-
391
- def _try_build_core(self):
392
  """
393
- Try multiple constructor patterns to survive API changes.
 
394
  """
395
- from matanyone.inference.inference_core import InferenceCore
396
-
397
- # 1) Preferred: from_pretrained(...)
398
  try:
399
- core = InferenceCore.from_pretrained(self.model_id, device=self.device, cache_dir=self.cache_dir)
400
- logger.info("Loaded MatAnyone via InferenceCore.from_pretrained(...)")
401
- return core
402
- except Exception as e:
403
- logger.debug(f"from_pretrained failed: {e}")
404
-
405
- # 2) Direct ctor with device/cache_dir
406
- try:
407
- core = InferenceCore(self.model_id, device=self.device, cache_dir=self.cache_dir)
408
- logger.info("Loaded MatAnyone via InferenceCore(model_id, device, cache_dir)")
409
- return core
410
- except Exception as e:
411
- logger.debug(f"ctor(model_id, device, cache_dir) failed: {e}")
412
-
413
- # 3) Minimal ctor
414
- core = InferenceCore(self.model_id)
415
- logger.info("Loaded MatAnyone via InferenceCore(model_id) [minimal]")
416
- return core
417
-
418
- def load(self) -> Optional[MatAnyoneCallableWrapper]:
419
- """Load MatAnyone and return the callable wrapper."""
420
- if self.loaded and self.wrapper is not None:
421
- return self.wrapper
422
-
423
- logger.info(f"Loading MatAnyone: {self.model_id} (device={self.device})")
424
- t0 = time.time()
425
-
426
- try:
427
- self.processor = self._try_build_core()
428
- # Optional device move
429
- try:
430
- if hasattr(self.processor, "to"):
431
- self.processor.to(self.device)
432
- elif hasattr(self.processor, "set_device"):
433
- self.processor.set_device(self.device)
434
- except Exception as e:
435
- logger.debug(f"Optional device move failed: {e}")
436
-
437
- self.wrapper = MatAnyoneCallableWrapper(
438
- self.processor, device=self.device, mixed_precision=self.mixed_precision
439
  )
440
- self.loaded = True
441
- self.load_time = time.time() - t0
442
- logger.info(f"MatAnyone loaded and wrapped in {self.load_time:.2f}s")
443
- return self.wrapper
444
-
445
- except ImportError as e:
446
- self.load_error = f"MatAnyone not installed: {e}"
447
- logger.error("Failed to import MatAnyone. Install with: "
448
- "pip install git+https://github.com/pq-yang/MatAnyone.git@main")
449
- return None
450
  except Exception as e:
451
- self.load_error = str(e)
452
- logger.error(f"Failed to load MatAnyone: {e}")
453
- logger.debug(traceback.format_exc())
454
  return None
455
 
456
- def cleanup(self):
457
- """Cleanup temporary files and release resources."""
458
- self.processor = None
459
- self.wrapper = None
460
-
461
- # Clean temp directory
462
- if self.temp_dir.exists():
463
- import shutil
464
- shutil.rmtree(self.temp_dir, ignore_errors=True)
465
-
466
- # Clear CUDA cache if available
467
- if torch.cuda.is_available():
468
- torch.cuda.empty_cache()
469
 
470
  def get_info(self) -> Dict[str, Any]:
471
- """Get model information and interface flags."""
472
- info = {
473
- "loaded": self.loaded,
474
  "model_id": self.model_id,
475
- "device": str(self.device),
476
- "load_time": float(self.load_time),
477
- "error": self.load_error,
478
- "api": "InferenceCore (wrapped)",
479
- "mixed_precision": self.mixed_precision,
480
- }
481
- proc = self.processor
482
- if proc is not None:
483
- info["has_step"] = hasattr(proc, "step")
484
- info["has_process_frame"] = hasattr(proc, "process_frame")
485
- info["has_process_video"] = hasattr(proc, "process_video")
486
- return info
487
-
488
- def reset(self):
489
- """Reset the processor for a new video."""
490
- if self.wrapper:
491
- self.wrapper.reset()
492
- logger.info("MatAnyone session reset")
493
-
494
- # Make the loader itself callable (direct compatibility)
495
- def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
496
- if self.wrapper is None:
497
- if self.load() is None:
498
- # Fallback if loading fails
499
- if mask is not None:
500
- try:
501
- return _alpha_from_result(mask)
502
- except Exception:
503
- pass
504
- h, w = _hw_from_image_like(image)
505
- return np.zeros((h, w), dtype=np.float32)
506
- return self.wrapper(image, mask, **kwargs)
507
-
508
-
509
- # Backwards compatibility alias
510
- _MatAnyoneSession = MatAnyoneCallableWrapper
511
-
512
- __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]
 
 
 
1
  """
2
  MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping)
3
  =================================================================================
 
9
  e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning)
10
  e.g. [B,C,H,W] -> [C,H,W]
11
  e.g. [H,W,C,1] -> [H,W,C]
 
12
  - Robust alpha extraction -> (H,W) float32 [0,1]
13
  """
14
+ from __future__ import annotations
15
 
 
 
16
  import logging
 
 
 
17
  from typing import Optional, Dict, Any, Tuple, Union
18
 
19
  import numpy as np
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+ try:
25
+ # Official import path
26
+ from matanyone.inference.inference_core import InferenceCore
27
+ except Exception: # keep import error defered until load()
28
+ InferenceCore = None # type: ignore
29
+
30
 
31
  # ------------------------------ Helpers ------------------------------
32
 
 
74
 
75
  if torch.is_tensor(image):
76
  t = image
 
77
  if t.ndim == 4:
78
  t = _strip_leading_extras_to_ndim(t, 3)
79
 
80
  if t.ndim == 3:
81
  c0, c1, c2 = t.shape
82
  if c0 in (1, 3, 4):
83
+ pass # CHW
 
84
  elif c2 in (1, 3, 4):
85
+ t = t.permute(2, 0, 1) # HWC -> CHW
 
86
  else:
 
87
  logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.")
88
  t = t.permute(2, 0, 1)
89
  if t.shape[0] > 1:
90
  t = t[0]
91
+ t = t.unsqueeze(0)
92
  elif t.ndim == 2:
93
+ t = t.unsqueeze(0) # 1HW
 
94
  else:
95
  raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.")
96
 
 
101
  logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
102
  return t
103
 
 
104
  arr = np.asarray(image)
105
  if arr.ndim == 4:
106
  arr = _strip_leading_extras_to_ndim(arr, 3)
107
 
108
  if arr.ndim == 3:
109
+ if arr.shape[0] in (1, 3, 4):
110
+ pass # CHW
111
+ elif arr.shape[-1] in (1, 3, 4):
112
+ arr = arr.transpose(2, 0, 1) # HWC -> CHW
113
  else:
114
  logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.")
115
+ arr = arr.transpose(2, 0, 1)
116
  if arr.shape[0] > 1:
117
+ arr = arr[0:1, ...]
118
  elif arr.ndim == 2:
119
  arr = arr[None, ...] # 1HW
120
  else:
 
137
  if torch.is_tensor(mask):
138
  m = mask
139
  if m.ndim == 3:
 
140
  if m.shape[0] == 1:
141
  pass # 1HW
142
  elif m.shape[-1] == 1:
143
  m = m.permute(2, 0, 1) # HW1 -> 1HW
144
  else:
 
145
  logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.")
 
146
  if m.shape[0] in (3, 4):
147
  m = m[0:1, ...]
148
  elif m.shape[-1] in (3, 4):
149
  m = m.permute(2, 0, 1)[0:1, ...]
150
  else:
 
151
  m = m[0:1, ...]
152
  elif m.ndim == 2:
153
+ m = m.unsqueeze(0)
154
  else:
155
  raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.")
156
 
 
161
  logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)")
162
  return m
163
 
 
164
  arr = np.asarray(mask)
165
  if arr.ndim == 3:
166
  if arr.shape[0] == 1:
 
170
  else:
171
  logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.")
172
  if arr.shape[0] in (3, 4):
173
+ arr = arr[0:1, ...]
174
  elif arr.shape[-1] in (3, 4):
175
+ arr = arr.transpose(2, 0, 1)[0:1, ...]
176
  else:
177
+ arr = arr[0:1, ...]
178
  elif arr.ndim == 2:
179
+ arr = arr[None, ...]
180
  else:
181
  raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
182
 
 
195
  result = result.detach().float().cpu()
196
 
197
  arr = np.asarray(result)
 
198
  while arr.ndim > 3:
199
  if arr.shape[0] > 1:
200
  logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.")
 
203
  if arr.ndim == 2:
204
  alpha = arr
205
  elif arr.ndim == 3:
206
+ if arr.shape[0] in (1, 3, 4):
207
  alpha = arr[0]
208
+ elif arr.shape[-1] in (1, 3, 4):
209
  alpha = arr[..., 0]
210
  else:
211
+ alpha = arr[0]
212
  else:
 
213
  alpha = np.full((512, 512), 0.5, dtype=np.float32)
214
 
215
  alpha = alpha.astype(np.float32, copy=False)
 
219
 
220
  def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
221
  """Best-effort infer (H, W) for fallback mask sizing."""
222
+ shape = tuple(x.shape) if torch.is_tensor(x) else np.asarray(x).shape
223
+ if len(shape) == 2:
 
 
 
 
 
224
  return shape[0], shape[1]
225
  if len(shape) == 3:
226
+ if shape[0] in (1, 3, 4):
227
  return shape[1], shape[2]
228
+ if shape[-1] in (1, 3, 4):
229
  return shape[0], shape[1]
 
230
  return shape[1], shape[2]
231
  if len(shape) >= 4:
 
232
  if len(shape) >= 4 and (shape[1] in (1, 3, 4)):
233
  return shape[2], shape[3]
 
234
  return shape[-3], shape[-2]
235
  return 512, 512
236
 
 
248
  - Strips any extra dims from inputs before calling core.
249
  """
250
 
251
+ def __init__(self, inference_core, device: str = None):
252
  self.core = inference_core
253
  self.initialized = False
254
+ # Best-effort device selection if available
255
+ if device is None:
256
+ device = "cuda" if torch.cuda.is_available() else "cpu"
257
+ self.device = device
 
 
 
 
 
 
 
 
258
 
259
  def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
260
  try:
 
261
  img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True)
262
 
263
  if not self.initialized:
 
269
  m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True)
270
 
271
  with torch.inference_mode():
272
+ if hasattr(self.core, "step"):
273
+ result = self.core.step(image=img_chw, mask=m_1hw, **kwargs)
274
+ elif hasattr(self.core, "process_frame"):
275
+ result = self.core.process_frame(img_chw, m_1hw, **kwargs)
276
+ else:
277
+ logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
278
+ return _alpha_from_result(mask)
 
279
 
280
  self.initialized = True
281
  return _alpha_from_result(result)
282
 
283
  # Subsequent frames (no mask)
284
  with torch.inference_mode():
285
+ if hasattr(self.core, "step"):
286
+ result = self.core.step(image=img_chw, **kwargs)
287
+ elif hasattr(self.core, "process_frame"):
288
+ result = self.core.process_frame(img_chw, **kwargs)
289
+ else:
290
+ h, w = _hw_from_image_like(image)
291
+ logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
292
+ return np.full((h, w), 0.5, dtype=np.float32)
 
293
 
294
  return _alpha_from_result(result)
295
 
296
  except Exception as e:
297
  logger.error(f"MatAnyone wrapper call failed: {e}")
 
298
  # Fallbacks
299
  if mask is not None:
300
  try:
 
319
  logger.debug(f"Core clear_memory() failed: {e}")
320
 
321
 
 
 
322
  class MatAnyoneLoader:
323
+ def __init__(self, device: str = "auto", model_id: str = "PeiqingYang/MatAnyone"):
324
+ self.device = device
325
+ self.model_id = model_id
326
+ self._processor: Optional[InferenceCore] = None # type: ignore
327
+ self._wrapper: Optional[MatAnyoneCallableWrapper] = None
 
 
 
328
 
329
+ def load(self) -> Optional[Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  """
331
+ Initialize and return a callable wrapper around InferenceCore.
332
+ Returns MatAnyoneCallableWrapper if successful, else None.
333
  """
334
+ global InferenceCore
 
 
335
  try:
336
+ if InferenceCore is None:
337
+ from matanyone.inference.inference_core import InferenceCore as _IC # type: ignore
338
+ InferenceCore = _IC # type: ignore
339
+
340
+ logger.info("Loading MatAnyone InferenceCore ...")
341
+ self._processor = InferenceCore(self.model_id) # type: ignore
342
+ logger.info("MatAnyone InferenceCore loaded successfully")
343
+
344
+ # Choose device
345
+ dev = (
346
+ "cuda" if (str(self.device).startswith("cuda") and torch.cuda.is_available()) else
347
+ ("cpu" if str(self.device) == "cpu" else ("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  )
349
+
350
+ self._wrapper = MatAnyoneCallableWrapper(self._processor, device=dev)
351
+ logger.info("MatAnyone wrapped with dimension-safe callable")
352
+ return self._wrapper
 
 
 
 
 
 
353
  except Exception as e:
354
+ logger.error(f"Failed to load MatAnyone InferenceCore: {e}")
355
+ self._processor = None
356
+ self._wrapper = None
357
  return None
358
 
359
+ def get(self) -> Optional[Any]:
360
+ """Return the cached callable if loaded."""
361
+ return self._wrapper or self._processor
 
 
 
 
 
 
 
 
 
 
362
 
363
  def get_info(self) -> Dict[str, Any]:
364
+ """Metadata for diagnostics."""
365
+ return {
 
366
  "model_id": self.model_id,
367
+ "loaded": self._wrapper is not None or self._processor is not None,
368
+ "wrapped": self._wrapper is not None,
369
+ }