|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import logging |
|
|
import gc |
|
|
from typing import Generator |
|
|
|
|
|
|
|
|
from managers.ltx_manager import ltx_manager_singleton |
|
|
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class VaeManager: |
|
|
""" |
|
|
A specialist for managing the LTX VAE model. It provides high-level methods |
|
|
for encoding pixels to latents and decoding latents to pixels, while managing |
|
|
the model's presence on the GPU to conserve VRAM. |
|
|
""" |
|
|
def __init__(self, vae_model): |
|
|
self.vae = vae_model |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.cpu_device = torch.device('cpu') |
|
|
|
|
|
|
|
|
self.vae.to(self.cpu_device) |
|
|
logger.info(f"VaeManager initialized. VAE model is on CPU.") |
|
|
|
|
|
def to_gpu(self): |
|
|
"""Moves the VAE model to the active GPU.""" |
|
|
if self.device == 'cpu': return |
|
|
logger.info("VaeManager: Moving VAE to GPU...") |
|
|
self.vae.to(self.device) |
|
|
|
|
|
def to_cpu(self): |
|
|
"""Moves the VAE model to the CPU and clears VRAM cache.""" |
|
|
if self.device == 'cpu': return |
|
|
logger.info("VaeManager: Unloading VAE from GPU...") |
|
|
self.vae.to(self.cpu_device) |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, pixel_tensor: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encodes a pixel-space tensor to the latent space. |
|
|
Manages moving the VAE to and from the GPU. |
|
|
""" |
|
|
try: |
|
|
self.to_gpu() |
|
|
pixel_tensor = pixel_tensor.to(self.device, dtype=self.vae.dtype) |
|
|
latents = vae_encode(pixel_tensor, self.vae, vae_per_channel_normalize=True) |
|
|
return latents.to(self.cpu_device) |
|
|
finally: |
|
|
self.to_cpu() |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: |
|
|
""" |
|
|
Decodes a latent-space tensor to pixels. |
|
|
Manages moving the VAE to and from the GPU. |
|
|
""" |
|
|
try: |
|
|
self.to_gpu() |
|
|
latent_tensor = latent_tensor.to(self.device, dtype=self.vae.dtype) |
|
|
timestep_tensor = torch.tensor([decode_timestep] * latent_tensor.shape[0], device=self.device, dtype=latent_tensor.dtype) |
|
|
pixels = vae_decode(latent_tensor, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True) |
|
|
return pixels.to(self.cpu_device) |
|
|
finally: |
|
|
self.to_cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source_vae_model = ltx_manager_singleton.workers[0].pipeline.vae |
|
|
vae_manager_singleton = VaeManager(source_vae_model) |