File size: 17,147 Bytes
a70dcf0
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
84a78ca
 
a70dcf0
84a78ca
 
 
 
 
 
a70dcf0
 
 
 
 
 
 
 
 
 
 
84a78ca
 
 
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
 
a70dcf0
84a78ca
a70dcf0
36900c1
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
84a78ca
a70dcf0
84a78ca
a70dcf0
 
 
 
84a78ca
a70dcf0
 
84a78ca
a70dcf0
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
84a78ca
a70dcf0
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
 
84a78ca
a70dcf0
 
 
 
 
84a78ca
 
a70dcf0
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
 
a70dcf0
 
84a78ca
 
 
 
a70dcf0
 
84a78ca
 
 
a70dcf0
 
 
 
 
84a78ca
 
 
 
 
a70dcf0
84a78ca
 
 
a70dcf0
 
 
84a78ca
 
a70dcf0
84a78ca
 
 
 
 
 
a70dcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
84a78ca
a70dcf0
 
 
 
 
 
84a78ca
a70dcf0
 
 
84a78ca
a70dcf0
 
 
 
 
 
 
 
 
 
84a78ca
a70dcf0
 
84a78ca
a70dcf0
84a78ca
 
a70dcf0
 
 
 
 
 
 
84a78ca
a70dcf0
84a78ca
 
 
 
 
a70dcf0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#!/usr/bin/env python3
"""
Memory Manager for BackgroundFX Pro
- Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces)
- Accepts `device` as str or torch.device
- Optional per-process VRAM cap (env or method)
- Detailed usage reporting (CPU/RAM + VRAM + torch allocator)
- Light and aggressive cleanup paths
- Background monitor (optional)

Env switches:
  BFX_DISABLE_LIMIT=1          -> do not set VRAM fraction automatically
  BFX_CUDA_FRACTION=0.80       -> fraction to cap per-process VRAM (0.10..0.95)
"""

from __future__ import annotations
import gc
import os
import time
import logging
import threading
from typing import Dict, Any, Optional, Callable

# Optional deps
try:
    import psutil
except Exception:  # pragma: no cover
    psutil = None

try:
    import torch
except Exception:  # pragma: no cover
    torch = None

logger = logging.getLogger(__name__)

# ---- local exception to avoid shadowing built-in MemoryError ----
class MemoryManagerError(Exception):
    pass


def _bytes_to_gb(x: int | float) -> float:
    try:
        return float(x) / (1024**3)
    except Exception:
        return 0.0


def _normalize_device(dev) -> "torch.device":
    if torch is None:
        # fake CPU device
        class _Fake:
            type = "cpu"
            index = None
        return _Fake()  # type: ignore[return-value]

    if isinstance(dev, str):
        return torch.device(dev)
    if hasattr(dev, "type"):
        return dev
    # default CPU
    return torch.device("cpu")


def _cuda_index(device) -> Optional[int]:
    if getattr(device, "type", "cpu") != "cuda":
        return None
    idx = getattr(device, "index", None)
    if idx is None:
        # normalize bare "cuda" to 0
        return 0
    return int(idx)


