VideoBackgroundReplacer / memory_manager.py
MogensR's picture
Create memory_manager.py
84a78ca
raw
history blame
19.3 kB
"""
Memory Management Module
Handles memory cleanup, monitoring, and GPU resource management
"""
import gc
import os
import psutil
import torch
import time
import logging
import threading
from typing import Dict, Any, Optional, Callable
from exceptions import MemoryError, ResourceExhaustionError
logger = logging.getLogger(__name__)
class MemoryManager:
"""
Comprehensive memory management for video processing applications
"""
def __init__(self, device: torch.device, memory_limit_gb: Optional[float] = None):
self.device = device
self.gpu_available = device.type in ['cuda', 'mps']
self.memory_limit_gb = memory_limit_gb
self.cleanup_callbacks = []
self.monitoring_active = False
self.monitoring_thread = None
self.stats = {
'cleanup_count': 0,
'peak_memory_usage': 0.0,
'total_allocated': 0.0,
'total_freed': 0.0
}
# Initialize memory monitoring
self._initialize_memory_limits()
logger.info(f"MemoryManager initialized for device: {device}")
def _initialize_memory_limits(self):
"""Initialize memory limits based on device and system"""
if self.device.type == 'cuda':
try:
device_idx = self.device.index or 0
device_props = torch.cuda.get_device_properties(device_idx)
total_memory_gb = device_props.total_memory / (1024**3)
# Use 80% of GPU memory as default limit if not specified
if self.memory_limit_gb is None:
self.memory_limit_gb = total_memory_gb * 0.8
logger.info(f"CUDA memory limit set to {self.memory_limit_gb:.1f}GB "
f"(total: {total_memory_gb:.1f}GB)")
except Exception as e:
logger.warning(f"Could not get CUDA memory info: {e}")
self.memory_limit_gb = 4.0 # Conservative fallback
elif self.device.type == 'mps':
# MPS uses unified memory, so check system memory
system_memory_gb = psutil.virtual_memory().total / (1024**3)
if self.memory_limit_gb is None:
# Use 50% of system memory for MPS as it shares with system
self.memory_limit_gb = system_memory_gb * 0.5
logger.info(f"MPS memory limit set to {self.memory_limit_gb:.1f}GB "
f"(system: {system_memory_gb:.1f}GB)")
else: # CPU
system_memory_gb = psutil.virtual_memory().total / (1024**3)
if self.memory_limit_gb is None:
# Use 60% of system memory for CPU processing
self.memory_limit_gb = system_memory_gb * 0.6
logger.info(f"CPU memory limit set to {self.memory_limit_gb:.1f}GB "
f"(system: {system_memory_gb:.1f}GB)")
def get_memory_usage(self) -> Dict[str, Any]:
"""Get comprehensive memory usage statistics"""
usage = {
'device_type': self.device.type,
'memory_limit_gb': self.memory_limit_gb,
'timestamp': time.time()
}
try:
if self.device.type == 'cuda':
device_idx = self.device.index or 0
# GPU memory
allocated = torch.cuda.memory_allocated(device_idx)
reserved = torch.cuda.memory_reserved(device_idx)
total = torch.cuda.get_device_properties(device_idx).total_memory
usage.update({
'gpu_allocated_gb': allocated / (1024**3),
'gpu_reserved_gb': reserved / (1024**3),
'gpu_total_gb': total / (1024**3),
'gpu_utilization_percent': (allocated / total) * 100,
'gpu_reserved_percent': (reserved / total) * 100,
'gpu_free_gb': (total - reserved) / (1024**3)
})
# Peak memory tracking
max_allocated = torch.cuda.max_memory_allocated(device_idx)
max_reserved = torch.cuda.max_memory_reserved(device_idx)
usage.update({
'gpu_max_allocated_gb': max_allocated / (1024**3),
'gpu_max_reserved_gb': max_reserved / (1024**3)
})
elif self.device.type == 'mps':
# MPS doesn't have explicit memory tracking like CUDA
# Fall back to system memory monitoring
vm = psutil.virtual_memory()
usage.update({
'system_memory_gb': vm.total / (1024**3),
'system_available_gb': vm.available / (1024**3),
'system_used_gb': vm.used / (1024**3),
'system_utilization_percent': vm.percent
})
except Exception as e:
logger.warning(f"Error getting GPU memory usage: {e}")
# Always include system memory info
try:
vm = psutil.virtual_memory()
swap = psutil.swap_memory()
usage.update({
'system_total_gb': vm.total / (1024**3),
'system_available_gb': vm.available / (1024**3),
'system_used_gb': vm.used / (1024**3),
'system_percent': vm.percent,
'swap_total_gb': swap.total / (1024**3),
'swap_used_gb': swap.used / (1024**3),
'swap_percent': swap.percent
})
except Exception as e:
logger.warning(f"Error getting system memory usage: {e}")
# Process-specific memory
try:
process = psutil.Process()
memory_info = process.memory_info()
usage.update({
'process_rss_gb': memory_info.rss / (1024**3), # Physical memory
'process_vms_gb': memory_info.vms / (1024**3), # Virtual memory
})
except Exception as e:
logger.warning(f"Error getting process memory usage: {e}")
# Update peak tracking
current_usage = usage.get('gpu_allocated_gb', usage.get('system_used_gb', 0))
if current_usage > self.stats['peak_memory_usage']:
self.stats['peak_memory_usage'] = current_usage
return usage
def cleanup_basic(self):
"""Basic memory cleanup - lightweight operation"""
try:
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
self.stats['cleanup_count'] += 1
logger.debug("Basic memory cleanup completed")
except Exception as e:
logger.warning(f"Basic memory cleanup failed: {e}")
def cleanup_aggressive(self):
"""Aggressive memory cleanup - more thorough but slower"""
try:
start_time = time.time()
# Run all registered cleanup callbacks first
for callback in self.cleanup_callbacks:
try:
callback()
except Exception as e:
logger.warning(f"Cleanup callback failed: {e}")
# Multiple garbage collection passes
for _ in range(3):
gc.collect()
if self.device.type == 'cuda':
# CUDA-specific aggressive cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset peak memory statistics
device_idx = self.device.index or 0
torch.cuda.reset_peak_memory_stats(device_idx)
elif self.device.type == 'mps':
# MPS cleanup - mainly garbage collection
# Could add MPS-specific operations if available
pass
cleanup_time = time.time() - start_time
self.stats['cleanup_count'] += 1
logger.debug(f"Aggressive memory cleanup completed in {cleanup_time:.2f}s")
except Exception as e:
logger.error(f"Aggressive memory cleanup failed: {e}")
raise MemoryError("aggressive_cleanup", str(e))
def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]:
"""Check if system is under memory pressure"""
usage = self.get_memory_usage()
pressure_info = {
'under_pressure': False,
'pressure_level': 'normal', # normal, warning, critical
'recommendations': [],
'usage_percent': 0.0
}
# Determine usage percentage based on device type
if self.device.type == 'cuda':
usage_percent = usage.get('gpu_utilization_percent', 0)
pressure_info['usage_percent'] = usage_percent
if usage_percent >= threshold_percent:
pressure_info['under_pressure'] = True
if usage_percent >= 95:
pressure_info['pressure_level'] = 'critical'
pressure_info['recommendations'].extend([
'Reduce batch size immediately',
'Enable gradient checkpointing',
'Consider switching to CPU processing'
])
elif usage_percent >= threshold_percent:
pressure_info['pressure_level'] = 'warning'
pressure_info['recommendations'].extend([
'Run aggressive memory cleanup',
'Reduce keyframe interval',
'Monitor memory usage closely'
])
else: # CPU or MPS - use system memory
usage_percent = usage.get('system_percent', 0)
pressure_info['usage_percent'] = usage_percent
if usage_percent >= threshold_percent:
pressure_info['under_pressure'] = True
if usage_percent >= 95:
pressure_info['pressure_level'] = 'critical'
pressure_info['recommendations'].extend([
'Free system memory immediately',
'Close unnecessary applications',
'Reduce video processing quality'
])
elif usage_percent >= threshold_percent:
pressure_info['pressure_level'] = 'warning'
pressure_info['recommendations'].extend([
'Run memory cleanup',
'Monitor system memory',
'Consider processing in smaller chunks'
])
return pressure_info
def auto_cleanup_if_needed(self, pressure_threshold: float = 80.0) -> bool:
"""Automatically run cleanup if memory pressure is detected"""
pressure = self.check_memory_pressure(pressure_threshold)
if pressure['under_pressure']:
cleanup_method = (
self.cleanup_aggressive
if pressure['pressure_level'] == 'critical'
else self.cleanup_basic
)
logger.info(f"Auto-cleanup triggered due to {pressure['pressure_level']} "
f"memory pressure ({pressure['usage_percent']:.1f}%)")
cleanup_method()
return True
return False
def register_cleanup_callback(self, callback: Callable):
"""Register a callback to run during cleanup operations"""
self.cleanup_callbacks.append(callback)
logger.debug("Cleanup callback registered")
def start_monitoring(self, interval_seconds: float = 30.0,
pressure_callback: Optional[Callable] = None):
"""Start background memory monitoring"""
if self.monitoring_active:
logger.warning("Memory monitoring already active")
return
self.monitoring_active = True
def monitor_loop():
while self.monitoring_active:
try:
pressure = self.check_memory_pressure()
if pressure['under_pressure']:
logger.warning(f"Memory pressure detected: {pressure['pressure_level']} "
f"({pressure['usage_percent']:.1f}%)")
if pressure_callback:
try:
pressure_callback(pressure)
except Exception as e:
logger.error(f"Pressure callback failed: {e}")
# Auto-cleanup on critical pressure
if pressure['pressure_level'] == 'critical':
self.cleanup_aggressive()
time.sleep(interval_seconds)
except Exception as e:
logger.error(f"Memory monitoring error: {e}")
time.sleep(interval_seconds)
self.monitoring_thread = threading.Thread(target=monitor_loop, daemon=True)
self.monitoring_thread.start()
logger.info(f"Memory monitoring started (interval: {interval_seconds}s)")
def stop_monitoring(self):
"""Stop background memory monitoring"""
if self.monitoring_active:
self.monitoring_active = False
if self.monitoring_thread and self.monitoring_thread.is_alive():
self.monitoring_thread.join(timeout=5.0)
logger.info("Memory monitoring stopped")
def estimate_memory_requirement(self, video_width: int, video_height: int,
frames_in_memory: int = 5) -> Dict[str, float]:
"""Estimate memory requirements for video processing"""
# Base memory per frame (RGB image)
bytes_per_frame = video_width * video_height * 3
# Additional overhead for processing
overhead_multiplier = 3.0 # For masks, intermediate results, etc.
estimated_memory = {
'frames_memory_gb': (bytes_per_frame * frames_in_memory * overhead_multiplier) / (1024**3),
'model_memory_gb': 4.0, # Rough estimate for SAM2 + MatAnyone
'system_overhead_gb': 2.0,
'total_estimated_gb': 0.0
}
estimated_memory['total_estimated_gb'] = sum([
estimated_memory['frames_memory_gb'],
estimated_memory['model_memory_gb'],
estimated_memory['system_overhead_gb']
])
return estimated_memory
def can_process_video(self, video_width: int, video_height: int,
frames_in_memory: int = 5) -> Dict[str, Any]:
"""Check if video can be processed with current memory"""
estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory)
current_usage = self.get_memory_usage()
# Available memory calculation
if self.device.type == 'cuda':
available_memory = current_usage.get('gpu_free_gb', 0)
else:
available_memory = current_usage.get('system_available_gb', 0)
can_process = estimate['total_estimated_gb'] <= available_memory
result = {
'can_process': can_process,
'estimated_memory_gb': estimate['total_estimated_gb'],
'available_memory_gb': available_memory,
'memory_margin_gb': available_memory - estimate['total_estimated_gb'],
'recommendations': []
}
if not can_process:
deficit = estimate['total_estimated_gb'] - available_memory
result['recommendations'] = [
f"Free {deficit:.1f}GB of memory",
"Reduce video resolution",
"Process in smaller chunks",
"Use lower quality settings"
]
elif result['memory_margin_gb'] < 1.0:
result['recommendations'] = [
"Memory margin is low",
"Monitor memory usage during processing",
"Consider reducing batch size"
]
return result
def get_optimization_suggestions(self) -> Dict[str, Any]:
"""Get memory optimization suggestions based on current state"""
usage = self.get_memory_usage()
suggestions = {
'current_usage_percent': usage.get('gpu_utilization_percent', usage.get('system_percent', 0)),
'suggestions': [],
'priority': 'low' # low, medium, high
}
usage_percent = suggestions['current_usage_percent']
if usage_percent >= 90:
suggestions['priority'] = 'high'
suggestions['suggestions'].extend([
'Run aggressive memory cleanup immediately',
'Reduce batch size to 1',
'Enable gradient checkpointing if available',
'Consider switching to CPU processing'
])
elif usage_percent >= 75:
suggestions['priority'] = 'medium'
suggestions['suggestions'].extend([
'Run memory cleanup regularly',
'Monitor memory usage closely',
'Reduce keyframe interval',
'Use mixed precision if supported'
])
elif usage_percent >= 50:
suggestions['priority'] = 'low'
suggestions['suggestions'].extend([
'Current usage is acceptable',
'Regular cleanup should be sufficient',
'Monitor for memory leaks during long operations'
])
else:
suggestions['suggestions'] = [
'Memory usage is optimal',
'No immediate action required'
]
return suggestions
def get_stats(self) -> Dict[str, Any]:
"""Get memory management statistics"""
return {
'cleanup_count': self.stats['cleanup_count'],
'peak_memory_usage_gb': self.stats['peak_memory_usage'],
'monitoring_active': self.monitoring_active,
'device_type': self.device.type,
'memory_limit_gb': self.memory_limit_gb,
'registered_callbacks': len(self.cleanup_callbacks)
}
def __del__(self):
"""Cleanup when MemoryManager is destroyed"""
try:
self.stop_monitoring()
self.cleanup_aggressive()
except Exception:
pass # Ignore errors during cleanup