Spaces:
Paused
Paused
Update api/ltx/vae_aduc_pipeline.py
Browse files- api/ltx/vae_aduc_pipeline.py +40 -80
api/ltx/vae_aduc_pipeline.py
CHANGED
|
@@ -46,8 +46,8 @@ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
|
|
| 46 |
|
| 47 |
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
|
| 48 |
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 49 |
-
from
|
| 50 |
-
|
| 51 |
|
| 52 |
@dataclass
|
| 53 |
class LatentConditioningItem:
|
|
@@ -59,7 +59,6 @@ class LatentConditioningItem:
|
|
| 59 |
# --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
|
| 60 |
# ==============================================================================
|
| 61 |
|
| 62 |
-
@log_function_io
|
| 63 |
class VaeAducPipeline:
|
| 64 |
_instance = None
|
| 65 |
_lock = threading.Lock()
|
|
@@ -71,7 +70,6 @@ class VaeAducPipeline:
|
|
| 71 |
cls._instance._initialized = False
|
| 72 |
return cls._instance
|
| 73 |
|
| 74 |
-
@log_function_io
|
| 75 |
def __init__(self):
|
| 76 |
if hasattr(self, '_initialized') and self._initialized: return
|
| 77 |
with self._lock:
|
|
@@ -101,68 +99,24 @@ class VaeAducPipeline:
|
|
| 101 |
# --- MÉTODOS PÚBLICOS DE SERVIÇO ---
|
| 102 |
|
| 103 |
@log_function_io
|
| 104 |
-
def encode_video(
|
| 105 |
-
self,
|
| 106 |
-
video_tensor: torch.Tensor,
|
| 107 |
-
vae_per_channel_normalize: bool = True
|
| 108 |
-
) -> torch.Tensor:
|
| 109 |
-
"""
|
| 110 |
-
[NOVO] Codifica um tensor de vídeo (pixels) para o espaço latente.
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
video_tensor (torch.Tensor): Tensor de vídeo no formato (B, C, F, H, W) e range [0, 1].
|
| 114 |
-
vae_per_channel_normalize (bool): Se deve normalizar os latentes por canal.
|
| 115 |
-
|
| 116 |
-
Returns:
|
| 117 |
-
torch.Tensor: O tensor latente resultante na CPU.
|
| 118 |
-
"""
|
| 119 |
logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
|
| 120 |
if not (video_tensor.ndim == 5):
|
| 121 |
raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
|
| 122 |
-
|
| 123 |
-
# Normaliza o tensor de [0, 1] para [-1, 1]
|
| 124 |
video_tensor_normalized = (video_tensor * 2.0) - 1.0
|
| 125 |
-
|
| 126 |
try:
|
| 127 |
video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
|
| 128 |
with torch.no_grad():
|
| 129 |
-
latents = vae_encode(
|
| 130 |
-
video_gpu,
|
| 131 |
-
self.vae,
|
| 132 |
-
vae_per_channel_normalize=vae_per_channel_normalize
|
| 133 |
-
)
|
| 134 |
logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
|
| 135 |
return latents.cpu()
|
| 136 |
finally:
|
| 137 |
self._cleanup_gpu()
|
| 138 |
|
| 139 |
@log_function_io
|
| 140 |
-
def decode_and_resize_video(
|
| 141 |
-
self,
|
| 142 |
-
latent_tensor: torch.Tensor,
|
| 143 |
-
target_height: int,
|
| 144 |
-
target_width: int,
|
| 145 |
-
decode_timestep: float = 0.05
|
| 146 |
-
) -> torch.Tensor:
|
| 147 |
-
"""
|
| 148 |
-
[NOVO] Decodifica um tensor latente para pixels e o redimensiona para a resolução final.
|
| 149 |
-
|
| 150 |
-
Args:
|
| 151 |
-
latent_tensor (torch.Tensor): O tensor latente a ser decodificado.
|
| 152 |
-
target_height (int): A altura final do vídeo.
|
| 153 |
-
target_width (int): A largura final do vídeo.
|
| 154 |
-
decode_timestep (float): Timestep para o decoder do VAE, se aplicável.
|
| 155 |
-
|
| 156 |
-
Returns:
|
| 157 |
-
torch.Tensor: O tensor de vídeo em pixels, redimensionado e na CPU.
|
| 158 |
-
"""
|
| 159 |
logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
|
| 160 |
-
|
| 161 |
-
# 1. Decodificar para pixels (usando a função já existente)
|
| 162 |
-
# O resultado já virá para a CPU
|
| 163 |
pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
|
| 164 |
-
|
| 165 |
-
# 2. Redimensionar para o tamanho final
|
| 166 |
num_frames = pixel_video.shape[2]
|
| 167 |
current_height, current_width = pixel_video.shape[3:]
|
| 168 |
|
|
@@ -170,33 +124,21 @@ class VaeAducPipeline:
|
|
| 170 |
logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
|
| 171 |
return pixel_video
|
| 172 |
|
| 173 |
-
# Aplica a interpolação para redimensionar
|
| 174 |
videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
|
| 175 |
-
videos_resized = F.interpolate(
|
| 176 |
-
videos_flat,
|
| 177 |
-
size=(target_height, target_width),
|
| 178 |
-
mode="bilinear",
|
| 179 |
-
align_corners=False,
|
| 180 |
-
)
|
| 181 |
final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
|
| 182 |
-
|
| 183 |
logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
|
| 184 |
return final_video
|
| 185 |
|
| 186 |
@log_function_io
|
| 187 |
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 188 |
-
"""Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
|
| 189 |
t0 = time.time()
|
| 190 |
try:
|
| 191 |
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
|
| 192 |
num_items = latent_tensor_gpu.shape[0]
|
| 193 |
timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
|
| 194 |
-
|
| 195 |
with torch.no_grad():
|
| 196 |
-
pixels = vae_decode(
|
| 197 |
-
latent_tensor_gpu, self.vae, is_video=True,
|
| 198 |
-
timestep=timestep_tensor, vae_per_channel_normalize=True
|
| 199 |
-
)
|
| 200 |
logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
|
| 201 |
return pixels.cpu()
|
| 202 |
finally:
|
|
@@ -213,7 +155,6 @@ class VaeAducPipeline:
|
|
| 213 |
vae_per_channel_normalize: bool = True,
|
| 214 |
generator: Optional[torch.Generator] = None,
|
| 215 |
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
| 216 |
-
"""Prepara tensores de condicionamento a partir de uma lista de itens de pixels ou latentes."""
|
| 217 |
init_latents = init_latents.to(self.device, dtype=self.dtype)
|
| 218 |
|
| 219 |
if not conditioning_items:
|
|
@@ -236,9 +177,14 @@ class VaeAducPipeline:
|
|
| 236 |
init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
|
| 237 |
mask[..., :f, :h, :w] = item.conditioning_strength
|
| 238 |
else:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
else:
|
| 243 |
for item in conditioning_items:
|
| 244 |
item_resized = self._resize_conditioning_item(item, height, width)
|
|
@@ -252,7 +198,9 @@ class VaeAducPipeline:
|
|
| 252 |
mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
|
| 253 |
else:
|
| 254 |
if media_item.shape[2] > 1:
|
| 255 |
-
init_latents, mask, latents = self._handle_non_first_sequence(
|
|
|
|
|
|
|
| 256 |
if latents is not None:
|
| 257 |
latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
|
| 258 |
extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
|
|
@@ -282,7 +230,7 @@ class VaeAducPipeline:
|
|
| 282 |
if torch.cuda.is_available():
|
| 283 |
with torch.cuda.device(self.device): torch.cuda.empty_cache()
|
| 284 |
|
| 285 |
-
def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, self.transformer.config.causal_temporal_positioning)
|
| 286 |
|
| 287 |
@staticmethod
|
| 288 |
def _resize_tensor(m, h, w):
|
|
@@ -307,15 +255,27 @@ class VaeAducPipeline:
|
|
| 307 |
if (ys + hi) < h: l = l[..., :-1, :]
|
| 308 |
return l, xs // s, ys // s
|
| 309 |
|
| 310 |
-
def _handle_non_first_sequence(
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
if fl > flp:
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
def _process_extra_item(self, l, i, g):
|
| 321 |
n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
|
|
@@ -326,5 +286,5 @@ class VaeAducPipeline:
|
|
| 326 |
nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
|
| 327 |
return lp, cp, nm, nl
|
| 328 |
|
| 329 |
-
# ---
|
| 330 |
vae_aduc_pipeline = VaeAducPipeline()
|
|
|
|
| 46 |
|
| 47 |
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
|
| 48 |
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 49 |
+
from pipeline_ltx_video import ConditioningItem as PipelineConditioningItem
|
| 50 |
+
|
| 51 |
|
| 52 |
@dataclass
|
| 53 |
class LatentConditioningItem:
|
|
|
|
| 59 |
# --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
|
| 60 |
# ==============================================================================
|
| 61 |
|
|
|
|
| 62 |
class VaeAducPipeline:
|
| 63 |
_instance = None
|
| 64 |
_lock = threading.Lock()
|
|
|
|
| 70 |
cls._instance._initialized = False
|
| 71 |
return cls._instance
|
| 72 |
|
|
|
|
| 73 |
def __init__(self):
|
| 74 |
if hasattr(self, '_initialized') and self._initialized: return
|
| 75 |
with self._lock:
|
|
|
|
| 99 |
# --- MÉTODOS PÚBLICOS DE SERVIÇO ---
|
| 100 |
|
| 101 |
@log_function_io
|
| 102 |
+
def encode_video(self, video_tensor: torch.Tensor, vae_per_channel_normalize: bool = True) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
|
| 104 |
if not (video_tensor.ndim == 5):
|
| 105 |
raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
|
|
|
|
|
|
|
| 106 |
video_tensor_normalized = (video_tensor * 2.0) - 1.0
|
|
|
|
| 107 |
try:
|
| 108 |
video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
|
| 109 |
with torch.no_grad():
|
| 110 |
+
latents = vae_encode(video_gpu, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
|
| 112 |
return latents.cpu()
|
| 113 |
finally:
|
| 114 |
self._cleanup_gpu()
|
| 115 |
|
| 116 |
@log_function_io
|
| 117 |
+
def decode_and_resize_video(self, latent_tensor: torch.Tensor, target_height: int, target_width: int, decode_timestep: float = 0.05) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
|
|
|
|
|
|
|
|
|
|
| 119 |
pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
|
|
|
|
|
|
|
| 120 |
num_frames = pixel_video.shape[2]
|
| 121 |
current_height, current_width = pixel_video.shape[3:]
|
| 122 |
|
|
|
|
| 124 |
logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
|
| 125 |
return pixel_video
|
| 126 |
|
|
|
|
| 127 |
videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
|
| 128 |
+
videos_resized = F.interpolate(videos_flat, size=(target_height, target_width), mode="bilinear", align_corners=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
|
|
|
|
| 130 |
logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
|
| 131 |
return final_video
|
| 132 |
|
| 133 |
@log_function_io
|
| 134 |
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
|
|
|
| 135 |
t0 = time.time()
|
| 136 |
try:
|
| 137 |
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
|
| 138 |
num_items = latent_tensor_gpu.shape[0]
|
| 139 |
timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
|
|
|
|
| 140 |
with torch.no_grad():
|
| 141 |
+
pixels = vae_decode(latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True)
|
|
|
|
|
|
|
|
|
|
| 142 |
logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
|
| 143 |
return pixels.cpu()
|
| 144 |
finally:
|
|
|
|
| 155 |
vae_per_channel_normalize: bool = True,
|
| 156 |
generator: Optional[torch.Generator] = None,
|
| 157 |
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
|
|
|
| 158 |
init_latents = init_latents.to(self.device, dtype=self.dtype)
|
| 159 |
|
| 160 |
if not conditioning_items:
|
|
|
|
| 177 |
init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
|
| 178 |
mask[..., :f, :h, :w] = item.conditioning_strength
|
| 179 |
else:
|
| 180 |
+
if latents.shape[2] > 1:
|
| 181 |
+
init_latents, mask, latents = self._handle_non_first_sequence(
|
| 182 |
+
init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
|
| 183 |
+
)
|
| 184 |
+
if latents is not None:
|
| 185 |
+
latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
|
| 186 |
+
extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
|
| 187 |
+
num_extra_latents += num_new
|
| 188 |
else:
|
| 189 |
for item in conditioning_items:
|
| 190 |
item_resized = self._resize_conditioning_item(item, height, width)
|
|
|
|
| 198 |
mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
|
| 199 |
else:
|
| 200 |
if media_item.shape[2] > 1:
|
| 201 |
+
init_latents, mask, latents = self._handle_non_first_sequence(
|
| 202 |
+
init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
|
| 203 |
+
)
|
| 204 |
if latents is not None:
|
| 205 |
latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
|
| 206 |
extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
|
|
|
|
| 230 |
if torch.cuda.is_available():
|
| 231 |
with torch.cuda.device(self.device): torch.cuda.empty_cache()
|
| 232 |
|
| 233 |
+
def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 234 |
|
| 235 |
@staticmethod
|
| 236 |
def _resize_tensor(m, h, w):
|
|
|
|
| 255 |
if (ys + hi) < h: l = l[..., :-1, :]
|
| 256 |
return l, xs // s, ys // s
|
| 257 |
|
| 258 |
+
def _handle_non_first_sequence(
|
| 259 |
+
self,
|
| 260 |
+
init_latents: torch.Tensor,
|
| 261 |
+
mask: torch.Tensor,
|
| 262 |
+
latents: torch.Tensor,
|
| 263 |
+
media_frame_number: int,
|
| 264 |
+
conditioning_strength: float,
|
| 265 |
+
num_prefix=2,
|
| 266 |
+
mode="concat"
|
| 267 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 268 |
+
fl, flp = latents.shape[2], num_prefix
|
| 269 |
if fl > flp:
|
| 270 |
+
start = media_frame_number // 8 + flp
|
| 271 |
+
end = start + fl - flp
|
| 272 |
+
init_latents[..., start:end, :, :] = torch.lerp(init_latents[..., start:end, :, :], latents[..., flp:, :, :], conditioning_strength)
|
| 273 |
+
mask[..., start:end, :, :] = conditioning_strength
|
| 274 |
+
if mode == "concat":
|
| 275 |
+
latents = latents[..., :flp, :, :]
|
| 276 |
+
else:
|
| 277 |
+
latents = None
|
| 278 |
+
return init_latents, mask, latents
|
| 279 |
|
| 280 |
def _process_extra_item(self, l, i, g):
|
| 281 |
n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
|
|
|
|
| 286 |
nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
|
| 287 |
return lp, cp, nm, nl
|
| 288 |
|
| 289 |
+
# --- Instânciação do Singleton ---
|
| 290 |
vae_aduc_pipeline = VaeAducPipeline()
|