euiia commited on
Commit
b269a3a
·
verified ·
1 Parent(s): 563a353

Create vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +87 -0
managers/vae_manager.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # managers/vae_manager.py
2
+ #
3
+ # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 1.0.0
6
+ #
7
+ # This file defines the VaeManager specialist. Its purpose is to abstract all
8
+ # direct interactions with the Variational Autoencoder (VAE) model. It handles
9
+ # the model's state (CPU/GPU memory), provides clean interfaces for encoding and
10
+ # decoding, and ensures that the heavy VAE model only occupies VRAM when actively
11
+ # performing a task, freeing up resources for other specialists.
12
+
13
+ import torch
14
+ import logging
15
+ import gc
16
+ from typing import Generator
17
+
18
+ # Import the source of the VAE model and the low-level functions
19
+ from managers.ltx_manager import ltx_manager_singleton
20
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class VaeManager:
25
+ """
26
+ A specialist for managing the LTX VAE model. It provides high-level methods
27
+ for encoding pixels to latents and decoding latents to pixels, while managing
28
+ the model's presence on the GPU to conserve VRAM.
29
+ """
30
+ def __init__(self, vae_model):
31
+ self.vae = vae_model
32
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ self.cpu_device = torch.device('cpu')
34
+
35
+ # Initialize the VAE on the CPU to keep VRAM free at startup
36
+ self.vae.to(self.cpu_device)
37
+ logger.info(f"VaeManager initialized. VAE model is on CPU.")
38
+
39
+ def to_gpu(self):
40
+ """Moves the VAE model to the active GPU."""
41
+ if self.device == 'cpu': return
42
+ logger.info("VaeManager: Moving VAE to GPU...")
43
+ self.vae.to(self.device)
44
+
45
+ def to_cpu(self):
46
+ """Moves the VAE model to the CPU and clears VRAM cache."""
47
+ if self.device == 'cpu': return
48
+ logger.info("VaeManager: Unloading VAE from GPU...")
49
+ self.vae.to(self.cpu_device)
50
+ gc.collect()
51
+ if torch.cuda.is_available():
52
+ torch.cuda.empty_cache()
53
+
54
+ @torch.no_grad()
55
+ def encode(self, pixel_tensor: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Encodes a pixel-space tensor to the latent space.
58
+ Manages moving the VAE to and from the GPU.
59
+ """
60
+ try:
61
+ self.to_gpu()
62
+ pixel_tensor = pixel_tensor.to(self.device, dtype=self.vae.dtype)
63
+ latents = vae_encode(pixel_tensor, self.vae, vae_per_channel_normalize=True)
64
+ return latents.to(self.cpu_device) # Return to CPU to free VRAM
65
+ finally:
66
+ self.to_cpu()
67
+
68
+ @torch.no_grad()
69
+ def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
70
+ """
71
+ Decodes a latent-space tensor to pixels.
72
+ Manages moving the VAE to and from the GPU.
73
+ """
74
+ try:
75
+ self.to_gpu()
76
+ latent_tensor = latent_tensor.to(self.device, dtype=self.vae.dtype)
77
+ timestep_tensor = torch.tensor([decode_timestep] * latent_tensor.shape[0], device=self.device, dtype=latent_tensor.dtype)
78
+ pixels = vae_decode(latent_tensor, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True)
79
+ return pixels.to(self.cpu_device) # Return to CPU to free VRAM
80
+ finally:
81
+ self.to_cpu()
82
+
83
+ # --- Singleton Instance ---
84
+ # The VaeManager must use the exact same VAE instance as the LTX pipeline to ensure
85
+ # latent space compatibility. We source it directly from the already-initialized ltx_manager.
86
+ source_vae_model = ltx_manager_singleton.workers[0].pipeline.vae
87
+ vae_manager_singleton = VaeManager(source_vae_model)