MogensR's picture
Update utils/system/memory_manager.py
a70dcf0
#!/usr/bin/env python3
"""
Memory Manager for BackgroundFX Pro
- Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces)
- Accepts `device` as str or torch.device
- Optional per-process VRAM cap (env or method)
- Detailed usage reporting (CPU/RAM + VRAM + torch allocator)
- Light and aggressive cleanup paths
- Background monitor (optional)
Env switches:
BFX_DISABLE_LIMIT=1 -> do not set VRAM fraction automatically
BFX_CUDA_FRACTION=0.80 -> fraction to cap per-process VRAM (0.10..0.95)
"""
from __future__ import annotations
import gc
import os
import time
import logging
import threading
from typing import Dict, Any, Optional, Callable
# Optional deps
try:
import psutil
except Exception: # pragma: no cover
psutil = None
try:
import torch
except Exception: # pragma: no cover
torch = None
logger = logging.getLogger(__name__)
# ---- local exception to avoid shadowing built-in MemoryError ----
class MemoryManagerError(Exception):
pass
def _bytes_to_gb(x: int | float) -> float:
try:
return float(x) / (1024**3)
except Exception:
return 0.0
def _normalize_device(dev) -> "torch.device":
if torch is None:
# fake CPU device
class _Fake:
type = "cpu"
index = None
return _Fake() # type: ignore[return-value]
if isinstance(dev, str):
return torch.device(dev)
if hasattr(dev, "type"):
return dev
# default CPU
return torch.device("cpu")
def _cuda_index(device) -> Optional[int]:
if getattr(device, "type", "cpu") != "cuda":
return None
idx = getattr(device, "index", None)
if idx is None:
# normalize bare "cuda" to 0
return 0
return int(idx)
class MemoryManager:
"""
Comprehensive memory management with VRAM cap + cleanup utilities.
"""
def __init__(self, device, memory_limit_gb: Optional[float] = None):
self.device = _normalize_device(device)
self.device_type = getattr(self.device, "type", "cpu")
self.cuda_idx = _cuda_index(self.device)
self.gpu_available = bool(
torch and self.device_type == "cuda" and torch.cuda.is_available()
)
self.mps_available = bool(
torch and self.device_type == "mps" and getattr(torch.backends, "mps", None)
and torch.backends.mps.is_available()
)
self.memory_limit_gb = memory_limit_gb
self.cleanup_callbacks: list[Callable] = []
self.monitoring_active = False
self.monitoring_thread: Optional[threading.Thread] = None
self.stats = {
"cleanup_count": 0,
"peak_memory_usage": 0.0,
"total_allocated": 0.0,
"total_freed": 0.0,
}
self.applied_fraction: Optional[float] = None
self._initialize_memory_limits()
self._maybe_apply_vram_fraction()
logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})")
# -------------------------------
# init helpers
# -------------------------------
def _initialize_memory_limits(self):
try:
if self.gpu_available:
props = torch.cuda.get_device_properties(self.cuda_idx or 0)
total_gb = _bytes_to_gb(props.total_memory)
if self.memory_limit_gb is None:
self.memory_limit_gb = max(0.5, total_gb * 0.80) # default 80%
logger.info(
f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB "
f"(device total {total_gb:.1f}GB)"
)
elif self.mps_available:
vm = psutil.virtual_memory() if psutil else None
total_gb = _bytes_to_gb(vm.total) if vm else 0.0
if self.memory_limit_gb is None:
self.memory_limit_gb = max(0.5, total_gb * 0.50)
logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)")
else:
vm = psutil.virtual_memory() if psutil else None
total_gb = _bytes_to_gb(vm.total) if vm else 0.0
if self.memory_limit_gb is None:
self.memory_limit_gb = max(0.5, total_gb * 0.60)
logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)")
except Exception as e:
logger.warning(f"Memory limit init failed: {e}")
if self.memory_limit_gb is None:
self.memory_limit_gb = 4.0 # conservative fallback
def _maybe_apply_vram_fraction(self):
if not self.gpu_available or torch is None:
return
if os.environ.get("BFX_DISABLE_LIMIT", ""):
return
frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip()
try:
fraction = float(frac_env) if frac_env else 0.80
except Exception:
fraction = 0.80
applied = self.limit_cuda_memory(fraction=fraction)
if applied:
logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}")
# -------------------------------
# public API
# -------------------------------
def get_memory_usage(self) -> Dict[str, Any]:
usage: Dict[str, Any] = {
"device_type": self.device_type,
"memory_limit_gb": self.memory_limit_gb,
"timestamp": time.time(),
}
# CPU / system
if psutil:
try:
vm = psutil.virtual_memory()
usage.update(
dict(
system_total_gb=round(_bytes_to_gb(vm.total), 3),
system_available_gb=round(_bytes_to_gb(vm.available), 3),
system_used_gb=round(_bytes_to_gb(vm.used), 3),
system_percent=float(vm.percent),
)
)
swap = psutil.swap_memory()
usage.update(
dict(
swap_total_gb=round(_bytes_to_gb(swap.total), 3),
swap_used_gb=round(_bytes_to_gb(swap.used), 3),
swap_percent=float(swap.percent),
)
)
proc = psutil.Process()
mi = proc.memory_info()
usage.update(
dict(
process_rss_gb=round(_bytes_to_gb(mi.rss), 3),
process_vms_gb=round(_bytes_to_gb(mi.vms), 3),
)
)
except Exception as e:
logger.debug(f"psutil stats error: {e}")
# GPU
if self.gpu_available and torch is not None:
try:
# mem_get_info returns (free, total) in bytes
free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0)
used_b = total_b - free_b
usage.update(
dict(
vram_total_gb=round(_bytes_to_gb(total_b), 3),
vram_used_gb=round(_bytes_to_gb(used_b), 3),
vram_free_gb=round(_bytes_to_gb(free_b), 3),
vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0,
)
)
except Exception as e:
logger.debug(f"mem_get_info failed: {e}")
# torch allocator stats
try:
idx = self.cuda_idx or 0
allocated = torch.cuda.memory_allocated(idx)
reserved = torch.cuda.memory_reserved(idx)
usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3)
usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3)
# inactive split (2.x)
try:
inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0)
usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3)
except Exception:
pass
except Exception as e:
logger.debug(f"allocator stats failed: {e}")
usage["applied_fraction"] = self.applied_fraction
# Update peak tracker
current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0))
try:
if float(current) > float(self.stats["peak_memory_usage"]):
self.stats["peak_memory_usage"] = float(current)
except Exception:
pass
return usage
def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]:
if not self.gpu_available or torch is None:
return None
# derive fraction from max_gb if provided
if max_gb is not None:
try:
_, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0)
total_gb = _bytes_to_gb(total_b)
if total_gb <= 0:
return None
fraction = min(max(0.10, max_gb / total_gb), 0.95)
except Exception as e:
logger.debug(f"fraction from max_gb failed: {e}")
return None
if fraction is None:
fraction = 0.80
fraction = float(max(0.10, min(0.95, fraction)))
try:
torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0)
self.applied_fraction = fraction
return fraction
except Exception as e:
logger.debug(f"set_per_process_memory_fraction failed: {e}")
return None
def cleanup(self) -> None:
"""Light cleanup used frequently between steps."""
try:
gc.collect()
except Exception:
pass
if self.gpu_available and torch is not None:
try:
torch.cuda.empty_cache()
except Exception:
pass
self.stats["cleanup_count"] += 1
def cleanup_basic(self) -> None:
"""Alias kept for compatibility."""
self.cleanup()
def cleanup_aggressive(self) -> None:
"""Aggressive cleanup for OOM recovery or big scene switches."""
if self.gpu_available and torch is not None:
try:
torch.cuda.synchronize(self.cuda_idx or 0)
except Exception:
pass
try:
torch.cuda.empty_cache()
except Exception:
pass
try:
torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0)
except Exception:
pass
try:
if hasattr(torch.cuda, "ipc_collect"):
torch.cuda.ipc_collect()
except Exception:
pass
try:
gc.collect(); gc.collect()
except Exception:
pass
self.stats["cleanup_count"] += 1
def register_cleanup_callback(self, callback: Callable):
self.cleanup_callbacks.append(callback)
def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None):
if self.monitoring_active:
logger.warning("Memory monitoring already active")
return
self.monitoring_active = True
def loop():
while self.monitoring_active:
try:
pressure = self.check_memory_pressure()
if pressure["under_pressure"]:
logger.warning(
f"Memory pressure: {pressure['pressure_level']} "
f"({pressure['usage_percent']:.1f}%)"
)
if pressure_callback:
try:
pressure_callback(pressure)
except Exception as e:
logger.error(f"Pressure callback failed: {e}")
if pressure["pressure_level"] == "critical":
self.cleanup_aggressive()
except Exception as e:
logger.error(f"Memory monitoring error: {e}")
time.sleep(interval_seconds)
self.monitoring_thread = threading.Thread(target=loop, daemon=True)
self.monitoring_thread.start()
logger.info(f"Memory monitoring started (interval: {interval_seconds}s)")
def stop_monitoring(self):
if self.monitoring_active:
self.monitoring_active = False
if self.monitoring_thread and self.monitoring_thread.is_alive():
self.monitoring_thread.join(timeout=5.0)
logger.info("Memory monitoring stopped")
def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]:
usage = self.get_memory_usage()
info = {
"under_pressure": False,
"pressure_level": "normal",
"usage_percent": 0.0,
"recommendations": [],
}
if self.gpu_available:
percent = usage.get("vram_used_percent", 0.0)
info["usage_percent"] = percent
if percent >= threshold_percent:
info["under_pressure"] = True
if percent >= 95:
info["pressure_level"] = "critical"
info["recommendations"] += [
"Run aggressive memory cleanup",
"Reduce frame cache / chunk size",
"Lower resolution or disable previews",
]
else:
info["pressure_level"] = "warning"
info["recommendations"] += [
"Run cleanup",
"Monitor memory usage",
"Reduce keyframe interval",
]
else:
percent = usage.get("system_percent", 0.0)
info["usage_percent"] = percent
if percent >= threshold_percent:
info["under_pressure"] = True
if percent >= 95:
info["pressure_level"] = "critical"
info["recommendations"] += [
"Close other processes",
"Reduce resolution",
"Split video into chunks",
]
else:
info["pressure_level"] = "warning"
info["recommendations"] += [
"Run cleanup",
"Monitor usage",
"Reduce processing footprint",
]
return info
def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]:
bytes_per_frame = video_width * video_height * 3
overhead_multiplier = 3.0 # masks/intermediates
frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier)
estimate = {
"frames_memory_gb": round(frames_gb, 3),
"model_memory_gb": 4.0,
"system_overhead_gb": 2.0,
}
estimate["total_estimated_gb"] = round(
estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3
)
return estimate
def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]:
estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory)
usage = self.get_memory_usage()
if self.gpu_available:
available = usage.get("vram_free_gb", 0.0)
else:
available = usage.get("system_available_gb", 0.0)
can = estimate["total_estimated_gb"] <= available
return {
"can_process": can,
"estimated_memory_gb": estimate["total_estimated_gb"],
"available_memory_gb": available,
"memory_margin_gb": round(available - estimate["total_estimated_gb"], 3),
"recommendations": [] if can else [
"Reduce resolution or duration",
"Process in smaller chunks",
"Run aggressive cleanup before start",
],
}
def get_stats(self) -> Dict[str, Any]:
return {
"cleanup_count": self.stats["cleanup_count"],
"peak_memory_usage_gb": self.stats["peak_memory_usage"],
"device_type": self.device_type,
"memory_limit_gb": self.memory_limit_gb,
"applied_fraction": self.applied_fraction,
"monitoring_active": self.monitoring_active,
"callbacks_registered": len(self.cleanup_callbacks),
}
def __del__(self):
try:
self.stop_monitoring()
self.cleanup_aggressive()
except Exception:
pass