class MemoryManager:
    """
    Comprehensive memory management with VRAM cap + cleanup utilities.
    """

    def __init__(self, device, memory_limit_gb: Optional[float] = None):
        self.device = _normalize_device(device)
        self.device_type = getattr(self.device, "type", "cpu")
        self.cuda_idx = _cuda_index(self.device)

        self.gpu_available = bool(
            torch and self.device_type == "cuda" and torch.cuda.is_available()
        )
        self.mps_available = bool(
            torch and self.device_type == "mps" and getattr(torch.backends, "mps", None)
            and torch.backends.mps.is_available()
        )

        self.memory_limit_gb = memory_limit_gb
        self.cleanup_callbacks: list[Callable] = []
        self.monitoring_active = False
        self.monitoring_thread: Optional[threading.Thread] = None
        self.stats = {
            "cleanup_count": 0,
            "peak_memory_usage": 0.0,
            "total_allocated": 0.0,
            "total_freed": 0.0,
        }
        self.applied_fraction: Optional[float] = None

        self._initialize_memory_limits()
        self._maybe_apply_vram_fraction()
        logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})")

    # -------------------------------
    # init helpers
    # -------------------------------
    def _initialize_memory_limits(self):
        try:
            if self.gpu_available:
                props = torch.cuda.get_device_properties(self.cuda_idx or 0)
                total_gb = _bytes_to_gb(props.total_memory)
                if self.memory_limit_gb is None:
                    self.memory_limit_gb = max(0.5, total_gb * 0.80)  # default 80%
                logger.info(
                    f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB "
                    f"(device total {total_gb:.1f}GB)"
                )
            elif self.mps_available:
                vm = psutil.virtual_memory() if psutil else None
                total_gb = _bytes_to_gb(vm.total) if vm else 0.0
                if self.memory_limit_gb is None:
                    self.memory_limit_gb = max(0.5, total_gb * 0.50)
                logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)")
            else:
                vm = psutil.virtual_memory() if psutil else None
                total_gb = _bytes_to_gb(vm.total) if vm else 0.0
                if self.memory_limit_gb is None:
                    self.memory_limit_gb = max(0.5, total_gb * 0.60)
                logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)")
        except Exception as e:
            logger.warning(f"Memory limit init failed: {e}")
            if self.memory_limit_gb is None:
                self.memory_limit_gb = 4.0  # conservative fallback

    def _maybe_apply_vram_fraction(self):
        if not self.gpu_available or torch is None:
            return
        if os.environ.get("BFX_DISABLE_LIMIT", ""):
            return
        frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip()
        try:
            fraction = float(frac_env) if frac_env else 0.80
        except Exception:
            fraction = 0.80
        applied = self.limit_cuda_memory(fraction=fraction)
        if applied:
            logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}")

    # -------------------------------
    # public API
    # -------------------------------
    def get_memory_usage(self) -> Dict[str, Any]:
        usage: Dict[str, Any] = {
            "device_type": self.device_type,
            "memory_limit_gb": self.memory_limit_gb,
            "timestamp": time.time(),
        }

        # CPU / system
        if psutil:
            try:
                vm = psutil.virtual_memory()
                usage.update(
                    dict(
                        system_total_gb=round(_bytes_to_gb(vm.total), 3),
                        system_available_gb=round(_bytes_to_gb(vm.available), 3),
                        system_used_gb=round(_bytes_to_gb(vm.used), 3),
                        system_percent=float(vm.percent),
                    )
                )
                swap = psutil.swap_memory()
                usage.update(
                    dict(
                        swap_total_gb=round(_bytes_to_gb(swap.total), 3),
                        swap_used_gb=round(_bytes_to_gb(swap.used), 3),
                        swap_percent=float(swap.percent),
                    )
                )
                proc = psutil.Process()
                mi = proc.memory_info()
                usage.update(
                    dict(
                        process_rss_gb=round(_bytes_to_gb(mi.rss), 3),
                        process_vms_gb=round(_bytes_to_gb(mi.vms), 3),
                    )
                )
            except Exception as e:
                logger.debug(f"psutil stats error: {e}")

        # GPU
        if self.gpu_available and torch is not None:
            try:
                # mem_get_info returns (free, total) in bytes
                free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0)
                used_b = total_b - free_b
                usage.update(
                    dict(
                        vram_total_gb=round(_bytes_to_gb(total_b), 3),
                        vram_used_gb=round(_bytes_to_gb(used_b), 3),
                        vram_free_gb=round(_bytes_to_gb(free_b), 3),
                        vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0,
                    )
                )
            except Exception as e:
                logger.debug(f"mem_get_info failed: {e}")

            # torch allocator stats
            try:
                idx = self.cuda_idx or 0
                allocated = torch.cuda.memory_allocated(idx)
                reserved = torch.cuda.memory_reserved(idx)
                usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3)
                usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3)
                # inactive split (2.x)
                try:
                    inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0)
                    usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3)
                except Exception:
                    pass
            except Exception as e:
                logger.debug(f"allocator stats failed: {e}")

            usage["applied_fraction"] = self.applied_fraction

        # Update peak tracker
        current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0))
        try:
            if float(current) > float(self.stats["peak_memory_usage"]):
                self.stats["peak_memory_usage"] = float(current)
        except Exception:
            pass

        return usage

    def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]:
        if not self.gpu_available or torch is None:
            return None

        # derive fraction from max_gb if provided
        if max_gb is not None:
            try:
                _, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0)
                total_gb = _bytes_to_gb(total_b)
                if total_gb <= 0:
                    return None
                fraction = min(max(0.10, max_gb / total_gb), 0.95)
            except Exception as e:
                logger.debug(f"fraction from max_gb failed: {e}")
                return None

        if fraction is None:
            fraction = 0.80
        fraction = float(max(0.10, min(0.95, fraction)))

        try:
            torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0)
            self.applied_fraction = fraction
            return fraction
        except Exception as e:
            logger.debug(f"set_per_process_memory_fraction failed: {e}")
            return None

    def cleanup(self) -> None:
        """Light cleanup used frequently between steps."""
        try:
            gc.collect()
        except Exception:
            pass
        if self.gpu_available and torch is not None:
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
        self.stats["cleanup_count"] += 1

    def cleanup_basic(self) -> None:
        """Alias kept for compatibility."""
        self.cleanup()

    def cleanup_aggressive(self) -> None:
        """Aggressive cleanup for OOM recovery or big scene switches."""
        if self.gpu_available and torch is not None:
            try:
                torch.cuda.synchronize(self.cuda_idx or 0)
            except Exception:
                pass
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
            try:
                torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0)
            except Exception:
                pass
            try:
                if hasattr(torch.cuda, "ipc_collect"):
                    torch.cuda.ipc_collect()
            except Exception:
                pass
        try:
            gc.collect(); gc.collect()
        except Exception:
            pass
        self.stats["cleanup_count"] += 1

    def register_cleanup_callback(self, callback: Callable):
        self.cleanup_callbacks.append(callback)

    def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None):
        if self.monitoring_active:
            logger.warning("Memory monitoring already active")
            return
        self.monitoring_active = True

        def loop():
            while self.monitoring_active:
                try:
                    pressure = self.check_memory_pressure()
                    if pressure["under_pressure"]:
                        logger.warning(
                            f"Memory pressure: {pressure['pressure_level']} "
                            f"({pressure['usage_percent']:.1f}%)"
                        )
                        if pressure_callback:
                            try:
                                pressure_callback(pressure)
                            except Exception as e:
                                logger.error(f"Pressure callback failed: {e}")
                        if pressure["pressure_level"] == "critical":
                            self.cleanup_aggressive()
                except Exception as e:
                    logger.error(f"Memory monitoring error: {e}")
                time.sleep(interval_seconds)

        self.monitoring_thread = threading.Thread(target=loop, daemon=True)
        self.monitoring_thread.start()
        logger.info(f"Memory monitoring started (interval: {interval_seconds}s)")

    def stop_monitoring(self):
        if self.monitoring_active:
            self.monitoring_active = False
            if self.monitoring_thread and self.monitoring_thread.is_alive():
                self.monitoring_thread.join(timeout=5.0)
            logger.info("Memory monitoring stopped")

    def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]:
        usage = self.get_memory_usage()
        info = {
            "under_pressure": False,
            "pressure_level": "normal",
            "usage_percent": 0.0,
            "recommendations": [],
        }

        if self.gpu_available:
            percent = usage.get("vram_used_percent", 0.0)
            info["usage_percent"] = percent
            if percent >= threshold_percent:
                info["under_pressure"] = True
                if percent >= 95:
                    info["pressure_level"] = "critical"
                    info["recommendations"] += [
                        "Run aggressive memory cleanup",
                        "Reduce frame cache / chunk size",
                        "Lower resolution or disable previews",
                    ]
                else:
                    info["pressure_level"] = "warning"
                    info["recommendations"] += [
                        "Run cleanup",
                        "Monitor memory usage",
                        "Reduce keyframe interval",
                    ]
        else:
            percent = usage.get("system_percent", 0.0)
            info["usage_percent"] = percent
            if percent >= threshold_percent:
                info["under_pressure"] = True
                if percent >= 95:
                    info["pressure_level"] = "critical"
                    info["recommendations"] += [
                        "Close other processes",
                        "Reduce resolution",
                        "Split video into chunks",
                    ]
                else:
                    info["pressure_level"] = "warning"
                    info["recommendations"] += [
                        "Run cleanup",
                        "Monitor usage",
                        "Reduce processing footprint",
                    ]
        return info

    def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]:
        bytes_per_frame = video_width * video_height * 3
        overhead_multiplier = 3.0  # masks/intermediates
        frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier)
        estimate = {
            "frames_memory_gb": round(frames_gb, 3),
            "model_memory_gb": 4.0,
            "system_overhead_gb": 2.0,
        }
        estimate["total_estimated_gb"] = round(
            estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3
        )
        return estimate

    def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]:
        estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory)
        usage = self.get_memory_usage()
        if self.gpu_available:
            available = usage.get("vram_free_gb", 0.0)
        else:
            available = usage.get("system_available_gb", 0.0)

        can = estimate["total_estimated_gb"] <= available
        return {
            "can_process": can,
            "estimated_memory_gb": estimate["total_estimated_gb"],
            "available_memory_gb": available,
            "memory_margin_gb": round(available - estimate["total_estimated_gb"], 3),
            "recommendations": [] if can else [
                "Reduce resolution or duration",
                "Process in smaller chunks",
                "Run aggressive cleanup before start",
            ],
        }

    def get_stats(self) -> Dict[str, Any]:
        return {
            "cleanup_count": self.stats["cleanup_count"],
            "peak_memory_usage_gb": self.stats["peak_memory_usage"],
            "device_type": self.device_type,
            "memory_limit_gb": self.memory_limit_gb,
            "applied_fraction": self.applied_fraction,
            "monitoring_active": self.monitoring_active,
            "callbacks_registered": len(self.cleanup_callbacks),
        }

    def __del__(self):
        try:
            self.stop_monitoring()
            self.cleanup_aggressive()
        except Exception:
            pass