eeuuia commited on
Commit
140e6ff
·
verified ·
1 Parent(s): 52c58b6

Update api/ltx/ltx_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_pipeline.py +68 -77
api/ltx/ltx_aduc_pipeline.py CHANGED
@@ -1,24 +1,21 @@
1
  # FILE: api/ltx/ltx_aduc_pipeline.py
2
- # DESCRIPTION: A high-level client for submitting LTX video generation jobs to the pool manager.
3
- # Its sole responsibility is to orchestrate the generation of a final LATENT tensor from prompts
4
- # and initial conditions, without handling pixel decoding.
5
 
6
  import logging
7
  import time
8
  import torch
9
  import random
10
  from typing import List, Optional, Tuple, Dict
 
 
11
  from pathlib import Path
12
  import sys
13
- import os
14
 
15
- # O cliente importa o MANAGER para submeter trabalhos ao pool de workers.
16
  from api.ltx.ltx_aduc_manager import ltx_aduc_manager
17
 
18
- # O cliente precisa da definição de LatentConditioningItem, que agora vive no cliente VAE.
19
- from api.ltx.vae_aduc_pipeline import LatentConditioningItem
20
-
21
- # Adiciona o path do LTX-Video para importação de tipos (para anotação da função de job).
22
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
23
  def add_deps_to_path():
24
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
@@ -27,101 +24,105 @@ def add_deps_to_path():
27
  add_deps_to_path()
28
 
29
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # ==============================================================================
32
  # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) ---
33
  # ==============================================================================
34
 
35
- def _job_generate_latent_chunk(
36
- pipeline: LTXVideoPipeline,
37
- autocast_dtype: torch.dtype,
38
- **kwargs
39
- ) -> torch.Tensor:
40
- """
41
- Função de trabalho que executa a geração de um único chunk (cena) de vídeo latente.
42
- Esta função é executada DENTRO de um LTXMainWorker, na GPU principal.
43
-
44
- Args:
45
- pipeline (LTXVideoPipeline): A instância do pipeline fornecida pelo worker.
46
- autocast_dtype (torch.dtype): A precisão de computação (ex: bfloat16) fornecida pelo worker.
47
- **kwargs: Dicionário contendo todos os parâmetros para a geração (prompt, height, width, etc.).
48
- """
49
- # Cria um gerador na mesma GPU do pipeline para consistência.
 
50
  generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed'])
 
51
 
52
- # Monta os argumentos finais para a chamada do pipeline a partir dos kwargs recebidos.
53
- pipeline_kwargs = {
54
- "generator": generator,
55
- "output_type": "latent", # Ponto chave: sempre solicitamos a saída em formato latente.
56
- **kwargs
57
- }
58
-
59
- logging.info(f"[LTX Job] Generating chunk with {kwargs['num_frames']} frames for prompt: '{kwargs['prompt'][:50]}...'")
60
-
61
- # Executa a geração dentro do contexto de autocast com a precisão definida pelo worker.
62
  with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype):
63
- latents_raw = pipeline(**pipeline_kwargs).images
64
-
65
- # Retorna o tensor latente na CPU para liberar a VRAM do worker para o próximo job.
66
  return latents_raw.cpu()
67
 
68
  # ==============================================================================
69
- # --- A CLASSE CLIENTE (Interface Pública para Geração de Vídeo Latente) ---
70
  # ==============================================================================
71
 
72
  class LtxAducPipeline:
73
  """
74
- Cliente de alto nível para orquestrar a geração de vídeo latente.
75
- Ele quebra a tarefa em chunks e submete cada um como um trabalho ao LTXAducManager.
76
  """
77
  def __init__(self):
78
- logging.info("✅ LTX ADUC Pipeline (Client) initialized and ready to submit jobs.")
79
  self.FRAMES_ALIGNMENT = 8
80
- pass
81
 
82
  def _get_random_seed(self) -> int:
83
- """Gera e retorna uma nova semente aleatória para garantir variedade."""
84
  return random.randint(0, 2**32 - 1)
85
 
86
  def _align(self, dim: int, alignment: int = 8) -> int:
