Spaces:
Running
on
Zero
Running
on
Zero
| import hashlib | |
| import json | |
| import time | |
| import threading | |
| from collections import OrderedDict | |
| import torch | |
| from ..hloc import logger | |
| class ARCSizeAwareModelCache: | |
| def __init__( | |
| self, | |
| max_gpu_mem: float = 48e9, | |
| max_cpu_mem: float = 48e9, | |
| device_priority: list = ["cuda", "cpu"], | |
| auto_empty_cache: bool = True, | |
| ): | |
| """ | |
| Initialize the model cache. | |
| Args: | |
| max_gpu_mem: Maximum GPU memory allowed in bytes. | |
| max_cpu_mem: Maximum CPU memory allowed in bytes. | |
| device_priority: List of devices to prioritize when evicting models. | |
| auto_empty_cache: Whether to call torch.cuda.empty_cache() when out of memory. | |
| """ | |
| self.t1 = OrderedDict() | |
| self.t2 = OrderedDict() | |
| self.b1 = OrderedDict() | |
| self.b2 = OrderedDict() | |
| self.max_gpu = max_gpu_mem | |
| self.max_cpu = max_cpu_mem | |
| self.current_gpu = 0 | |
| self.current_cpu = 0 | |
| self.p = 0 | |
| self.adaptive_factor = 0.5 | |
| self.device_priority = device_priority | |
| self.lock = threading.Lock() | |
| self.auto_empty_cache = auto_empty_cache | |
| logger.info( | |
| "ARCSizeAwareModelCache initialized. Max GPU: {} GB, Max CPU: {} GB".format( | |
| max_gpu_mem / 1e9, max_cpu_mem / 1e9 | |
| ) | |
| ) | |
| def _release_model(self, model_entry): | |
| """ | |
| Release a model from memory. | |
| Args: | |
| model_entry: A dictionary containing the model, device and other information. | |
| Notes: | |
| If the device is CUDA and auto_empty_cache is True, torch.cuda.empty_cache() is called after releasing the model. | |
| """ | |
| model = model_entry["model"] | |
| device = model_entry["device"] | |
| del model | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| if self.auto_empty_cache: | |
| torch.cuda.empty_cache() | |
| def generate_key(self, model_key, model_conf: dict) -> str: | |
| loader_identifier = f"{model_key}" | |
| unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}" | |
| return hashlib.sha256(unique_str.encode()).hexdigest() | |
| def _get_device(self, model_size: int) -> str: | |
| for device in self.device_priority: | |
| if device == "cuda" and torch.cuda.is_available(): | |
| if self.current_gpu + model_size <= self.max_gpu: | |
| return "cuda" | |
| elif device == "cpu": | |
| if self.current_cpu + model_size <= self.max_cpu: | |
| return "cpu" | |
| return "cpu" | |
| def _calculate_model_size(self, model): | |
| return sum(p.numel() * p.element_size() for p in model.parameters()) + sum( | |
| b.numel() * b.element_size() for b in model.buffers() | |
| ) | |
| def _update_access(self, key: str, size: int, device: str): | |
| if key in self.b1: | |
| self.p = min( | |
| self.p + max(1, len(self.b2) // len(self.b1)), | |
| len(self.t1) + len(self.t2), | |
| ) | |
| self.b1.pop(key) | |
| self._replace(False) | |
| elif key in self.b2: | |
| self.p = max(self.p - max(1, len(self.b1) // len(self.b2)), 0) | |
| self.b2.pop(key) | |
| self._replace(True) | |
| if key in self.t1: | |
| self.t1.pop(key) | |
| self.t2[key] = { | |
| "size": size, | |
| "device": device, | |
| "access_count": 1, | |
| "last_accessed": time.time(), | |
| } | |
| def _replace(self, in_t2: bool): | |
| if len(self.t1) > 0 and ( | |
| (len(self.t1) > self.p) or (in_t2 and len(self.t1) == self.p) | |
| ): | |
| k, v = self.t1.popitem(last=False) | |
| self.b1[k] = v | |
| else: | |
| k, v = self.t2.popitem(last=False) | |
| self.b2[k] = v | |
| def _calculate_weight(self, entry) -> float: | |
| return entry["access_count"] / entry["size"] | |
| def _evict_models(self, required_size: int, target_device: str) -> bool: | |
| candidates = [] | |
| for k, v in list(self.t1.items()) + list(self.t2.items()): | |
| if v["device"] == target_device: | |
| candidates.append((k, v)) | |
| candidates.sort(key=lambda x: self._calculate_weight(x[1])) | |
| freed = 0 | |
| for k, v in candidates: | |
| self._release_model(v) | |
| freed += v["size"] | |
| if v in self.t1: | |
| self.t1.pop(k) | |
| if v in self.t2: | |
| self.t2.pop(k) | |
| if v["device"] == "cuda": | |
| self.current_gpu -= v["size"] | |
| else: | |
| self.current_cpu -= v["size"] | |
| if freed >= required_size: | |
| return True | |
| if target_device == "cuda": | |
| return self._cross_device_evict(required_size, "cuda") | |
| return False | |
| def _cross_device_evict(self, required_size: int, target_device: str) -> bool: | |
| all_entries = [] | |
| for k, v in list(self.t1.items()) + list(self.t2.items()): | |
| all_entries.append((k, v)) | |
| all_entries.sort( | |
| key=lambda x: self._calculate_weight(x[1]) | |
| + (0.5 if x[1]["device"] == target_device else 0) | |
| ) | |
| freed = 0 | |
| for k, v in all_entries: | |
| freed += v["size"] | |
| if v in self.t1: | |
| self.t1.pop(k) | |
| if v in self.t2: | |
| self.t2.pop(k) | |
| if v["device"] == "cuda": | |
| self.current_gpu -= v["size"] | |
| else: | |
| self.current_cpu -= v["size"] | |
| if freed >= required_size: | |
| return True | |
| return False | |
| def load_model(self, model_key, model_loader_func, model_conf: dict): | |
| key = self.generate_key(model_key, model_conf) | |
| with self.lock: | |
| if key in self.t1 or key in self.t2: | |
| entry = self.t1.pop(key, None) or self.t2.pop(key) | |
| entry["access_count"] += 1 | |
| self.t2[key] = entry | |
| return entry["model"] | |
| raw_model = model_loader_func(model_conf) | |
| model_size = self._calculate_model_size(raw_model) | |
| device = self._get_device(model_size) | |
| if device == "cuda" and self.auto_empty_cache: | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| while True: | |
| current_mem = self.current_gpu if device == "cuda" else self.current_cpu | |
| max_mem = self.max_gpu if device == "cuda" else self.max_cpu | |
| if current_mem + model_size <= max_mem: | |
| break | |
| if not self._evict_models(model_size, device): | |
| if device == "cuda": | |
| device = "cpu" | |
| else: | |
| raise RuntimeError("Out of memory") | |
| try: | |
| model = raw_model.to(device) | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| torch.cuda.empty_cache() | |
| model = raw_model.to(device) | |
| new_entry = { | |
| "model": model, | |
| "size": model_size, | |
| "device": device, | |
| "access_count": 1, | |
| "last_accessed": time.time(), | |
| } | |
| if key in self.b1 or key in self.b2: | |
| self.t2[key] = new_entry | |
| self._replace(True) | |
| else: | |
| self.t1[key] = new_entry | |
| self._replace(False) | |
| if device == "cuda": | |
| self.current_gpu += model_size | |
| else: | |
| self.current_cpu += model_size | |
| return model | |
| def clear_device_cache(self, device: str): | |
| with self.lock: | |
| for cache in [self.t1, self.t2, self.b1, self.b2]: | |
| for k in list(cache.keys()): | |
| if cache[k]["device"] == device: | |
| cache.pop(k) | |
| class LRUModelCache: | |
| def __init__( | |
| self, | |
| max_gpu_mem: float = 8e9, | |
| max_cpu_mem: float = 12e9, | |
| device_priority: list = ["cuda", "cpu"], | |
| ): | |
| self.cache = OrderedDict() | |
| self.max_gpu = max_gpu_mem | |
| self.max_cpu = max_cpu_mem | |
| self.current_gpu = 0 | |
| self.current_cpu = 0 | |
| self.lock = threading.Lock() | |
| self.device_priority = device_priority | |
| def generate_key(self, model_key, model_conf: dict) -> str: | |
| loader_identifier = f"{model_key}" | |
| unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}" | |
| return hashlib.sha256(unique_str.encode()).hexdigest() | |
| def get_device(self) -> str: | |
| for device in self.device_priority: | |
| if device == "cuda" and torch.cuda.is_available(): | |
| if self.current_gpu < self.max_gpu: | |
| return device | |
| elif device == "cpu": | |
| if self.current_cpu < self.max_cpu: | |
| return device | |
| return "cpu" | |
| def _calculate_model_size(self, model): | |
| param_size = sum(p.numel() * p.element_size() for p in model.parameters()) | |
| buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) | |
| return param_size + buffer_size | |
| def load_model(self, model_key, model_loader_func, model_conf: dict): | |
| key = self.generate_key(model_key, model_conf) | |
| with self.lock: | |
| if key in self.cache: | |
| self.cache.move_to_end(key) # update LRU | |
| return self.cache[key]["model"] | |
| device = self.get_device() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| try: | |
| raw_model = model_loader_func(model_conf) | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| try: | |
| model = raw_model.to(device) | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| return self._handle_oom(model_key, model_loader_func, model_conf) | |
| raise | |
| model_size = self._calculate_model_size(model) | |
| while ( | |
| device == "cuda" and (self.current_gpu + model_size > self.max_gpu) | |
| ) or (device == "cpu" and (self.current_cpu + model_size > self.max_cpu)): | |
| if not self._free_space(model_size, device): | |
| raise RuntimeError("Insufficient memory even after cache cleanup") | |
| if device == "cuda": | |
| self.current_gpu += model_size | |
| else: | |
| self.current_cpu += model_size | |
| self.cache[key] = { | |
| "model": model, | |
| "size": model_size, | |
| "device": device, | |
| "timestamp": time.time(), | |
| } | |
| return model | |
| def _free_space(self, required_size: int, device: str) -> bool: | |
| for key in list(self.cache.keys()): | |
| if (device == "cuda" and self.cache[key]["device"] == "cuda") or ( | |
| device == "cpu" and self.cache[key]["device"] == "cpu" | |
| ): | |
| self.current_gpu -= ( | |
| self.cache[key]["size"] | |
| if self.cache[key]["device"] == "cuda" | |
| else 0 | |
| ) | |
| self.current_cpu -= ( | |
| self.cache[key]["size"] if self.cache[key]["device"] == "cpu" else 0 | |
| ) | |
| del self.cache[key] | |
| if ( | |
| device == "cuda" | |
| and self.current_gpu + required_size <= self.max_gpu | |
| ) or ( | |
| device == "cpu" and self.current_cpu + required_size <= self.max_cpu | |
| ): | |
| return True | |
| return False | |
| def _handle_oom(self, model_key, model_loader_func, model_conf: dict): | |
| with self.lock: | |
| self.clear_device_cache("cuda") | |
| torch.cuda.empty_cache() | |
| try: | |
| return self.load_model(model_key, model_loader_func, model_conf) | |
| except RuntimeError: | |
| original_priority = self.device_priority | |
| self.device_priority = ["cpu"] | |
| try: | |
| return self.load_model(model_key, model_loader_func, model_conf) | |
| finally: | |
| self.device_priority = original_priority | |
| def clear_device_cache(self, device: str): | |
| with self.lock: | |
| keys_to_remove = [k for k, v in self.cache.items() if v["device"] == device] | |
| for k in keys_to_remove: | |
| self.current_gpu -= self.cache[k]["size"] if device == "cuda" else 0 | |
| self.current_cpu -= self.cache[k]["size"] if device == "cpu" else 0 | |
| del self.cache[k] | |