Spaces:
Paused
Paused
Update api/ltx/vae_aduc_pipeline.py
Browse files- api/ltx/vae_aduc_pipeline.py +248 -88
api/ltx/vae_aduc_pipeline.py
CHANGED
|
@@ -6,50 +6,60 @@
|
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import time
|
| 9 |
-
|
| 10 |
import threading
|
| 11 |
from pathlib import Path
|
| 12 |
-
from typing import List, Union, Tuple,
|
| 13 |
-
import
|
|
|
|
| 14 |
import torch
|
| 15 |
import numpy as np
|
| 16 |
from PIL import Image
|
| 17 |
-
from
|
| 18 |
-
|
| 19 |
-
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
|
| 20 |
|
|
|
|
| 21 |
from utils.debug_utils import log_function_io
|
|
|
|
| 22 |
|
| 23 |
import logging
|
| 24 |
import warnings
|
|
|
|
|
|
|
| 25 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 26 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 27 |
warnings.filterwarnings("ignore", message=".*")
|
| 28 |
-
from huggingface_hub import logging as ll
|
| 29 |
-
ll.set_verbosity_error()
|
| 30 |
-
ll.set_verbosity_warning()
|
| 31 |
-
ll.set_verbosity_info()
|
| 32 |
-
ll.set_verbosity_debug()
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# ==============================================================================
|
| 36 |
-
# --- IMPORTAÇÕES
|
| 37 |
# ==============================================================================
|
| 38 |
|
| 39 |
-
# Adiciona o path para as bibliotecas do LTX
|
| 40 |
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
|
| 41 |
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
|
| 42 |
sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
|
| 43 |
|
| 44 |
-
|
| 45 |
-
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
|
| 46 |
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
|
|
|
|
|
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# ==============================================================================
|
| 50 |
-
# --- CLASSE DO SERVIÇO VAE ---
|
| 51 |
# ==============================================================================
|
| 52 |
|
|
|
|
| 53 |
class VaeAducPipeline:
|
| 54 |
_instance = None
|
| 55 |
_lock = threading.Lock()
|
|
@@ -61,110 +71,260 @@ class VaeAducPipeline:
|
|
| 61 |
cls._instance._initialized = False
|
| 62 |
return cls._instance
|
| 63 |
|
|
|
|
| 64 |
def __init__(self):
|
| 65 |
-
if self._initialized: return
|
| 66 |
with self._lock:
|
| 67 |
-
if self._initialized: return
|
| 68 |
-
|
| 69 |
-
logging.info("⚙️ Initializing VaeServer Singleton...")
|
| 70 |
t0 = time.time()
|
| 71 |
-
|
| 72 |
-
# 1. Obter o dispositivo VAE dedicado do gerenciador central
|
| 73 |
self.device = gpu_manager.get_ltx_vae_device()
|
| 74 |
|
| 75 |
-
# 2. Obter o modelo VAE já carregado pelo LTXPoolManager
|
| 76 |
-
# Isso garante consistência e evita carregar o modelo duas vezes.
|
| 77 |
try:
|
| 78 |
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
-
logging.critical(f"Failed to get
|
| 84 |
raise
|
| 85 |
|
| 86 |
-
|
| 87 |
-
self.vae.to(self.device)
|
| 88 |
-
self.vae.eval()
|
| 89 |
self.dtype = self.vae.dtype
|
| 90 |
-
|
| 91 |
self._initialized = True
|
| 92 |
-
logging.info(f"✅
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def _cleanup_gpu(self):
|
| 96 |
-
"""Limpa a VRAM da GPU do VAE."""
|
| 97 |
-
if torch.cuda.is_available():
|
| 98 |
-
with torch.cuda.device(self.device):
|
| 99 |
-
torch.cuda.empty_cache()
|
| 100 |
-
|
| 101 |
-
@log_function_io
|
| 102 |
-
def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
|
| 103 |
-
"""Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar."""
|
| 104 |
-
if isinstance(item, Image.Image):
|
| 105 |
-
from PIL import ImageOps
|
| 106 |
-
img = item.convert("RGB")
|
| 107 |
-
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
|
| 108 |
-
image_np = np.array(processed_img).astype(np.float32) / 255.0
|
| 109 |
-
tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
|
| 110 |
-
elif isinstance(item, torch.Tensor):
|
| 111 |
-
if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0)
|
| 112 |
-
elif item.ndim == 3: tensor = item
|
| 113 |
-
else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
|
| 114 |
-
else:
|
| 115 |
-
raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")
|
| 116 |
-
|
| 117 |
-
# Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
|
| 118 |
-
tensor_5d = tensor.unsqueeze(0).unsqueeze(2)
|
| 119 |
-
return (tensor_5d * 2.0) - 1.0
|
| 120 |
|
| 121 |
@log_function_io
|
| 122 |
-
def
|
| 123 |
self,
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
target_resolution: Tuple[int, int]
|
| 128 |
-
) -> List[LatentConditioningItem]:
|
| 129 |
"""
|
| 130 |
-
[
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
-
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
|
| 139 |
-
conditioning_items = []
|
| 140 |
try:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
| 149 |
finally:
|
| 150 |
self._cleanup_gpu()
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
@log_function_io
|
| 153 |
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 154 |
"""Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
|
| 155 |
t0 = time.time()
|
| 156 |
try:
|
| 157 |
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
|
| 158 |
-
|
| 159 |
-
timestep_tensor = torch.tensor([decode_timestep] *
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
return pixels.cpu()
|
| 167 |
finally:
|
| 168 |
self._cleanup_gpu()
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
vae_aduc_pipeline = VaeAducPipeline()
|
|
|
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
import time
|
| 9 |
+
import copy
|
| 10 |
import threading
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import List, Union, Tuple, Optional
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
|
| 15 |
import torch
|
| 16 |
import numpy as np
|
| 17 |
from PIL import Image
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
import torch.nn.functional as F
|
|
|
|
| 20 |
|
| 21 |
+
from managers.gpu_manager import gpu_manager
|
| 22 |
from utils.debug_utils import log_function_io
|
| 23 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 24 |
|
| 25 |
import logging
|
| 26 |
import warnings
|
| 27 |
+
|
| 28 |
+
# --- Configuração de Logging e Warnings ---
|
| 29 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 30 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 31 |
warnings.filterwarnings("ignore", message=".*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
try:
|
| 34 |
+
from huggingface_hub import logging as hf_logging
|
| 35 |
+
hf_logging.set_verbosity_error()
|
| 36 |
+
except ImportError:
|
| 37 |
+
pass
|
| 38 |
|
| 39 |
# ==============================================================================
|
| 40 |
+
# --- IMPORTAÇÕES E DEFINIÇÕES DE TIPO ---
|
| 41 |
# ==============================================================================
|
| 42 |
|
|
|
|
| 43 |
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
|
| 44 |
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
|
| 45 |
sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
|
| 46 |
|
| 47 |
+
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
|
|
|
|
| 48 |
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 49 |
+
from pipeline_ltx_video import ConditioningItem as PipelineConditioningItem
|
| 50 |
+
|
| 51 |
|
| 52 |
+
@dataclass
|
| 53 |
+
class LatentConditioningItem:
|
| 54 |
+
latent_tensor: torch.Tensor
|
| 55 |
+
media_frame_number: int
|
| 56 |
+
conditioning_strength: float
|
| 57 |
|
| 58 |
# ==============================================================================
|
| 59 |
+
# --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
|
| 60 |
# ==============================================================================
|
| 61 |
|
| 62 |
+
@log_function_io
|
| 63 |
class VaeAducPipeline:
|
| 64 |
_instance = None
|
| 65 |
_lock = threading.Lock()
|
|
|
|
| 71 |
cls._instance._initialized = False
|
| 72 |
return cls._instance
|
| 73 |
|
| 74 |
+
@log_function_io
|
| 75 |
def __init__(self):
|
| 76 |
+
if hasattr(self, '_initialized') and self._initialized: return
|
| 77 |
with self._lock:
|
| 78 |
+
if hasattr(self, '_initialized') and self._initialized: return
|
| 79 |
+
logging.info("⚙️ Initializing VaeAducPipeline Singleton...")
|
|
|
|
| 80 |
t0 = time.time()
|
|
|
|
|
|
|
| 81 |
self.device = gpu_manager.get_ltx_vae_device()
|
| 82 |
|
|
|
|
|
|
|
| 83 |
try:
|
| 84 |
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
|
| 85 |
+
main_pipeline = ltx_aduc_manager.get_pipeline()
|
| 86 |
+
if main_pipeline is None:
|
| 87 |
+
raise RuntimeError("LTXPoolManager must be initialized before VaeAducPipeline.")
|
| 88 |
+
self.vae: CausalVideoAutoencoder = main_pipeline.vae
|
| 89 |
+
self.patchifier = main_pipeline.patchifier
|
| 90 |
+
self.transformer = main_pipeline.transformer
|
| 91 |
+
self.vae_scale_factor = main_pipeline.vae_scale_factor
|
| 92 |
except Exception as e:
|
| 93 |
+
logging.critical(f"Failed to get components from LTXPoolManager. Error: {e}", exc_info=True)
|
| 94 |
raise
|
| 95 |
|
| 96 |
+
self.vae.to(self.device).eval()
|
|
|
|
|
|
|
| 97 |
self.dtype = self.vae.dtype
|
|
|
|
| 98 |
self._initialized = True
|
| 99 |
+
logging.info(f"✅ VaeAducPipeline ready. Components are 'hot' on {self.device}. Startup time: {time.time() - t0:.2f}s")
|
| 100 |
+
|
| 101 |
+
# --- MÉTODOS PÚBLICOS DE SERVIÇO ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
@log_function_io
|
| 104 |
+
def encode_video(
|
| 105 |
self,
|
| 106 |
+
video_tensor: torch.Tensor,
|
| 107 |
+
vae_per_channel_normalize: bool = True
|
| 108 |
+
) -> torch.Tensor:
|
|
|
|
|
|
|
| 109 |
"""
|
| 110 |
+
[NOVO] Codifica um tensor de vídeo (pixels) para o espaço latente.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
video_tensor (torch.Tensor): Tensor de vídeo no formato (B, C, F, H, W) e range [0, 1].
|
| 114 |
+
vae_per_channel_normalize (bool): Se deve normalizar os latentes por canal.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
torch.Tensor: O tensor latente resultante na CPU.
|
| 118 |
"""
|
| 119 |
+
logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
|
| 120 |
+
if not (video_tensor.ndim == 5):
|
| 121 |
+
raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
|
| 122 |
|
| 123 |
+
# Normaliza o tensor de [0, 1] para [-1, 1]
|
| 124 |
+
video_tensor_normalized = (video_tensor * 2.0) - 1.0
|
| 125 |
|
|
|
|
| 126 |
try:
|
| 127 |
+
video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
latents = vae_encode(
|
| 130 |
+
video_gpu,
|
| 131 |
+
self.vae,
|
| 132 |
+
vae_per_channel_normalize=vae_per_channel_normalize
|
| 133 |
+
)
|
| 134 |
+
logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
|
| 135 |
+
return latents.cpu()
|
| 136 |
finally:
|
| 137 |
self._cleanup_gpu()
|
| 138 |
|
| 139 |
+
@log_function_io
|
| 140 |
+
def decode_and_resize_video(
|
| 141 |
+
self,
|
| 142 |
+
latent_tensor: torch.Tensor,
|
| 143 |
+
target_height: int,
|
| 144 |
+
target_width: int,
|
| 145 |
+
decode_timestep: float = 0.05
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
"""
|
| 148 |
+
[NOVO] Decodifica um tensor latente para pixels e o redimensiona para a resolução final.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
latent_tensor (torch.Tensor): O tensor latente a ser decodificado.
|
| 152 |
+
target_height (int): A altura final do vídeo.
|
| 153 |
+
target_width (int): A largura final do vídeo.
|
| 154 |
+
decode_timestep (float): Timestep para o decoder do VAE, se aplicável.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
torch.Tensor: O tensor de vídeo em pixels, redimensionado e na CPU.
|
| 158 |
+
"""
|
| 159 |
+
logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
|
| 160 |
+
|
| 161 |
+
# 1. Decodificar para pixels (usando a função já existente)
|
| 162 |
+
# O resultado já virá para a CPU
|
| 163 |
+
pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
|
| 164 |
+
|
| 165 |
+
# 2. Redimensionar para o tamanho final
|
| 166 |
+
num_frames = pixel_video.shape[2]
|
| 167 |
+
current_height, current_width = pixel_video.shape[3:]
|
| 168 |
+
|
| 169 |
+
if current_height == target_height and current_width == target_width:
|
| 170 |
+
logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
|
| 171 |
+
return pixel_video
|
| 172 |
+
|
| 173 |
+
# Aplica a interpolação para redimensionar
|
| 174 |
+
videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
|
| 175 |
+
videos_resized = F.interpolate(
|
| 176 |
+
videos_flat,
|
| 177 |
+
size=(target_height, target_width),
|
| 178 |
+
mode="bilinear",
|
| 179 |
+
align_corners=False,
|
| 180 |
+
)
|
| 181 |
+
final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
|
| 182 |
+
|
| 183 |
+
logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
|
| 184 |
+
return final_video
|
| 185 |
+
|
| 186 |
@log_function_io
|
| 187 |
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 188 |
"""Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
|
| 189 |
t0 = time.time()
|
| 190 |
try:
|
| 191 |
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
|
| 192 |
+
num_items = latent_tensor_gpu.shape[0]
|
| 193 |
+
timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
|
| 194 |
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
pixels = vae_decode(
|
| 197 |
+
latent_tensor_gpu, self.vae, is_video=True,
|
| 198 |
+
timestep=timestep_tensor, vae_per_channel_normalize=True
|
| 199 |
+
)
|
| 200 |
+
logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
|
| 201 |
return pixels.cpu()
|
| 202 |
finally:
|
| 203 |
self._cleanup_gpu()
|
| 204 |
|
| 205 |
+
@log_function_io
|
| 206 |
+
def prepare_conditioning(
|
| 207 |
+
self,
|
| 208 |
+
conditioning_items: Optional[List[Union[PipelineConditioningItem, LatentConditioningItem]]],
|
| 209 |
+
init_latents: torch.Tensor,
|
| 210 |
+
num_frames: int,
|
| 211 |
+
height: int,
|
| 212 |
+
width: int,
|
| 213 |
+
vae_per_channel_normalize: bool = True,
|
| 214 |
+
generator: Optional[torch.Generator] = None,
|
| 215 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
| 216 |
+
"""Prepara tensores de condicionamento a partir de uma lista de itens de pixels ou latentes."""
|
| 217 |
+
init_latents = init_latents.to(self.device, dtype=self.dtype)
|
| 218 |
+
|
| 219 |
+
if not conditioning_items:
|
| 220 |
+
latents_p, coords_l = self.patchifier.patchify(latents=init_latents)
|
| 221 |
+
coords_p = self._latent_to_pixel_coords(coords_l)
|
| 222 |
+
return latents_p.cpu(), coords_p.cpu(), None, 0
|
| 223 |
+
|
| 224 |
+
mask = torch.zeros(init_latents.shape[0], *init_latents.shape[2:], dtype=torch.float32, device=self.device)
|
| 225 |
+
extra_latents, extra_coords, extra_masks = [], [], []
|
| 226 |
+
num_extra_latents = 0
|
| 227 |
+
|
| 228 |
+
is_latent_mode = isinstance(conditioning_items[0], LatentConditioningItem)
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
if is_latent_mode:
|
| 232 |
+
for item in conditioning_items:
|
| 233 |
+
latents = item.latent_tensor.to(device=self.device, dtype=self.dtype)
|
| 234 |
+
if item.media_frame_number == 0:
|
| 235 |
+
f, h, w = latents.shape[-3:]
|
| 236 |
+
init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
|
| 237 |
+
mask[..., :f, :h, :w] = item.conditioning_strength
|
| 238 |
+
else:
|
| 239 |
+
latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
|
| 240 |
+
extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
|
| 241 |
+
num_extra_latents += num_new
|
| 242 |
+
else:
|
| 243 |
+
for item in conditioning_items:
|
| 244 |
+
item_resized = self._resize_conditioning_item(item, height, width)
|
| 245 |
+
media_item = item_resized.media_item.to(self.device, dtype=self.dtype)
|
| 246 |
+
latents = vae_encode(media_item, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
|
| 247 |
+
|
| 248 |
+
if item.media_frame_number == 0:
|
| 249 |
+
latents_pos, lx, ly = self._get_latent_spatial_position(latents, item_resized, height, width)
|
| 250 |
+
f, h, w = latents_pos.shape[-3:]
|
| 251 |
+
init_latents[..., :f, ly:ly+h, lx:lx+w] = torch.lerp(init_latents[..., :f, ly:ly+h, lx:lx+w], latents_pos, item.conditioning_strength)
|
| 252 |
+
mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
|
| 253 |
+
else:
|
| 254 |
+
if media_item.shape[2] > 1:
|
| 255 |
+
init_latents, mask, latents = self._handle_non_first_sequence(init_latents, mask, latents, item)
|
| 256 |
+
if latents is not None:
|
| 257 |
+
latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
|
| 258 |
+
extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
|
| 259 |
+
num_extra_latents += num_new
|
| 260 |
+
|
| 261 |
+
# --- Consolidação final ---
|
| 262 |
+
latents_p, coords_l = self.patchifier.patchify(latents=init_latents)
|
| 263 |
+
coords_p = self._latent_to_pixel_coords(coords_l)
|
| 264 |
+
mask_p, _ = self.patchifier.patchify(latents=mask.unsqueeze(1))
|
| 265 |
+
mask_p = mask_p.squeeze(-1)
|
| 266 |
+
|
| 267 |
+
if extra_latents:
|
| 268 |
+
latents_p = torch.cat([*extra_latents, latents_p], dim=1)
|
| 269 |
+
coords_p = torch.cat([*extra_coords, coords_p], dim=2)
|
| 270 |
+
mask_p = torch.cat([*extra_masks, mask_p], dim=1)
|
| 271 |
+
|
| 272 |
+
use_flash = getattr(self.transformer.config, 'use_tpu_flash_attention', False)
|
| 273 |
+
if use_flash:
|
| 274 |
+
latents_p = latents_p[:, :-num_extra_latents]
|
| 275 |
+
coords_p = coords_p[:, :, :-num_extra_latents]
|
| 276 |
+
mask_p = mask_p[:, :-num_extra_latents]
|
| 277 |
+
|
| 278 |
+
return latents_p.cpu(), coords_p.cpu(), mask_p.cpu(), num_extra_latents
|
| 279 |
+
|
| 280 |
+
# --- MÉTODOS PRIVADOS AUXILIARES ---
|
| 281 |
+
def _cleanup_gpu(self):
|
| 282 |
+
if torch.cuda.is_available():
|
| 283 |
+
with torch.cuda.device(self.device): torch.cuda.empty_cache()
|
| 284 |
+
|
| 285 |
+
def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, self.transformer.config.causal_temporal_positioning)
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def _resize_tensor(m, h, w):
|
| 289 |
+
if m.shape[-2:] != (h, w):
|
| 290 |
+
n = m.shape[2]
|
| 291 |
+
flat = rearrange(m, "b c n h w -> (b n) c h w")
|
| 292 |
+
resized = F.interpolate(flat, (h, w), mode="bilinear", align_corners=False)
|
| 293 |
+
return rearrange(resized, "(b n) c h w -> b c n h w", n=n)
|
| 294 |
+
return m
|
| 295 |
+
|
| 296 |
+
def _resize_conditioning_item(self, i, h, w):
|
| 297 |
+
n = copy.copy(i); n.media_item = self._resize_tensor(i.media_item, h, w); return n
|
| 298 |
+
|
| 299 |
+
def _get_latent_spatial_position(self, l, i, h, w, strip=True):
|
| 300 |
+
s, hi, wi = self.vae_scale_factor, i.media_item.shape[-2], i.media_item.shape[-1]
|
| 301 |
+
xs = (w - wi) // 2 if i.media_x is None else i.media_x
|
| 302 |
+
ys = (h - hi) // 2 if i.media_y is None else i.media_y
|
| 303 |
+
if strip:
|
| 304 |
+
if xs > 0: xs += s; l = l[..., :, 1:]
|
| 305 |
+
if ys > 0: ys += s; l = l[..., 1:, :]
|
| 306 |
+
if (xs + wi) < w: l = l[..., :, :-1]
|
| 307 |
+
if (ys + hi) < h: l = l[..., :-1, :]
|
| 308 |
+
return l, xs // s, ys // s
|
| 309 |
+
|
| 310 |
+
def _handle_non_first_sequence(self, il, m, l, i, np=2, mode="concat"):
|
| 311 |
+
fl, flp = l.shape[2], np
|
| 312 |
+
if fl > flp:
|
| 313 |
+
s, e = i.media_frame_number // 8 + flp, i.media_frame_number // 8 + fl
|
| 314 |
+
il[..., s:e, :, :] = torch.lerp(il[..., s:e, :, :], l[..., flp:, :, :], i.conditioning_strength)
|
| 315 |
+
m[..., s:e, :, :] = i.conditioning_strength
|
| 316 |
+
if mode == "concat": l = l[..., :flp, :, :]
|
| 317 |
+
else: l = None
|
| 318 |
+
return il, m, l
|
| 319 |
+
|
| 320 |
+
def _process_extra_item(self, l, i, g):
|
| 321 |
+
n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
|
| 322 |
+
l = torch.lerp(n, l, i.conditioning_strength)
|
| 323 |
+
lp, cl = self.patchifier.patchify(l)
|
| 324 |
+
cp = self._latent_to_pixel_coords(cl); cp[:, 0] += i.media_frame_number
|
| 325 |
+
nl = lp.shape[1]
|
| 326 |
+
nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
|
| 327 |
+
return lp, cp, nm, nl
|
| 328 |
+
|
| 329 |
+
# --- Instanciação do Singleton ---
|
| 330 |
vae_aduc_pipeline = VaeAducPipeline()
|