import torch import comfy.sample import comfy.model_management import comfy.utils import gc import logging import nodes from typing import Dict, Union import time from contextlib import contextmanager import psutil class MemoryManager: """Manages memory resources for efficient video processing.""" def __init__(self, device=None, log_level: str = "INFO"): self.logger = logging.getLogger("MemoryManager") if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(getattr(logging, log_level)) self.device = device or comfy.model_management.get_torch_device() self.logger.info(f"Using device: {self.device}") # Memory thresholds (percentages) self.warning_threshold = 85 self.critical_threshold = 95 def is_cuda_device(self) -> bool: """Check if the current device is a CUDA device.""" if isinstance(self.device, str): return self.device.startswith("cuda") elif isinstance(self.device, torch.device): return self.device.type == "cuda" return False def get_memory_stats(self) -> Dict[str, Union[int, float]]: """Get current memory statistics for the device.""" stats = {} if self.is_cuda_device() and torch.cuda.is_available(): try: t = torch.cuda.get_device_properties(0) stats["total"] = t.total_memory stats["reserved"] = torch.cuda.memory_reserved(0) stats["allocated"] = torch.cuda.memory_allocated(0) stats["free"] = stats["total"] - stats["reserved"] stats["usage_percent"] = (stats["allocated"] / stats["total"]) * 100 except Exception as e: self.logger.error(f"Error getting CUDA memory stats: {e}") stats = {"error": str(e)} else: # CPU memory stats vm = psutil.virtual_memory() stats["total"] = vm.total stats["available"] = vm.available stats["used"] = vm.used stats["free"] = vm.free stats["usage_percent"] = vm.percent return stats def is_memory_critical(self) -> bool: """Check if memory usage is at critical levels.""" stats = self.get_memory_stats() if "error" in stats: return True # Assume critical if we can't get stats return stats.get("usage_percent", 0) > self.critical_threshold @contextmanager def track_memory(self, label: str = "Operation"): """Context manager to track memory usage before and after an operation.""" if self.is_cuda_device() and torch.cuda.is_available(): start_mem = torch.cuda.memory_allocated() start_time = time.time() try: yield finally: end_mem = torch.cuda.memory_allocated() end_time = time.time() self.logger.info(f"{label} - Memory change: {(end_mem-start_mem)/1024**2:.2f}MB, Time: {end_time-start_time:.2f}s") else: start_time = time.time() try: yield finally: end_time = time.time() self.logger.info(f"{label} - Time: {end_time-start_time:.2f}s") def cleanup(self, force: bool = False): """Clean up memory resources.""" if self.is_cuda_device() and torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() if force: # More aggressive cleanup for obj in gc.get_objects(): try: if torch.is_tensor(obj) and not obj.is_cuda: del obj except: pass gc.collect() if self.is_cuda_device() and torch.cuda.is_available(): torch.cuda.empty_cache() class WanVideoKsampler: """ Video K-sampler node with memory management for processing video latents. """ @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "video_latents": ("LATENT",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "denoise": ("FLOAT", {"default": 1, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" def __init__(self): self.logger = logging.getLogger("WanVideoKsampler") if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) # Initialize memory manager self.memory_manager = None def sample( self, model, video_latents: Dict[str, torch.Tensor], positive, negative, seed: int, steps: int, cfg: float, sampler_name: str, scheduler: str, denoise: float ) -> Dict[str, torch.Tensor]: """ Sample video frames with memory management. Args: model: Diffusion model video_latents: Dictionary containing latent tensors positive: Positive conditioning negative: Negative conditioning seed: Random seed steps: Number of sampling steps cfg: Classifier-free guidance scale sampler_name: Name of sampler to use scheduler: Name of scheduler to use denoise: Denoising strength Returns: Dictionary containing processed latent tensors """ start_time = time.time() device = comfy.model_management.get_torch_device() # Initialize memory manager if needed if self.memory_manager is None: self.memory_manager = MemoryManager(device) # Log latent size for debugging if isinstance(video_latents, dict) and 'samples' in video_latents: latent_samples = video_latents['samples'] total_frames = latent_samples.shape[0] self.logger.info(f"Processing latent shape: {latent_samples.shape}, total frames: {total_frames}") else: self.logger.error("Invalid latent format") raise ValueError("Expected latent dictionary with 'samples' key") self.logger.info(f"Processing with {steps} steps, {cfg} CFG, {sampler_name} sampler") try: # Process with memory tracking with self.memory_manager.track_memory("Video processing"): # Check memory usage before processing memory_stats = self.memory_manager.get_memory_stats() if "usage_percent" in memory_stats: self.logger.info(f"Memory usage before processing: {memory_stats['usage_percent']:.1f}%") # Apply sampling result = nodes.common_ksampler( model, seed, steps, cfg, sampler_name, scheduler, positive, negative, video_latents, denoise=denoise ) # Clear memory after processing self.memory_manager.cleanup() # Check memory usage after processing memory_stats = self.memory_manager.get_memory_stats() if "usage_percent" in memory_stats: self.logger.info(f"Memory usage after processing: {memory_stats['usage_percent']:.1f}%") end_time = time.time() self.logger.info(f"Complete: {total_frames} frames in {end_time - start_time:.2f}s ({(end_time - start_time) / total_frames:.2f}s per frame)") return result except Exception as e: self.logger.error(f"Error during processing: {str(e)}") # Try to release memory self.memory_manager.cleanup(force=True) # Check if it's an out-of-memory error if "CUDA out of memory" in str(e): self.logger.error("Out of memory error. Consider reducing frame count or model complexity.") raise e # Node registration NODE_CLASS_MAPPINGS = { "WanVideoKsampler": WanVideoKsampler, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoKsampler": "Wan Video Ksampler", }