File size: 7,827 Bytes
8c98072
9a6b3d7
 
c9413de
 
 
 
4f4406c
9a6b3d7
 
 
 
 
4f4406c
9a6b3d7
 
 
4f4406c
 
9a6b3d7
 
 
 
 
 
 
 
 
 
 
140e6ff
9a6b3d7
 
 
140e6ff
9a6b3d7
 
 
 
 
 
c9413de
 
4f4406c
c9413de
 
9a6b3d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e13ea4b
9a6b3d7
 
 
4f4406c
c9413de
 
9a6b3d7
c9413de
 
fdf33ba
c9413de
9a6b3d7
c9413de
 
9a6b3d7
4f4406c
c9413de
4f4406c
 
c9413de
4f4406c
 
c9413de
9a6b3d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9413de
 
9a6b3d7
 
 
4f4406c
9a6b3d7
4f4406c
9a6b3d7
c9413de
 
 
4f4406c
9a6b3d7
4f4406c
c9413de
4f4406c
9a6b3d7
c9413de
4f4406c
 
 
 
 
 
c9413de
4f4406c
9a6b3d7
4f4406c
 
e13ea4b
9a6b3d7
 
 
 
 
 
 
 
 
 
 
e9eab63
4f4406c
9a6b3d7
4f4406c
 
 
 
 
9a6b3d7
4f4406c
 
 
e9eab63
9a6b3d7
 
 
 
4f4406c
 
e9eab63
4f4406c
c9413de
4f4406c
9a6b3d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# FILE: api/ltx/ltx_aduc_pipeline.py
# DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE)
# to the LTXAducManager pool.

import logging
import time
import torch
import random
from typing import List, Optional, Tuple, Dict
from PIL import Image
from dataclasses import dataclass
from pathlib import Path
import sys

from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils

# O cliente importa o MANAGER para submeter todos os trabalhos.
from api.ltx.ltx_aduc_manager import ltx_aduc_manager

# Adiciona o path do LTX-Video para importações de baixo nível e tipos.
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
add_deps_to_path()

from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode


# ==============================================================================
# --- DEFINIÇÕES DE ESTRUTURA ---
# ==============================================================================

@dataclass
class LatentConditioningItem:
    """Estrutura de dados para passar latentes condicionados ao job de geração."""
    latent_tensor: torch.Tensor
    media_frame_number: int
    conditioning_strength: float

# ==============================================================================
# --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) ---
# ==============================================================================

def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor:
    """Job que usa o VAE do pipeline para codificar um tensor de pixel."""
    vae = pipeline.vae
    pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype)
    latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
    return latents.cpu()

def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor:
    """Job que usa o VAE do pipeline para decodificar um tensor latente."""
    vae = pipeline.vae
    latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype)
    pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
    return pixels.cpu()

def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor:
    """Job que usa o pipeline principal para gerar um chunk de vídeo latente."""
    generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed'])
    pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs}
    
    with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype):
        latents_raw = pipeline(**pipeline_kwargs).images
        
    return latents_raw.cpu()

# ==============================================================================
# --- A CLASSE CLIENTE UNIFICADA ---
# ==============================================================================

class LtxAducPipeline:
    """
    Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE.
    """
    def __init__(self):
        logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.")
        self.FRAMES_ALIGNMENT = 8

    def _get_random_seed(self) -> int:
        return random.randint(0, 2**32 - 1)

    def _align(self, dim: int, alignment: int = 8) -> int:
        return ((dim + alignment - 1) // alignment) * alignment

    # --- Métodos de API para o Orquestrador ---

    def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]:
        """Converte uma lista de imagens em uma lista de LatentConditioningItem."""
        pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list]
        items = []
        for i, pt in enumerate(pixel_tensors):
            latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt)
            frame_number, strength = params[i]
            items.append(LatentConditioningItem(
                latent_tensor=latent_tensor,
                media_frame_number=frame_number,
                conditioning_strength=strength
            ))
        return items

    def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor:
        """Decodifica um tensor latente em um tensor de pixels."""
        return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor)

    def generate_latents(
        self,
        prompt_list: List[str],
        duration_in_seconds: float,
        common_ltx_args: Dict,
        initial_conditioning_items: Optional[List[LatentConditioningItem]] = None
    ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
        """Gera um vídeo latente completo a partir de uma lista de prompts."""
        t0 = time.time()
        logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.")
        used_seed = self._get_random_seed()

        num_chunks = len(prompt_list)
        total_frames = self._align(int(duration_in_seconds * 24))
        frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames
        overlap_frames = self._align(9) if num_chunks > 1 else 0

        final_latents_list = []
        overlap_condition_item = None

        for i, chunk_prompt in enumerate(prompt_list):
            current_conditions = []
            if i == 0 and initial_conditioning_items:
                current_conditions.extend(initial_conditioning_items)
            if overlap_condition_item:
                current_conditions.append(overlap_condition_item)
            
            num_frames_for_chunk = frames_per_chunk_base
            if i == num_chunks - 1:
                processed_frames = sum(f.shape[2] for f in final_latents_list)
                num_frames_for_chunk = total_frames - processed_frames
            num_frames_for_chunk = self._align(num_frames_for_chunk)
            if num_frames_for_chunk <= 0: continue

            job_specific_args = {
                "prompt": chunk_prompt,
                "num_frames": num_frames_for_chunk,
                "seed": used_seed + i,
                "conditioning_items": current_conditions
            }
            final_job_args = {**common_ltx_args, **job_specific_args}
            
            chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args)

            if chunk_latents is None:
                logging.error(f"Failed to generate latents for scene {i+1}. Aborting.")
                return None, used_seed
            
            if i < num_chunks - 1:
                overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
                overlap_condition_item = LatentConditioningItem(
                    latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
                final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :])
            else:
                final_latents_list.append(chunk_latents)
        
        if not final_latents_list:
            logging.warning("No latent chunks were generated.")
            return None, used_seed
            
        final_latents = torch.cat(final_latents_list, dim=2)
        logging.info(f"LTX Client job finished in {time.time() - t0:.2f}s. Final latent shape: {final_latents.shape}")
        
        return final_latents, used_seed

# --- INSTÂNCIA SINGLETON DO CLIENTE ---
ltx_aduc_pipeline = LtxAducPipeline()