Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_aduc_pipeline.py | |
| # DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE) | |
| # to the LTXAducManager pool. | |
| import logging | |
| import time | |
| import torch | |
| import random | |
| from typing import List, Optional, Tuple, Dict | |
| from PIL import Image | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import sys | |
| from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils | |
| # O cliente importa o MANAGER para submeter todos os trabalhos. | |
| from api.ltx.ltx_aduc_manager import ltx_aduc_manager | |
| # Adiciona o path do LTX-Video para importações de baixo nível e tipos. | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| add_deps_to_path() | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode | |
| # ============================================================================== | |
| # --- DEFINIÇÕES DE ESTRUTURA --- | |
| # ============================================================================== | |
| class LatentConditioningItem: | |
| """Estrutura de dados para passar latentes condicionados ao job de geração.""" | |
| latent_tensor: torch.Tensor | |
| media_frame_number: int | |
| conditioning_strength: float | |
| # ============================================================================== | |
| # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) --- | |
| # ============================================================================== | |
| def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor: | |
| """Job que usa o VAE do pipeline para codificar um tensor de pixel.""" | |
| vae = pipeline.vae | |
| pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype) | |
| latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True) | |
| return latents.cpu() | |
| def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor: | |
| """Job que usa o VAE do pipeline para decodificar um tensor latente.""" | |
| vae = pipeline.vae | |
| latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype) | |
| pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True) | |
| return pixels.cpu() | |
| def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor: | |
| """Job que usa o pipeline principal para gerar um chunk de vídeo latente.""" | |
| generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed']) | |
| pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs} | |
| with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype): | |
| latents_raw = pipeline(**pipeline_kwargs).images | |
| return latents_raw.cpu() | |
| # ============================================================================== | |
| # --- A CLASSE CLIENTE UNIFICADA --- | |
| # ============================================================================== | |
| class LtxAducPipeline: | |
| """ | |
| Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE. | |
| """ | |
| def __init__(self): | |
| logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.") | |
| self.FRAMES_ALIGNMENT = 8 | |
| def _get_random_seed(self) -> int: | |
| return random.randint(0, 2**32 - 1) | |
| def _align(self, dim: int, alignment: int = 8) -> int: | |
| return ((dim + alignment - 1) // alignment) * alignment | |
| # --- Métodos de API para o Orquestrador --- | |
| def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]: | |
| """Converte uma lista de imagens em uma lista de LatentConditioningItem.""" | |
| pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list] | |
| items = [] | |
| for i, pt in enumerate(pixel_tensors): | |
| latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt) | |
| frame_number, strength = params[i] | |
| items.append(LatentConditioningItem( | |
| latent_tensor=latent_tensor, | |
| media_frame_number=frame_number, | |
| conditioning_strength=strength | |
| )) | |
| return items | |
| def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor: | |
| """Decodifica um tensor latente em um tensor de pixels.""" | |
| return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor) | |
| def generate_latents( | |
| self, | |
| prompt_list: List[str], | |
| duration_in_seconds: float, | |
| common_ltx_args: Dict, | |
| initial_conditioning_items: Optional[List[LatentConditioningItem]] = None | |
| ) -> Tuple[Optional[torch.Tensor], Optional[int]]: | |
| """Gera um vídeo latente completo a partir de uma lista de prompts.""" | |
| t0 = time.time() | |
| logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.") | |
| used_seed = self._get_random_seed() | |
| num_chunks = len(prompt_list) | |
| total_frames = self._align(int(duration_in_seconds * 24)) | |
| frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames | |
| overlap_frames = self._align(9) if num_chunks > 1 else 0 | |
| final_latents_list = [] | |
| overlap_condition_item = None | |
| for i, chunk_prompt in enumerate(prompt_list): | |
| current_conditions = [] | |
| if i == 0 and initial_conditioning_items: | |
| current_conditions.extend(initial_conditioning_items) | |
| if overlap_condition_item: | |
| current_conditions.append(overlap_condition_item) | |
| num_frames_for_chunk = frames_per_chunk_base | |
| if i == num_chunks - 1: | |
| processed_frames = sum(f.shape[2] for f in final_latents_list) | |
| num_frames_for_chunk = total_frames - processed_frames | |
| num_frames_for_chunk = self._align(num_frames_for_chunk) | |
| if num_frames_for_chunk <= 0: continue | |
| job_specific_args = { | |
| "prompt": chunk_prompt, | |
| "num_frames": num_frames_for_chunk, | |
| "seed": used_seed + i, | |
| "conditioning_items": current_conditions | |
| } | |
| final_job_args = {**common_ltx_args, **job_specific_args} | |
| chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args) | |
| if chunk_latents is None: | |
| logging.error(f"Failed to generate latents for scene {i+1}. Aborting.") | |
| return None, used_seed | |
| if i < num_chunks - 1: | |
| overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone() | |
| overlap_condition_item = LatentConditioningItem( | |
| latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0) | |
| final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :]) | |
| else: | |
| final_latents_list.append(chunk_latents) | |
| if not final_latents_list: | |
| logging.warning("No latent chunks were generated.") | |
| return None, used_seed | |
| final_latents = torch.cat(final_latents_list, dim=2) | |
| logging.info(f"LTX Client job finished in {time.time() - t0:.2f}s. Final latent shape: {final_latents.shape}") | |
| return final_latents, used_seed | |
| # --- INSTÂNCIA SINGLETON DO CLIENTE --- | |
| ltx_aduc_pipeline = LtxAducPipeline() | |