File size: 13,321 Bytes
ce6bb5b
59aded7
 
ce6bb5b
 
62d73e2
59aded7
62d73e2
 
 
 
59aded7
ce6bb5b
 
59aded7
 
 
 
ce6bb5b
59aded7
 
 
f7961f3
ce6bb5b
 
59aded7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
ce6bb5b
 
59aded7
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
ce6bb5b
59aded7
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
ce6bb5b
59aded7
 
ce6bb5b
59aded7
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
ce6bb5b
59aded7
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
 
 
 
 
ce6bb5b
 
59aded7
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
ce6bb5b
59aded7
 
ce6bb5b
62d73e2
 
 
 
 
 
76d53eb
 
 
 
 
62d73e2
 
59aded7
 
 
 
ce6bb5b
59aded7
 
ce6bb5b
59aded7
 
 
 
 
ce6bb5b
59aded7
 
 
ce6bb5b
59aded7
 
ce6bb5b
59aded7
 
 
ce6bb5b
59aded7
ce6bb5b
 
59aded7
 
62d73e2
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
 
ce6bb5b
 
59aded7
 
 
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
 
 
 
 
 
 
 
ce6bb5b
59aded7
 
 
ce6bb5b
59aded7
 
 
 
 
ce6bb5b
59aded7
 
 
 
ce6bb5b
59aded7
 
 
f7961f3
59aded7
f7961f3
 
59aded7
f7961f3
59aded7
f7961f3
 
 
59aded7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62d73e2
59aded7
62d73e2
59aded7
62d73e2
59aded7
 
