""" Memory Management Module Handles memory cleanup, monitoring, and GPU resource management """ import gc import os import psutil import torch import time import logging import threading from typing import Dict, Any, Optional, Callable from exceptions import MemoryError, ResourceExhaustionError logger = logging.getLogger(__name__) class MemoryManager: """ Comprehensive memory management for video processing applications """ def __init__(self, device: torch.device, memory_limit_gb: Optional[float] = None): self.device = device self.gpu_available = device.type in ['cuda', 'mps'] self.memory_limit_gb = memory_limit_gb self.cleanup_callbacks = [] self.monitoring_active = False self.monitoring_thread = None self.stats = { 'cleanup_count': 0, 'peak_memory_usage': 0.0, 'total_allocated': 0.0, 'total_freed': 0.0 } # Initialize memory monitoring self._initialize_memory_limits() logger.info(f"MemoryManager initialized for device: {device}") def _initialize_memory_limits(self): """Initialize memory limits based on device and system""" if self.device.type == 'cuda': try: device_idx = self.device.index or 0 device_props = torch.cuda.get_device_properties(device_idx) total_memory_gb = device_props.total_memory / (1024**3) # Use 80% of GPU memory as default limit if not specified if self.memory_limit_gb is None: self.memory_limit_gb = total_memory_gb * 0.8 logger.info(f"CUDA memory limit set to {self.memory_limit_gb:.1f}GB " f"(total: {total_memory_gb:.1f}GB)") except Exception as e: logger.warning(f"Could not get CUDA memory info: {e}") self.memory_limit_gb = 4.0 # Conservative fallback elif self.device.type == 'mps': # MPS uses unified memory, so check system memory system_memory_gb = psutil.virtual_memory().total / (1024**3) if self.memory_limit_gb is None: # Use 50% of system memory for MPS as it shares with system self.memory_limit_gb = system_memory_gb * 0.5 logger.info(f"MPS memory limit set to {self.memory_limit_gb:.1f}GB " f"(system: {system_memory_gb:.1f}GB)") else: # CPU system_memory_gb = psutil.virtual_memory().total / (1024**3) if self.memory_limit_gb is None: # Use 60% of system memory for CPU processing self.memory_limit_gb = system_memory_gb * 0.6 logger.info(f"CPU memory limit set to {self.memory_limit_gb:.1f}GB " f"(system: {system_memory_gb:.1f}GB)") def get_memory_usage(self) -> Dict[str, Any]: """Get comprehensive memory usage statistics""" usage = { 'device_type': self.device.type, 'memory_limit_gb': self.memory_limit_gb, 'timestamp': time.time() } try: if self.device.type == 'cuda': device_idx = self.device.index or 0 # GPU memory allocated = torch.cuda.memory_allocated(device_idx) reserved = torch.cuda.memory_reserved(device_idx) total = torch.cuda.get_device_properties(device_idx).total_memory usage.update({ 'gpu_allocated_gb': allocated / (1024**3), 'gpu_reserved_gb': reserved / (1024**3), 'gpu_total_gb': total / (1024**3), 'gpu_utilization_percent': (allocated / total) * 100, 'gpu_reserved_percent': (reserved / total) * 100, 'gpu_free_gb': (total - reserved) / (1024**3) }) # Peak memory tracking max_allocated = torch.cuda.max_memory_allocated(device_idx) max_reserved = torch.cuda.max_memory_reserved(device_idx) usage.update({ 'gpu_max_allocated_gb': max_allocated / (1024**3), 'gpu_max_reserved_gb': max_reserved / (1024**3) }) elif self.device.type == 'mps': # MPS doesn't have explicit memory tracking like CUDA # Fall back to system memory monitoring vm = psutil.virtual_memory() usage.update({ 'system_memory_gb': vm.total / (1024**3), 'system_available_gb': vm.available / (1024**3), 'system_used_gb': vm.used / (1024**3), 'system_utilization_percent': vm.percent }) except Exception as e: logger.warning(f"Error getting GPU memory usage: {e}") # Always include system memory info try: vm = psutil.virtual_memory() swap = psutil.swap_memory() usage.update({ 'system_total_gb': vm.total / (1024**3), 'system_available_gb': vm.available / (1024**3), 'system_used_gb': vm.used / (1024**3), 'system_percent': vm.percent, 'swap_total_gb': swap.total / (1024**3), 'swap_used_gb': swap.used / (1024**3), 'swap_percent': swap.percent }) except Exception as e: logger.warning(f"Error getting system memory usage: {e}") # Process-specific memory try: process = psutil.Process() memory_info = process.memory_info() usage.update({ 'process_rss_gb': memory_info.rss / (1024**3), # Physical memory 'process_vms_gb': memory_info.vms / (1024**3), # Virtual memory }) except Exception as e: logger.warning(f"Error getting process memory usage: {e}") # Update peak tracking current_usage = usage.get('gpu_allocated_gb', usage.get('system_used_gb', 0)) if current_usage > self.stats['peak_memory_usage']: self.stats['peak_memory_usage'] = current_usage return usage def cleanup_basic(self): """Basic memory cleanup - lightweight operation""" try: gc.collect() if self.device.type == 'cuda': torch.cuda.empty_cache() self.stats['cleanup_count'] += 1 logger.debug("Basic memory cleanup completed") except Exception as e: logger.warning(f"Basic memory cleanup failed: {e}") def cleanup_aggressive(self): """Aggressive memory cleanup - more thorough but slower""" try: start_time = time.time() # Run all registered cleanup callbacks first for callback in self.cleanup_callbacks: try: callback() except Exception as e: logger.warning(f"Cleanup callback failed: {e}") # Multiple garbage collection passes for _ in range(3): gc.collect() if self.device.type == 'cuda': # CUDA-specific aggressive cleanup torch.cuda.empty_cache() torch.cuda.synchronize() # Reset peak memory statistics device_idx = self.device.index or 0 torch.cuda.reset_peak_memory_stats(device_idx) elif self.device.type == 'mps': # MPS cleanup - mainly garbage collection # Could add MPS-specific operations if available pass cleanup_time = time.time() - start_time self.stats['cleanup_count'] += 1 logger.debug(f"Aggressive memory cleanup completed in {cleanup_time:.2f}s") except Exception as e: logger.error(f"Aggressive memory cleanup failed: {e}") raise MemoryError("aggressive_cleanup", str(e)) def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]: """Check if system is under memory pressure""" usage = self.get_memory_usage() pressure_info = { 'under_pressure': False, 'pressure_level': 'normal', # normal, warning, critical 'recommendations': [], 'usage_percent': 0.0 } # Determine usage percentage based on device type if self.device.type == 'cuda': usage_percent = usage.get('gpu_utilization_percent', 0) pressure_info['usage_percent'] = usage_percent if usage_percent >= threshold_percent: pressure_info['under_pressure'] = True if usage_percent >= 95: pressure_info['pressure_level'] = 'critical' pressure_info['recommendations'].extend([ 'Reduce batch size immediately', 'Enable gradient checkpointing', 'Consider switching to CPU processing' ]) elif usage_percent >= threshold_percent: pressure_info['pressure_level'] = 'warning' pressure_info['recommendations'].extend([ 'Run aggressive memory cleanup', 'Reduce keyframe interval', 'Monitor memory usage closely' ]) else: # CPU or MPS - use system memory usage_percent = usage.get('system_percent', 0) pressure_info['usage_percent'] = usage_percent if usage_percent >= threshold_percent: pressure_info['under_pressure'] = True if usage_percent >= 95: pressure_info['pressure_level'] = 'critical' pressure_info['recommendations'].extend([ 'Free system memory immediately', 'Close unnecessary applications', 'Reduce video processing quality' ]) elif usage_percent >= threshold_percent: pressure_info['pressure_level'] = 'warning' pressure_info['recommendations'].extend([ 'Run memory cleanup', 'Monitor system memory', 'Consider processing in smaller chunks' ]) return pressure_info def auto_cleanup_if_needed(self, pressure_threshold: float = 80.0) -> bool: """Automatically run cleanup if memory pressure is detected""" pressure = self.check_memory_pressure(pressure_threshold) if pressure['under_pressure']: cleanup_method = ( self.cleanup_aggressive if pressure['pressure_level'] == 'critical' else self.cleanup_basic ) logger.info(f"Auto-cleanup triggered due to {pressure['pressure_level']} " f"memory pressure ({pressure['usage_percent']:.1f}%)") cleanup_method() return True return False def register_cleanup_callback(self, callback: Callable): """Register a callback to run during cleanup operations""" self.cleanup_callbacks.append(callback) logger.debug("Cleanup callback registered") def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None): """Start background memory monitoring""" if self.monitoring_active: logger.warning("Memory monitoring already active") return self.monitoring_active = True def monitor_loop(): while self.monitoring_active: try: pressure = self.check_memory_pressure() if pressure['under_pressure']: logger.warning(f"Memory pressure detected: {pressure['pressure_level']} " f"({pressure['usage_percent']:.1f}%)") if pressure_callback: try: pressure_callback(pressure) except Exception as e: logger.error(f"Pressure callback failed: {e}") # Auto-cleanup on critical pressure if pressure['pressure_level'] == 'critical': self.cleanup_aggressive() time.sleep(interval_seconds) except Exception as e: logger.error(f"Memory monitoring error: {e}") time.sleep(interval_seconds) self.monitoring_thread = threading.Thread(target=monitor_loop, daemon=True) self.monitoring_thread.start() logger.info(f"Memory monitoring started (interval: {interval_seconds}s)") def stop_monitoring(self): """Stop background memory monitoring""" if self.monitoring_active: self.monitoring_active = False if self.monitoring_thread and self.monitoring_thread.is_alive(): self.monitoring_thread.join(timeout=5.0) logger.info("Memory monitoring stopped") def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]: """Estimate memory requirements for video processing""" # Base memory per frame (RGB image) bytes_per_frame = video_width * video_height * 3 # Additional overhead for processing overhead_multiplier = 3.0 # For masks, intermediate results, etc. estimated_memory = { 'frames_memory_gb': (bytes_per_frame * frames_in_memory * overhead_multiplier) / (1024**3), 'model_memory_gb': 4.0, # Rough estimate for SAM2 + MatAnyone 'system_overhead_gb': 2.0, 'total_estimated_gb': 0.0 } estimated_memory['total_estimated_gb'] = sum([ estimated_memory['frames_memory_gb'], estimated_memory['model_memory_gb'], estimated_memory['system_overhead_gb'] ]) return estimated_memory def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]: """Check if video can be processed with current memory""" estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory) current_usage = self.get_memory_usage() # Available memory calculation if self.device.type == 'cuda': available_memory = current_usage.get('gpu_free_gb', 0) else: available_memory = current_usage.get('system_available_gb', 0) can_process = estimate['total_estimated_gb'] <= available_memory result = { 'can_process': can_process, 'estimated_memory_gb': estimate['total_estimated_gb'], 'available_memory_gb': available_memory, 'memory_margin_gb': available_memory - estimate['total_estimated_gb'], 'recommendations': [] } if not can_process: deficit = estimate['total_estimated_gb'] - available_memory result['recommendations'] = [ f"Free {deficit:.1f}GB of memory", "Reduce video resolution", "Process in smaller chunks", "Use lower quality settings" ] elif result['memory_margin_gb'] < 1.0: result['recommendations'] = [ "Memory margin is low", "Monitor memory usage during processing", "Consider reducing batch size" ] return result def get_optimization_suggestions(self) -> Dict[str, Any]: """Get memory optimization suggestions based on current state""" usage = self.get_memory_usage() suggestions = { 'current_usage_percent': usage.get('gpu_utilization_percent', usage.get('system_percent', 0)), 'suggestions': [], 'priority': 'low' # low, medium, high } usage_percent = suggestions['current_usage_percent'] if usage_percent >= 90: suggestions['priority'] = 'high' suggestions['suggestions'].extend([ 'Run aggressive memory cleanup immediately', 'Reduce batch size to 1', 'Enable gradient checkpointing if available', 'Consider switching to CPU processing' ]) elif usage_percent >= 75: suggestions['priority'] = 'medium' suggestions['suggestions'].extend([ 'Run memory cleanup regularly', 'Monitor memory usage closely', 'Reduce keyframe interval', 'Use mixed precision if supported' ]) elif usage_percent >= 50: suggestions['priority'] = 'low' suggestions['suggestions'].extend([ 'Current usage is acceptable', 'Regular cleanup should be sufficient', 'Monitor for memory leaks during long operations' ]) else: suggestions['suggestions'] = [ 'Memory usage is optimal', 'No immediate action required' ] return suggestions def get_stats(self) -> Dict[str, Any]: """Get memory management statistics""" return { 'cleanup_count': self.stats['cleanup_count'], 'peak_memory_usage_gb': self.stats['peak_memory_usage'], 'monitoring_active': self.monitoring_active, 'device_type': self.device.type, 'memory_limit_gb': self.memory_limit_gb, 'registered_callbacks': len(self.cleanup_callbacks) } def __del__(self): """Cleanup when MemoryManager is destroyed""" try: self.stop_monitoring() self.cleanup_aggressive() except Exception: pass # Ignore errors during cleanup