87
- """Alinha uma dimensão para o múltiplo mais próximo para compatibilidade com o modelo."""
88
  return ((dim + alignment - 1) // alignment) * alignment
89
 
90
- def __call__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  self,
92
  prompt_list: List[str],
93
  duration_in_seconds: float,
94
  common_ltx_args: Dict,
95
  initial_conditioning_items: Optional[List[LatentConditioningItem]] = None
96
  ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
97
- """
98
- Ponto de entrada principal para gerar um vídeo latente completo.
99
-
100
- Args:
101
- prompt_list (List[str]): Lista de prompts, onde cada um representa uma cena.
102
- duration_in_seconds (float): Duração total desejada para o vídeo.
103
- common_ltx_args (Dict): Dicionário com argumentos comuns a todas as cenas (height, width, etc.).
104
- initial_conditioning_items (Optional[List[LatentConditioningItem]]): Itens para condicionar a primeira cena.
105
-
106
- Returns:
107
- Tuple[Optional[torch.Tensor], Optional[int]]:
108
- - O tensor latente final completo (na CPU).
109
- - A semente principal usada para a geração.
110
- """
111
  t0 = time.time()
112
  logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.")
113
-
114
  used_seed = self._get_random_seed()
115
- logging.info(f"Generation seed set to: {used_seed}")
116
 
117
- # --- Lógica de Divisão de Chunks e Sobreposição ("Eco Cinético") ---
118
  num_chunks = len(prompt_list)
119
  total_frames = self._align(int(duration_in_seconds * 24))
120
  frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames
121
  overlap_frames = self._align(9) if num_chunks > 1 else 0
122
 
123
  final_latents_list = []
124
- overlap_condition_item: Optional[LatentConditioningItem] = None
125
 
126
  for i, chunk_prompt in enumerate(prompt_list):
127
  current_conditions = []
@@ -130,7 +131,6 @@ class LtxAducPipeline:
130
  if overlap_condition_item:
131
  current_conditions.append(overlap_condition_item)
132
 
133
- # Calcula o número de frames para o chunk atual, garantindo que o último preencha o restante.
134
  num_frames_for_chunk = frames_per_chunk_base
135
  if i == num_chunks - 1:
136
  processed_frames = sum(f.shape[2] for f in final_latents_list)
@@ -138,33 +138,24 @@ class LtxAducPipeline:
138
  num_frames_for_chunk = self._align(num_frames_for_chunk)
139
  if num_frames_for_chunk <= 0: continue
140
 
141
- # --- Preparação e Submissão do Job ---
142
  job_specific_args = {
143
  "prompt": chunk_prompt,
144
  "num_frames": num_frames_for_chunk,
145
- "seed": used_seed + i, # Incrementa a semente para cada cena
146
  "conditioning_items": current_conditions
147
  }
148
  final_job_args = {**common_ltx_args, **job_specific_args}
149
 
150
- chunk_latents = ltx_aduc_manager.submit_job(
151
- job_type='ltx',
152
- job_func=_job_generate_latent_chunk,
153
- **final_job_args
154
- )
155
 
156
  if chunk_latents is None:
157
  logging.error(f"Failed to generate latents for scene {i+1}. Aborting.")
158
  return None, used_seed
159
 
160
- # --- Gerenciamento da Sobreposição ---
161
  if i < num_chunks - 1:
162
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
163
  overlap_condition_item = LatentConditioningItem(
164
- latent_tensor=overlap_latents,
165
- media_frame_number=0,
166
- conditioning_strength=1.0
167
- )
168
  final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :])
169
  else:
170
  final_latents_list.append(chunk_latents)
@@ -183,4 +174,4 @@ try:
183
  ltx_aduc_pipeline = LtxAducPipeline()
184
  except Exception as e:
185
  logging.critical("CRITICAL: Failed to initialize the LtxAducPipeline client.", exc_info=True)
186
- ltx_aduc_pipeline = None
 
1
  # FILE: api/ltx/ltx_aduc_pipeline.py
2
+ # DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE)
3
+ # to the LTXAducManager pool.
 
4
 
5
  import logging
6
  import time
7
  import torch
8
  import random
9
  from typing import List, Optional, Tuple, Dict
10
+ from PIL import Image
11
+ from dataclasses import dataclass
12
  from pathlib import Path
13
  import sys
 
14
 
15
+ # O cliente importa o MANAGER para submeter todos os trabalhos.
16
  from api.ltx.ltx_aduc_manager import ltx_aduc_manager
17
 
18
+ # Adiciona o path do LTX-Video para importações de baixo nível e tipos.
 
 
 
19
  LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
20
  def add_deps_to_path():
21
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
 
24
  add_deps_to_path()
25
 
26
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
27
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
28
+ from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils
29
+
30
+ # ==============================================================================
31
+ # --- DEFINIÇÕES DE ESTRUTURA ---
32
+ # ==============================================================================
33
+
34
+ @dataclass
35
+ class LatentConditioningItem:
36
+ """Estrutura de dados para passar latentes condicionados ao job de geração."""
37
+ latent_tensor: torch.Tensor
38
+ media_frame_number: int
39
+ conditioning_strength: float
40
 