76d53eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""
Device Manager for BackgroundFX Pro
Handles device detection, optimization, and hardware compatibility
"""

# CRITICAL: Set OMP_NUM_THREADS before ANY other imports to prevent libgomp error
import os
if 'OMP_NUM_THREADS' not in os.environ:
    os.environ['OMP_NUM_THREADS'] = '4'
    os.environ['MKL_NUM_THREADS'] = '4'

import sys
import platform
import subprocess
import logging
from typing import Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum

import torch
import psutil
import cpuinfo

logger = logging.getLogger(__name__)


class DeviceType(Enum):
    """Enumeration of supported device types"""
    CUDA = "cuda"
    MPS = "mps"
    CPU = "cpu"


@dataclass
class DeviceInfo:
    """Information about a compute device"""
    type: DeviceType
    index: int
    name: str
    memory_total: int
    memory_available: int
    compute_capability: Optional[Tuple[int, int]] = None


class DeviceManager:
    """Manages compute devices and system optimization"""
    
    _instance = None
    
    def __init__(self):
        """Initialize device manager"""
        self.devices = []
        self.optimal_device = None
        self.cpu_info = None
        self.system_info = {}
        
        # Initialize device detection
        self._detect_devices()
        self._gather_system_info()
        self._determine_optimal_device()
    
    def _detect_devices(self):
        """Detect available compute devices"""
        self.devices = []
        
        # Check for CUDA devices
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                self.devices.append(DeviceInfo(
                    type=DeviceType.CUDA,
                    index=i,
                    name=props.name,
                    memory_total=props.total_memory,
                    memory_available=props.total_memory - torch.cuda.memory_allocated(i),
                    compute_capability=(props.major, props.minor)
                ))
        
        # Check for MPS (Apple Silicon)
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            # MPS doesn't provide detailed device info like CUDA
            self.devices.append(DeviceInfo(
                type=DeviceType.MPS,
                index=0,
                name="Apple Silicon GPU",
                memory_total=psutil.virtual_memory().total,
                memory_available=psutil.virtual_memory().available
            ))
        
        # CPU is always available
        try:
            cpu_info = cpuinfo.get_cpu_info()
            cpu_name = cpu_info.get('brand_raw', 'Unknown CPU')
        except:
            cpu_name = platform.processor() or "Unknown CPU"
        
        self.devices.append(DeviceInfo(
            type=DeviceType.CPU,
            index=0,
            name=cpu_name,
            memory_total=psutil.virtual_memory().total,
            memory_available=psutil.virtual_memory().available
        ))
    
    def _gather_system_info(self):
        """Gather system information"""
        try:
            cpu_info = cpuinfo.get_cpu_info()
            self.cpu_info = cpu_info
        except:
            self.cpu_info = {}
        
        self.system_info = {
            'platform': platform.system(),
            'platform_release': platform.release(),
            'platform_version': platform.version(),
            'architecture': platform.machine(),
            'processor': platform.processor(),
            'cpu_count': psutil.cpu_count(logical=False),
            'cpu_count_logical': psutil.cpu_count(logical=True),
            'ram_total': psutil.virtual_memory().total,
            'ram_available': psutil.virtual_memory().available,
            'python_version': sys.version,
            'torch_version': torch.__version__,
        }
    
    def _determine_optimal_device(self):
        """Determine the optimal device for computation"""
        # Priority: CUDA > MPS > CPU
        cuda_devices = [d for d in self.devices if d.type == DeviceType.CUDA]
        mps_devices = [d for d in self.devices if d.type == DeviceType.MPS]
        cpu_devices = [d for d in self.devices if d.type == DeviceType.CPU]
        
        if cuda_devices:
            # Choose CUDA device with most available memory
            self.optimal_device = max(cuda_devices, key=lambda d: d.memory_available)
        elif mps_devices:
            self.optimal_device = mps_devices[0]
        else:
            self.optimal_device = cpu_devices[0]
        
        logger.info(f"Optimal device: {self.optimal_device.name} ({self.optimal_device.type.value})")
    
    def get_optimal_device(self) -> str:
        """Get the optimal device string for PyTorch"""
        if self.optimal_device.type == DeviceType.CUDA:
            return f"cuda:{self.optimal_device.index}"
        elif self.optimal_device.type == DeviceType.MPS:
            return "mps"
        else:
            return "cpu"
    
    def fix_cuda_compatibility(self):
        """Apply CUDA compatibility fixes"""
        if not torch.cuda.is_available():
            logger.info("CUDA not available, skipping compatibility fixes")
            return
        
        try:
            # Set CUDA environment variables for better compatibility
            os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
            
            # For older GPUs, enable TF32 for better performance
            if torch.cuda.is_available():
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True
                
                # Set memory fraction for stability
                if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
                    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
                
                logger.info("CUDA compatibility settings applied")
        except Exception as e:
            logger.warning(f"Error applying CUDA compatibility fixes: {e}")
    
    def setup_optimal_threading(self):
        """Configure optimal threading for the system"""
        try:
            # Skip if already configured (to avoid overwriting the early setting)
            current_omp = os.environ.get('OMP_NUM_THREADS')
            if current_omp and current_omp.isdigit() and int(current_omp) > 0:
                logger.info(f"Threading already configured: OMP_NUM_THREADS={current_omp}")
                # Just ensure PyTorch uses the same settings
                torch.set_num_threads(int(current_omp))
                
                # Ensure MKL matches OMP if it's not set
                if 'MKL_NUM_THREADS' not in os.environ:
                    os.environ['MKL_NUM_THREADS'] = current_omp
                
                return
            
            # Get physical CPU count
            physical_cores = psutil.cpu_count(logical=False)
            if physical_cores is None:
                physical_cores = 4  # Default fallback
            
            # Validate and set the number of threads
            num_threads = str(min(physical_cores, 8))  # Cap at 8 threads
            
            # Set OpenMP threads (validate the value is a positive integer)
            if num_threads.isdigit() and int(num_threads) > 0:
                os.environ['OMP_NUM_THREADS'] = num_threads
            else:
                os.environ['OMP_NUM_THREADS'] = '4'  # Safe default
            
            # Set MKL threads for Intel processors
            if 'intel' in self.system_info.get('processor', '').lower():
                os.environ['MKL_NUM_THREADS'] = os.environ['OMP_NUM_THREADS']
            
            # Set PyTorch threads
            torch.set_num_threads(int(os.environ['OMP_NUM_THREADS']))
            
            # For CUDA, set the number of threads for CPU operations
            if torch.cuda.is_available():
                torch.set_num_interop_threads(2)  # Inter-op parallelism
            
            logger.info(f"Threading configured: OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')}")
            
        except Exception as e:
            logger.warning(f"Error setting up threading: {e}")
            # Set safe defaults
            if 'OMP_NUM_THREADS' not in os.environ:
                os.environ['OMP_NUM_THREADS'] = '4'
            if 'MKL_NUM_THREADS' not in os.environ:
                os.environ['MKL_NUM_THREADS'] = '4'
    
    def get_system_diagnostics(self) -> Dict[str, Any]:
        """Get comprehensive system diagnostics"""
        diagnostics = {
            'system': self.system_info.copy(),
            'devices': [],
            'optimal_device': None,
            'threading': {
                'omp_num_threads': os.environ.get('OMP_NUM_THREADS', 'not set'),
                'mkl_num_threads': os.environ.get('MKL_NUM_THREADS', 'not set'),
                'torch_num_threads': torch.get_num_threads(),
            }
        }
        
        # Add device information
        for device in self.devices:
            device_info = {
                'type': device.type.value,
                'index': device.index,
                'name': device.name,
                'memory_total_gb': device.memory_total / (1024**3),
                'memory_available_gb': device.memory_available / (1024**3),
            }
            if device.compute_capability:
                device_info['compute_capability'] = f"{device.compute_capability[0]}.{device.compute_capability[1]}"
            diagnostics['devices'].append(device_info)
        
        # Add optimal device
        if self.optimal_device:
            diagnostics['optimal_device'] = {
                'type': self.optimal_device.type.value,
                'name': self.optimal_device.name,
                'pytorch_device': self.get_optimal_device()
            }
        
        # Add CUDA-specific diagnostics
        if torch.cuda.is_available():
            diagnostics['cuda'] = {
                'available': True,
                'version': torch.version.cuda,
                'device_count': torch.cuda.device_count(),
                'current_device': torch.cuda.current_device() if torch.cuda.is_initialized() else None,
            }
        else:
            diagnostics['cuda'] = {'available': False}
        
        # Add MPS-specific diagnostics
        if hasattr(torch.backends, 'mps'):
            diagnostics['mps'] = {
                'available': torch.backends.mps.is_available(),
                'built': torch.backends.mps.is_built()
            }
        else:
            diagnostics['mps'] = {'available': False}
        
        return diagnostics
    
    def get_device_for_model(self, model_size_gb: float = 2.0) -> str:
        """Get appropriate device based on model size requirements"""
        required_memory = model_size_gb * 1024**3 * 1.5  # 1.5x for overhead
        
        # Check CUDA devices first
        cuda_devices = [d for d in self.devices if d.type == DeviceType.CUDA]
        for device in cuda_devices:
            if device.memory_available > required_memory:
                return f"cuda:{device.index}"
        
        # Check MPS
        mps_devices = [d for d in self.devices if d.type == DeviceType.MPS]
        if mps_devices and mps_devices[0].memory_available > required_memory:
            return "mps"
        
        # Fallback to CPU
        return "cpu"


# Singleton instance holder
_device_manager_instance = None


def get_device_manager() -> DeviceManager:
    """Get or create the singleton DeviceManager instance"""
    global _device_manager_instance
    if _device_manager_instance is None:
        _device_manager_instance = DeviceManager()
    return _device_manager_instance


def get_optimal_device() -> str:
    """
    Get the optimal device string for PyTorch operations.
    
    Returns:
        str: Device string like 'cuda:0', 'mps', or 'cpu'
    """
    manager = get_device_manager()
    return manager.get_optimal_device()


def fix_cuda_compatibility():
    """
    Apply CUDA compatibility settings for stable operation.
    Sets environment variables and PyTorch settings for CUDA compatibility.
    """
    manager = get_device_manager()
    manager.fix_cuda_compatibility()


def setup_optimal_threading():
    """
    Configure optimal threading settings for the current system.
    Sets OMP_NUM_THREADS, MKL_NUM_THREADS, and PyTorch thread counts.
    """
    manager = get_device_manager()
    manager.setup_optimal_threading()


def get_system_diagnostics() -> Dict[str, Any]:
    """
    Get comprehensive system diagnostics information.
    
    Returns:
        Dict containing system info, device info, and configuration details
    """
    manager = get_device_manager()
    return manager.get_system_diagnostics()


# Initialize and configure on module import
if __name__ != "__main__":
    # When imported, automatically set up the device manager
    try:
        # Get the manager instance (threading is already configured at top of file)
        manager = get_device_manager()
        # Only run setup_optimal_threading if needed (it will check internally)
        manager.setup_optimal_threading()
    except Exception as e:
        logger.warning(f"Error during device manager initialization: {e}")