MogensR's picture
Update utils/hardware/device_manager.py
76d53eb
"""
Device Manager for BackgroundFX Pro
Handles device detection, optimization, and hardware compatibility
"""
# CRITICAL: Set OMP_NUM_THREADS before ANY other imports to prevent libgomp error
import os
if 'OMP_NUM_THREADS' not in os.environ:
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['MKL_NUM_THREADS'] = '4'
import sys
import platform
import subprocess
import logging
from typing import Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import torch
import psutil
import cpuinfo
logger = logging.getLogger(__name__)
class DeviceType(Enum):
"""Enumeration of supported device types"""
CUDA = "cuda"
MPS = "mps"
CPU = "cpu"
@dataclass
class DeviceInfo:
"""Information about a compute device"""
type: DeviceType
index: int
name: str
memory_total: int
memory_available: int
compute_capability: Optional[Tuple[int, int]] = None
class DeviceManager:
"""Manages compute devices and system optimization"""
_instance = None
def __init__(self):
"""Initialize device manager"""
self.devices = []
self.optimal_device = None
self.cpu_info = None
self.system_info = {}
# Initialize device detection
self._detect_devices()
self._gather_system_info()
self._determine_optimal_device()
def _detect_devices(self):
"""Detect available compute devices"""
self.devices = []
# Check for CUDA devices
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
self.devices.append(DeviceInfo(
type=DeviceType.CUDA,
index=i,
name=props.name,
memory_total=props.total_memory,
memory_available=props.total_memory - torch.cuda.memory_allocated(i),
compute_capability=(props.major, props.minor)
))
# Check for MPS (Apple Silicon)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
# MPS doesn't provide detailed device info like CUDA
self.devices.append(DeviceInfo(
type=DeviceType.MPS,
index=0,
name="Apple Silicon GPU",
memory_total=psutil.virtual_memory().total,
memory_available=psutil.virtual_memory().available
))
# CPU is always available
try:
cpu_info = cpuinfo.get_cpu_info()
cpu_name = cpu_info.get('brand_raw', 'Unknown CPU')
except:
cpu_name = platform.processor() or "Unknown CPU"
self.devices.append(DeviceInfo(
type=DeviceType.CPU,
index=0,
name=cpu_name,
memory_total=psutil.virtual_memory().total,
memory_available=psutil.virtual_memory().available
))
def _gather_system_info(self):
"""Gather system information"""
try:
cpu_info = cpuinfo.get_cpu_info()
self.cpu_info = cpu_info
except:
self.cpu_info = {}
self.system_info = {
'platform': platform.system(),
'platform_release': platform.release(),
'platform_version': platform.version(),
'architecture': platform.machine(),
'processor': platform.processor(),
'cpu_count': psutil.cpu_count(logical=False),
'cpu_count_logical': psutil.cpu_count(logical=True),
'ram_total': psutil.virtual_memory().total,
'ram_available': psutil.virtual_memory().available,
'python_version': sys.version,
'torch_version': torch.__version__,
}
def _determine_optimal_device(self):
"""Determine the optimal device for computation"""
# Priority: CUDA > MPS > CPU
cuda_devices = [d for d in self.devices if d.type == DeviceType.CUDA]
mps_devices = [d for d in self.devices if d.type == DeviceType.MPS]
cpu_devices = [d for d in self.devices if d.type == DeviceType.CPU]
if cuda_devices:
# Choose CUDA device with most available memory
self.optimal_device = max(cuda_devices, key=lambda d: d.memory_available)
elif mps_devices:
self.optimal_device = mps_devices[0]
else:
self.optimal_device = cpu_devices[0]
logger.info(f"Optimal device: {self.optimal_device.name} ({self.optimal_device.type.value})")
def get_optimal_device(self) -> str:
"""Get the optimal device string for PyTorch"""
if self.optimal_device.type == DeviceType.CUDA:
return f"cuda:{self.optimal_device.index}"
elif self.optimal_device.type == DeviceType.MPS:
return "mps"
else:
return "cpu"
def fix_cuda_compatibility(self):
"""Apply CUDA compatibility fixes"""
if not torch.cuda.is_available():
logger.info("CUDA not available, skipping compatibility fixes")
return
try:
# Set CUDA environment variables for better compatibility
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# For older GPUs, enable TF32 for better performance
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set memory fraction for stability
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
logger.info("CUDA compatibility settings applied")
except Exception as e:
logger.warning(f"Error applying CUDA compatibility fixes: {e}")
def setup_optimal_threading(self):
"""Configure optimal threading for the system"""
try:
# Skip if already configured (to avoid overwriting the early setting)
current_omp = os.environ.get('OMP_NUM_THREADS')
if current_omp and current_omp.isdigit() and int(current_omp) > 0:
logger.info(f"Threading already configured: OMP_NUM_THREADS={current_omp}")
# Just ensure PyTorch uses the same settings
torch.set_num_threads(int(current_omp))
# Ensure MKL matches OMP if it's not set
if 'MKL_NUM_THREADS' not in os.environ:
os.environ['MKL_NUM_THREADS'] = current_omp
return
# Get physical CPU count
physical_cores = psutil.cpu_count(logical=False)
if physical_cores is None:
physical_cores = 4 # Default fallback
# Validate and set the number of threads
num_threads = str(min(physical_cores, 8)) # Cap at 8 threads
# Set OpenMP threads (validate the value is a positive integer)
if num_threads.isdigit() and int(num_threads) > 0:
os.environ['OMP_NUM_THREADS'] = num_threads
else:
os.environ['OMP_NUM_THREADS'] = '4' # Safe default
# Set MKL threads for Intel processors
if 'intel' in self.system_info.get('processor', '').lower():
os.environ['MKL_NUM_THREADS'] = os.environ['OMP_NUM_THREADS']
# Set PyTorch threads
torch.set_num_threads(int(os.environ['OMP_NUM_THREADS']))
# For CUDA, set the number of threads for CPU operations
if torch.cuda.is_available():
torch.set_num_interop_threads(2) # Inter-op parallelism
logger.info(f"Threading configured: OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')}")
except Exception as e:
logger.warning(f"Error setting up threading: {e}")
# Set safe defaults
if 'OMP_NUM_THREADS' not in os.environ:
os.environ['OMP_NUM_THREADS'] = '4'
if 'MKL_NUM_THREADS' not in os.environ:
os.environ['MKL_NUM_THREADS'] = '4'
def get_system_diagnostics(self) -> Dict[str, Any]:
"""Get comprehensive system diagnostics"""
diagnostics = {
'system': self.system_info.copy(),
'devices': [],
'optimal_device': None,
'threading': {
'omp_num_threads': os.environ.get('OMP_NUM_THREADS', 'not set'),
'mkl_num_threads': os.environ.get('MKL_NUM_THREADS', 'not set'),
'torch_num_threads': torch.get_num_threads(),
}
}
# Add device information
for device in self.devices:
device_info = {
'type': device.type.value,
'index': device.index,
'name': device.name,
'memory_total_gb': device.memory_total / (1024**3),
'memory_available_gb': device.memory_available / (1024**3),
}
if device.compute_capability:
device_info['compute_capability'] = f"{device.compute_capability[0]}.{device.compute_capability[1]}"
diagnostics['devices'].append(device_info)
# Add optimal device
if self.optimal_device:
diagnostics['optimal_device'] = {
'type': self.optimal_device.type.value,
'name': self.optimal_device.name,
'pytorch_device': self.get_optimal_device()
}
# Add CUDA-specific diagnostics
if torch.cuda.is_available():
diagnostics['cuda'] = {
'available': True,
'version': torch.version.cuda,
'device_count': torch.cuda.device_count(),
'current_device': torch.cuda.current_device() if torch.cuda.is_initialized() else None,
}
else:
diagnostics['cuda'] = {'available': False}
# Add MPS-specific diagnostics
if hasattr(torch.backends, 'mps'):
diagnostics['mps'] = {
'available': torch.backends.mps.is_available(),
'built': torch.backends.mps.is_built()
}
else:
diagnostics['mps'] = {'available': False}
return diagnostics
def get_device_for_model(self, model_size_gb: float = 2.0) -> str:
"""Get appropriate device based on model size requirements"""
required_memory = model_size_gb * 1024**3 * 1.5 # 1.5x for overhead
# Check CUDA devices first
cuda_devices = [d for d in self.devices if d.type == DeviceType.CUDA]
for device in cuda_devices:
if device.memory_available > required_memory:
return f"cuda:{device.index}"
# Check MPS
mps_devices = [d for d in self.devices if d.type == DeviceType.MPS]
if mps_devices and mps_devices[0].memory_available > required_memory:
return "mps"
# Fallback to CPU
return "cpu"
# Singleton instance holder
_device_manager_instance = None
def get_device_manager() -> DeviceManager:
"""Get or create the singleton DeviceManager instance"""
global _device_manager_instance
if _device_manager_instance is None:
_device_manager_instance = DeviceManager()
return _device_manager_instance
def get_optimal_device() -> str:
"""
Get the optimal device string for PyTorch operations.
Returns:
str: Device string like 'cuda:0', 'mps', or 'cpu'
"""
manager = get_device_manager()
return manager.get_optimal_device()
def fix_cuda_compatibility():
"""
Apply CUDA compatibility settings for stable operation.
Sets environment variables and PyTorch settings for CUDA compatibility.
"""
manager = get_device_manager()
manager.fix_cuda_compatibility()
def setup_optimal_threading():
"""
Configure optimal threading settings for the current system.
Sets OMP_NUM_THREADS, MKL_NUM_THREADS, and PyTorch thread counts.
"""
manager = get_device_manager()
manager.setup_optimal_threading()
def get_system_diagnostics() -> Dict[str, Any]:
"""
Get comprehensive system diagnostics information.
Returns:
Dict containing system info, device info, and configuration details
"""
manager = get_device_manager()
return manager.get_system_diagnostics()
# Initialize and configure on module import
if __name__ != "__main__":
# When imported, automatically set up the device manager
try:
# Get the manager instance (threading is already configured at top of file)
manager = get_device_manager()
# Only run setup_optimal_threading if needed (it will check internally)
manager.setup_optimal_threading()
except Exception as e:
logger.warning(f"Error during device manager initialization: {e}")