41
  # ==============================================================================
42
  # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) ---
43
  # ==============================================================================
44
 
45
+ def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor:
46
+ """Job que usa o VAE do pipeline para codificar um tensor de pixel."""
47
+ vae = pipeline.vae
48
+ pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype)
49
+ latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
50
+ return latents.cpu()
51
+
52
+ def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor:
53
+ """Job que usa o VAE do pipeline para decodificar um tensor latente."""
54
+ vae = pipeline.vae
55
+ latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype)
56
+ pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
57
+ return pixels.cpu()
58
+
59
+ def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor:
60
+ """Job que usa o pipeline principal para gerar um chunk de vídeo latente."""
61
  generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed'])
62
+ pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs}
63
 
 
 
 
 
 
 
 
 
 
 
64
  with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype):
65
+ latents_raw = pipeline(**pipeline_kwargs).images
66
+
 
67
  return latents_raw.cpu()
68
 
69
  # ==============================================================================
70
+ # --- A CLASSE CLIENTE UNIFICADA ---
71
  # ==============================================================================
72
 
73
  class LtxAducPipeline:
74
  """
75
+ Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE.
 
76
  """
77
  def __init__(self):
78
+ logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.")
79
  self.FRAMES_ALIGNMENT = 8
 
80
 
81
  def _get_random_seed(self) -> int:
 
82
  return random.randint(0, 2**32 - 1)
83
 
84
  def _align(self, dim: int, alignment: int = 8) -> int:
 
85
  return ((dim + alignment - 1) // alignment) * alignment
86
 
87
+ # --- Métodos de API para o Orquestrador ---
88
+
89
+ def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]:
90
+ """Converte uma lista de imagens em uma lista de LatentConditioningItem."""
91
+ pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list]
92
+ items = []
93
+ for i, pt in enumerate(pixel_tensors):
94
+ latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt)
95
+ frame_number, strength = params[i]
96
+ items.append(LatentConditioningItem(
97
+ latent_tensor=latent_tensor,
98
+ media_frame_number=frame_number,
99
+ conditioning_strength=strength
100
+ ))
101
+ return items
102
+
103
+ def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor:
104
+ """Decodifica um tensor latente em um tensor de pixels."""
105
+ return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor)
106
+
107
+ def generate_latents(
108
  self,
109
  prompt_list: List[str],
110
  duration_in_seconds: float,
111
  common_ltx_args: Dict,
112
  initial_conditioning_items: Optional[List[LatentConditioningItem]] = None
113
  ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
114
+ """Gera um vídeo latente completo a partir de uma lista de prompts."""
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  t0 = time.time()
116
  logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.")
 
117
  used_seed = self._get_random_seed()
 
118
 
 
119
  num_chunks = len(prompt_list)
120
  total_frames = self._align(int(duration_in_seconds * 24))
121
  frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames
122
  overlap_frames = self._align(9) if num_chunks > 1 else 0
123
 
124
  final_latents_list = []
125
+ overlap_condition_item = None
126
 
127
  for i, chunk_prompt in enumerate(prompt_list):
128
  current_conditions = []
 
131
  if overlap_condition_item:
132
  current_conditions.append(overlap_condition_item)
133
 
 
134
  num_frames_for_chunk = frames_per_chunk_base
135
  if i == num_chunks - 1:
136
  processed_frames = sum(f.shape[2] for f in final_latents_list)
 
138
  num_frames_for_chunk = self._align(num_frames_for_chunk)
139
  if num_frames_for_chunk <= 0: continue
140
 
 
141
  job_specific_args = {
142
  "prompt": chunk_prompt,
143
  "num_frames": num_frames_for_chunk,
144
+ "seed": used_seed + i,
145
  "conditioning_items": current_conditions
146
  }
147
  final_job_args = {**common_ltx_args, **job_specific_args}
148
 
149
+ chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args)
 
 
 
 
150
 
151
  if chunk_latents is None:
152
  logging.error(f"Failed to generate latents for scene {i+1}. Aborting.")
153
  return None, used_seed
154
 
 
155
  if i < num_chunks - 1:
156
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
157
  overlap_condition_item = LatentConditioningItem(
158
+ latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
 
 
 
159
  final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :])
160
  else:
161
  final_latents_list.append(chunk_latents)
 
174
  ltx_aduc_pipeline = LtxAducPipeline()
175
  except Exception as e:
176
  logging.critical("CRITICAL: Failed to initialize the LtxAducPipeline client.", exc_info=True)
177
+ ltx_aduc_pipeline = None```