Test / api /ltx /ltx_aduc_pipeline.py
eeuuia's picture
Update api/ltx/ltx_aduc_pipeline.py
24a706e verified
raw
history blame
7.83 kB
# FILE: api/ltx/ltx_aduc_pipeline.py
# DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE)
# to the LTXAducManager pool.
import logging
import time
import torch
import random
from typing import List, Optional, Tuple, Dict
from PIL import Image
from dataclasses import dataclass
from pathlib import Path
import sys
from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils
# O cliente importa o MANAGER para submeter todos os trabalhos.
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
# Adiciona o path do LTX-Video para importações de baixo nível e tipos.
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
add_deps_to_path()
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
# ==============================================================================
# --- DEFINIÇÕES DE ESTRUTURA ---
# ==============================================================================
@dataclass
class LatentConditioningItem:
"""Estrutura de dados para passar latentes condicionados ao job de geração."""
latent_tensor: torch.Tensor
media_frame_number: int
conditioning_strength: float
# ==============================================================================
# --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) ---
# ==============================================================================
def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor:
"""Job que usa o VAE do pipeline para codificar um tensor de pixel."""
vae = pipeline.vae
pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype)
latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
return latents.cpu()
def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor:
"""Job que usa o VAE do pipeline para decodificar um tensor latente."""
vae = pipeline.vae
latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype)
pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
return pixels.cpu()
def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor:
"""Job que usa o pipeline principal para gerar um chunk de vídeo latente."""
generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed'])
pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs}
with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype):
latents_raw = pipeline(**pipeline_kwargs).images
return latents_raw.cpu()
# ==============================================================================
# --- A CLASSE CLIENTE UNIFICADA ---
# ==============================================================================
class LtxAducPipeline:
"""
Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE.
"""
def __init__(self):
logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.")
self.FRAMES_ALIGNMENT = 8
def _get_random_seed(self) -> int:
return random.randint(0, 2**32 - 1)
def _align(self, dim: int, alignment: int = 8) -> int:
return ((dim + alignment - 1) // alignment) * alignment
# --- Métodos de API para o Orquestrador ---
def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]:
"""Converte uma lista de imagens em uma lista de LatentConditioningItem."""
pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list]
items = []
for i, pt in enumerate(pixel_tensors):
latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt)
frame_number, strength = params[i]
items.append(LatentConditioningItem(
latent_tensor=latent_tensor,
media_frame_number=frame_number,
conditioning_strength=strength
))
return items
def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor:
"""Decodifica um tensor latente em um tensor de pixels."""
return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor)
def generate_latents(
self,
prompt_list: List[str],
duration_in_seconds: float,
common_ltx_args: Dict,
initial_conditioning_items: Optional[List[LatentConditioningItem]] = None
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
"""Gera um vídeo latente completo a partir de uma lista de prompts."""
t0 = time.time()
logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.")
used_seed = self._get_random_seed()
num_chunks = len(prompt_list)
total_frames = self._align(int(duration_in_seconds * 24))
frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames
overlap_frames = self._align(9) if num_chunks > 1 else 0
final_latents_list = []
overlap_condition_item = None
for i, chunk_prompt in enumerate(prompt_list):
current_conditions = []
if i == 0 and initial_conditioning_items:
current_conditions.extend(initial_conditioning_items)
if overlap_condition_item:
current_conditions.append(overlap_condition_item)
num_frames_for_chunk = frames_per_chunk_base
if i == num_chunks - 1:
processed_frames = sum(f.shape[2] for f in final_latents_list)
num_frames_for_chunk = total_frames - processed_frames
num_frames_for_chunk = self._align(num_frames_for_chunk)
if num_frames_for_chunk <= 0: continue
job_specific_args = {
"prompt": chunk_prompt,
"num_frames": num_frames_for_chunk,
"seed": used_seed + i,
"conditioning_items": current_conditions
}
final_job_args = {**common_ltx_args, **job_specific_args}
chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args)
if chunk_latents is None:
logging.error(f"Failed to generate latents for scene {i+1}. Aborting.")
return None, used_seed
if i < num_chunks - 1:
overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
overlap_condition_item = LatentConditioningItem(
latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :])
else:
final_latents_list.append(chunk_latents)
if not final_latents_list:
logging.warning("No latent chunks were generated.")
return None, used_seed
final_latents = torch.cat(final_latents_list, dim=2)
logging.info(f"LTX Client job finished in {time.time() - t0:.2f}s. Final latent shape: {final_latents.shape}")
return final_latents, used_seed
# --- INSTÂNCIA SINGLETON DO CLIENTE ---
ltx_aduc_pipeline = LtxAducPipeline()