|
|
|
|
|
""" |
|
|
Memory Manager for BackgroundFX Pro |
|
|
- Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces) |
|
|
- Accepts `device` as str or torch.device |
|
|
- Optional per-process VRAM cap (env or method) |
|
|
- Detailed usage reporting (CPU/RAM + VRAM + torch allocator) |
|
|
- Light and aggressive cleanup paths |
|
|
- Background monitor (optional) |
|
|
|
|
|
Env switches: |
|
|
BFX_DISABLE_LIMIT=1 -> do not set VRAM fraction automatically |
|
|
BFX_CUDA_FRACTION=0.80 -> fraction to cap per-process VRAM (0.10..0.95) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
import gc |
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
import threading |
|
|
from typing import Dict, Any, Optional, Callable |
|
|
|
|
|
|
|
|
try: |
|
|
import psutil |
|
|
except Exception: |
|
|
psutil = None |
|
|
|
|
|
try: |
|
|
import torch |
|
|
except Exception: |
|
|
torch = None |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MemoryManagerError(Exception): |
|
|
pass |
|
|
|
|
|
|
|
|
def _bytes_to_gb(x: int | float) -> float: |
|
|
try: |
|
|
return float(x) / (1024**3) |
|
|
except Exception: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def _normalize_device(dev) -> "torch.device": |
|
|
if torch is None: |
|
|
|
|
|
class _Fake: |
|
|
type = "cpu" |
|
|
index = None |
|
|
return _Fake() |
|
|
|
|
|
if isinstance(dev, str): |
|
|
return torch.device(dev) |
|
|
if hasattr(dev, "type"): |
|
|
return dev |
|
|
|
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def _cuda_index(device) -> Optional[int]: |
|
|
if getattr(device, "type", "cpu") != "cuda": |
|
|
return None |
|
|
idx = getattr(device, "index", None) |
|
|
if idx is None: |
|
|
|
|
|
return 0 |
|
|
return int(idx) |
|
|
|
|
|
|
|
|
class MemoryManager: |
|
|
""" |
|
|
Comprehensive memory management with VRAM cap + cleanup utilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, device, memory_limit_gb: Optional[float] = None): |
|
|
self.device = _normalize_device(device) |
|
|
self.device_type = getattr(self.device, "type", "cpu") |
|
|
self.cuda_idx = _cuda_index(self.device) |
|
|
|
|
|
self.gpu_available = bool( |
|
|
torch and self.device_type == "cuda" and torch.cuda.is_available() |
|
|
) |
|
|
self.mps_available = bool( |
|
|
torch and self.device_type == "mps" and getattr(torch.backends, "mps", None) |
|
|
and torch.backends.mps.is_available() |
|
|
) |
|
|
|
|
|
self.memory_limit_gb = memory_limit_gb |
|
|
self.cleanup_callbacks: list[Callable] = [] |
|
|
self.monitoring_active = False |
|
|
self.monitoring_thread: Optional[threading.Thread] = None |
|
|
self.stats = { |
|
|
"cleanup_count": 0, |
|
|
"peak_memory_usage": 0.0, |
|
|
"total_allocated": 0.0, |
|
|
"total_freed": 0.0, |
|
|
} |
|
|
self.applied_fraction: Optional[float] = None |
|
|
|
|
|
self._initialize_memory_limits() |
|
|
self._maybe_apply_vram_fraction() |
|
|
logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_memory_limits(self): |
|
|
try: |
|
|
if self.gpu_available: |
|
|
props = torch.cuda.get_device_properties(self.cuda_idx or 0) |
|
|
total_gb = _bytes_to_gb(props.total_memory) |
|
|
if self.memory_limit_gb is None: |
|
|
self.memory_limit_gb = max(0.5, total_gb * 0.80) |
|
|
logger.info( |
|
|
f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB " |
|
|
f"(device total {total_gb:.1f}GB)" |
|
|
) |
|
|
elif self.mps_available: |
|
|
vm = psutil.virtual_memory() if psutil else None |
|
|
total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
|
|
if self.memory_limit_gb is None: |
|
|
self.memory_limit_gb = max(0.5, total_gb * 0.50) |
|
|
logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
|
|
else: |
|
|
vm = psutil.virtual_memory() if psutil else None |
|
|
total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
|
|
if self.memory_limit_gb is None: |
|
|
self.memory_limit_gb = max(0.5, total_gb * 0.60) |
|
|
logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
|
|
except Exception as e: |
|
|
logger.warning(f"Memory limit init failed: {e}") |
|
|
if self.memory_limit_gb is None: |
|
|
self.memory_limit_gb = 4.0 |
|
|
|
|
|
def _maybe_apply_vram_fraction(self): |
|
|
if not self.gpu_available or torch is None: |
|
|
return |
|
|
if os.environ.get("BFX_DISABLE_LIMIT", ""): |
|
|
return |
|
|
frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip() |
|
|
try: |
|
|
fraction = float(frac_env) if frac_env else 0.80 |
|
|
except Exception: |
|
|
fraction = 0.80 |
|
|
applied = self.limit_cuda_memory(fraction=fraction) |
|
|
if applied: |
|
|
logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, Any]: |
|
|
usage: Dict[str, Any] = { |
|
|
"device_type": self.device_type, |
|
|
"memory_limit_gb": self.memory_limit_gb, |
|
|
"timestamp": time.time(), |
|
|
} |
|
|
|
|
|
|
|
|
if psutil: |
|
|
try: |
|
|
vm = psutil.virtual_memory() |
|
|
usage.update( |
|
|
dict( |
|
|
system_total_gb=round(_bytes_to_gb(vm.total), 3), |
|
|
system_available_gb=round(_bytes_to_gb(vm.available), 3), |
|
|
system_used_gb=round(_bytes_to_gb(vm.used), 3), |
|
|
system_percent=float(vm.percent), |
|
|
) |
|
|
) |
|
|
swap = psutil.swap_memory() |
|
|
usage.update( |
|
|
dict( |
|
|
swap_total_gb=round(_bytes_to_gb(swap.total), 3), |
|
|
swap_used_gb=round(_bytes_to_gb(swap.used), 3), |
|
|
swap_percent=float(swap.percent), |
|
|
) |
|
|
) |
|
|
proc = psutil.Process() |
|
|
mi = proc.memory_info() |
|
|
usage.update( |
|
|
dict( |
|
|
process_rss_gb=round(_bytes_to_gb(mi.rss), 3), |
|
|
process_vms_gb=round(_bytes_to_gb(mi.vms), 3), |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
logger.debug(f"psutil stats error: {e}") |
|
|
|
|
|
|
|
|
if self.gpu_available and torch is not None: |
|
|
try: |
|
|
|
|
|
free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
|
|
used_b = total_b - free_b |
|
|
usage.update( |
|
|
dict( |
|
|
vram_total_gb=round(_bytes_to_gb(total_b), 3), |
|
|
vram_used_gb=round(_bytes_to_gb(used_b), 3), |
|
|
vram_free_gb=round(_bytes_to_gb(free_b), 3), |
|
|
vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0, |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
logger.debug(f"mem_get_info failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
idx = self.cuda_idx or 0 |
|
|
allocated = torch.cuda.memory_allocated(idx) |
|
|
reserved = torch.cuda.memory_reserved(idx) |
|
|
usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3) |
|
|
usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3) |
|
|
|
|
|
try: |
|
|
inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0) |
|
|
usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception as e: |
|
|
logger.debug(f"allocator stats failed: {e}") |
|
|
|
|
|
usage["applied_fraction"] = self.applied_fraction |
|
|
|
|
|
|
|
|
current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0)) |
|
|
try: |
|
|
if float(current) > float(self.stats["peak_memory_usage"]): |
|
|
self.stats["peak_memory_usage"] = float(current) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return usage |
|
|
|
|
|
def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]: |
|
|
if not self.gpu_available or torch is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if max_gb is not None: |
|
|
try: |
|
|
_, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
|
|
total_gb = _bytes_to_gb(total_b) |
|
|
if total_gb <= 0: |
|
|
return None |
|
|
fraction = min(max(0.10, max_gb / total_gb), 0.95) |
|
|
except Exception as e: |
|
|
logger.debug(f"fraction from max_gb failed: {e}") |
|
|
return None |
|
|
|
|
|
if fraction is None: |
|
|
fraction = 0.80 |
|
|
fraction = float(max(0.10, min(0.95, fraction))) |
|
|
|
|
|
try: |
|
|
torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0) |
|
|
self.applied_fraction = fraction |
|
|
return fraction |
|
|
except Exception as e: |
|
|
logger.debug(f"set_per_process_memory_fraction failed: {e}") |
|
|
return None |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Light cleanup used frequently between steps.""" |
|
|
try: |
|
|
gc.collect() |
|
|
except Exception: |
|
|
pass |
|
|
if self.gpu_available and torch is not None: |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
except Exception: |
|
|
pass |
|
|
self.stats["cleanup_count"] += 1 |
|
|
|
|
|
def cleanup_basic(self) -> None: |
|
|
"""Alias kept for compatibility.""" |
|
|
self.cleanup() |
|
|
|
|
|
def cleanup_aggressive(self) -> None: |
|
|
"""Aggressive cleanup for OOM recovery or big scene switches.""" |
|
|
if self.gpu_available and torch is not None: |
|
|
try: |
|
|
torch.cuda.synchronize(self.cuda_idx or 0) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
if hasattr(torch.cuda, "ipc_collect"): |
|
|
torch.cuda.ipc_collect() |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
gc.collect(); gc.collect() |
|
|
except Exception: |
|
|
pass |
|
|
self.stats["cleanup_count"] += 1 |
|
|
|
|
|
def register_cleanup_callback(self, callback: Callable): |
|
|
self.cleanup_callbacks.append(callback) |
|
|
|
|
|
def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None): |
|
|
if self.monitoring_active: |
|
|
logger.warning("Memory monitoring already active") |
|
|
return |
|
|
self.monitoring_active = True |
|
|
|
|
|
def loop(): |
|
|
while self.monitoring_active: |
|
|
try: |
|
|
pressure = self.check_memory_pressure() |
|
|
if pressure["under_pressure"]: |
|
|
logger.warning( |
|
|
f"Memory pressure: {pressure['pressure_level']} " |
|
|
f"({pressure['usage_percent']:.1f}%)" |
|
|
) |
|
|
if pressure_callback: |
|
|
try: |
|
|
pressure_callback(pressure) |
|
|
except Exception as e: |
|
|
logger.error(f"Pressure callback failed: {e}") |
|
|
if pressure["pressure_level"] == "critical": |
|
|
self.cleanup_aggressive() |
|
|
except Exception as e: |
|
|
logger.error(f"Memory monitoring error: {e}") |
|
|
time.sleep(interval_seconds) |
|
|
|
|
|
self.monitoring_thread = threading.Thread(target=loop, daemon=True) |
|
|
self.monitoring_thread.start() |
|
|
logger.info(f"Memory monitoring started (interval: {interval_seconds}s)") |
|
|
|
|
|
def stop_monitoring(self): |
|
|
if self.monitoring_active: |
|
|
self.monitoring_active = False |
|
|
if self.monitoring_thread and self.monitoring_thread.is_alive(): |
|
|
self.monitoring_thread.join(timeout=5.0) |
|
|
logger.info("Memory monitoring stopped") |
|
|
|
|
|
def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]: |
|
|
usage = self.get_memory_usage() |
|
|
info = { |
|
|
"under_pressure": False, |
|
|
"pressure_level": "normal", |
|
|
"usage_percent": 0.0, |
|
|
"recommendations": [], |
|
|
} |
|
|
|
|
|
if self.gpu_available: |
|
|
percent = usage.get("vram_used_percent", 0.0) |
|
|
info["usage_percent"] = percent |
|
|
if percent >= threshold_percent: |
|
|
info["under_pressure"] = True |
|
|
if percent >= 95: |
|
|
info["pressure_level"] = "critical" |
|
|
info["recommendations"] += [ |
|
|
"Run aggressive memory cleanup", |
|
|
"Reduce frame cache / chunk size", |
|
|
"Lower resolution or disable previews", |
|
|
] |
|
|
else: |
|
|
info["pressure_level"] = "warning" |
|
|
info["recommendations"] += [ |
|
|
"Run cleanup", |
|
|
"Monitor memory usage", |
|
|
"Reduce keyframe interval", |
|
|
] |
|
|
else: |
|
|
percent = usage.get("system_percent", 0.0) |
|
|
info["usage_percent"] = percent |
|
|
if percent >= threshold_percent: |
|
|
info["under_pressure"] = True |
|
|
if percent >= 95: |
|
|
info["pressure_level"] = "critical" |
|
|
info["recommendations"] += [ |
|
|
"Close other processes", |
|
|
"Reduce resolution", |
|
|
"Split video into chunks", |
|
|
] |
|
|
else: |
|
|
info["pressure_level"] = "warning" |
|
|
info["recommendations"] += [ |
|
|
"Run cleanup", |
|
|
"Monitor usage", |
|
|
"Reduce processing footprint", |
|
|
] |
|
|
return info |
|
|
|
|
|
def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]: |
|
|
bytes_per_frame = video_width * video_height * 3 |
|
|
overhead_multiplier = 3.0 |
|
|
frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier) |
|
|
estimate = { |
|
|
"frames_memory_gb": round(frames_gb, 3), |
|
|
"model_memory_gb": 4.0, |
|
|
"system_overhead_gb": 2.0, |
|
|
} |
|
|
estimate["total_estimated_gb"] = round( |
|
|
estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3 |
|
|
) |
|
|
return estimate |
|
|
|
|
|
def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]: |
|
|
estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory) |
|
|
usage = self.get_memory_usage() |
|
|
if self.gpu_available: |
|
|
available = usage.get("vram_free_gb", 0.0) |
|
|
else: |
|
|
available = usage.get("system_available_gb", 0.0) |
|
|
|
|
|
can = estimate["total_estimated_gb"] <= available |
|
|
return { |
|
|
"can_process": can, |
|
|
"estimated_memory_gb": estimate["total_estimated_gb"], |
|
|
"available_memory_gb": available, |
|
|
"memory_margin_gb": round(available - estimate["total_estimated_gb"], 3), |
|
|
"recommendations": [] if can else [ |
|
|
"Reduce resolution or duration", |
|
|
"Process in smaller chunks", |
|
|
"Run aggressive cleanup before start", |
|
|
], |
|
|
} |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"cleanup_count": self.stats["cleanup_count"], |
|
|
"peak_memory_usage_gb": self.stats["peak_memory_usage"], |
|
|
"device_type": self.device_type, |
|
|
"memory_limit_gb": self.memory_limit_gb, |
|
|
"applied_fraction": self.applied_fraction, |
|
|
"monitoring_active": self.monitoring_active, |
|
|
"callbacks_registered": len(self.cleanup_callbacks), |
|
|
} |
|
|
|
|
|
def __del__(self): |
|
|
try: |
|
|
self.stop_monitoring() |
|
|
self.cleanup_aggressive() |
|
|
except Exception: |
|
|
pass |
|
|
|