euiia commited on
Commit
906ad5a
·
verified ·
1 Parent(s): 0a8b5be

Rename ltx_manager_helpers.py to managers/ltx_manager_helpers.py

Browse files
ltx_manager_helpers.py → managers/ltx_manager_helpers.py RENAMED
@@ -1,9 +1,13 @@
1
  # ltx_manager_helpers.py
2
- # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
3
  #
4
- # Este programa é software livre: você pode redistribuí-lo e/ou modificá-lo
5
- # sob os termos da Licença Pública Geral Affero GNU...
6
- # AVISO DE PATENTE PENDENTE: Consulte NOTICE.md.
 
 
 
 
 
7
 
8
  import torch
9
  import gc
@@ -13,25 +17,124 @@ import logging
13
  import huggingface_hub
14
  import time
15
  import threading
16
- import json
17
- from typing import Optional, List
18
 
19
  from optimization import optimize_ltx_worker, can_optimize_fp8
20
  from hardware_manager import hardware_manager
21
  from inference import create_ltx_video_pipeline, calculate_padding
22
- from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem, LTXMultiScalePipeline
 
 
 
23
 
24
  logger = logging.getLogger(__name__)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class LtxWorker:
27
  """
28
- Representa uma única instância da pipeline LTX-Video em um dispositivo específico.
29
- Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
30
  """
31
  def __init__(self, device_id, ltx_config_file):
32
  self.cpu_device = torch.device('cpu')
33
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
34
- logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
35
 
36
  with open(ltx_config_file, "r") as file:
37
  self.config = yaml.safe_load(file)
@@ -40,7 +143,7 @@ class LtxWorker:
40
 
41
  models_dir = "downloaded_models_gradio"
42
 
43
- logger.info(f"LTX Worker ({self.device}): Carregando modelo para a CPU...")
44
  model_path = os.path.join(models_dir, self.config["checkpoint_path"])
45
  if not os.path.exists(model_path):
46
  model_path = huggingface_hub.hf_hub_download(
@@ -53,51 +156,66 @@ class LtxWorker:
53
  text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
54
  sampler=self.config["sampler"], device='cpu'
55
  )
56
- logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
57
 
58
  def to_gpu(self):
59
- """Move o pipeline para a GPU designada E OTIMIZA SE POSSÍVEL."""
60
  if self.device.type == 'cpu': return
61
- logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
62
  self.pipeline.to(self.device)
63
 
64
  if self.device.type == 'cuda' and can_optimize_fp8():
65
- logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
66
  optimize_ltx_worker(self)
67
- logger.info(f"LTX Worker ({self.device}): Otimização concluída.")
68
  elif self.device.type == 'cuda':
69
- logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada.")
70
 
71
  def to_cpu(self):
72
- """Move o pipeline de volta para a CPU e libera a memória da GPU."""
73
  if self.device.type == 'cpu': return
74
- logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
75
  self.pipeline.to('cpu')
76
  gc.collect()
77
  if torch.cuda.is_available(): torch.cuda.empty_cache()
78
 
79
  def generate_video_fragment_internal(self, **kwargs):
80
- """Invoca a pipeline de geração."""
81
  return self.pipeline(**kwargs).images
82
 
83
  class LtxPoolManager:
84
  """
85
- Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs.
86
- MODO "HOT START": Mantém todos os modelos carregados na VRAM para latência mínima.
87
  """
88
  def __init__(self, device_ids, ltx_config_file):
89
- logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
90
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
91
  self.current_worker_index = 0
92
  self.lock = threading.Lock()
93
 
 
 
 
94
  if all(w.device.type == 'cuda' for w in self.workers):
95
- logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
96
  for worker in self.workers:
97
  worker.to_gpu()
98
- logger.info("LTX POOL MANAGER: Todas as GPUs estão quentes e prontas.")
99
  else:
100
- logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. O pré-aquecimento de GPU foi ignorado.")
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def _get_next_worker(self):
103
  with self.lock:
@@ -106,7 +224,7 @@ class LtxPoolManager:
106
  return worker
107
 
108
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
109
- """Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
110
  pipeline_params = {
111
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
112
  "frame_rate": kwargs.get('video_fps', 24),
@@ -117,7 +235,6 @@ class LtxPoolManager:
117
  "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
118
  "output_type": "latent"
119
  }
120
-
121
  if 'latents' in kwargs:
122
  pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
123
  if 'strength' in kwargs:
@@ -128,37 +245,31 @@ class LtxPoolManager:
128
  item.latent_tensor = item.latent_tensor.to(worker.device)
129
  final_conditioning_items.append(item)
130
  pipeline_params["conditioning_items"] = final_conditioning_items
131
-
132
  if worker.is_distilled:
133
- logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
134
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
135
  pipeline_params["timesteps"] = fixed_timesteps
136
  if fixed_timesteps:
137
  pipeline_params["num_inference_steps"] = len(fixed_timesteps)
138
-
139
  return pipeline_params
140
 
141
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
 
142
  worker_to_use = self._get_next_worker()
143
  try:
144
- # [CORREÇÃO] A lógica de padding é específica para a geração do zero.
145
  height, width = kwargs['height'], kwargs['width']
146
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
147
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
148
  kwargs['height'], kwargs['width'] = padded_h, padded_w
149
-
150
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
151
-
152
- logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
153
-
154
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
155
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
156
  else:
157
  result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
158
-
159
  return result, padding_vals
160
  except Exception as e:
161
- logger.error(f"LTX POOL MANAGER: Erro durante a geração em {worker_to_use.device}: {e}", exc_info=True)
162
  raise e
163
  finally:
164
  if worker_to_use and worker_to_use.device.type == 'cuda':
@@ -166,47 +277,41 @@ class LtxPoolManager:
166
  gc.collect(); torch.cuda.empty_cache()
167
 
168
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
 
169
  worker_to_use = self._get_next_worker()
170
  try:
171
- # [CORREÇÃO] A lógica de dimensionamento para refinamento deriva da forma do latente.
172
  _b, _c, _f, latent_h, latent_w = latents_to_refine.shape
173
  vae_scale_factor = worker_to_use.pipeline.vae_scale_factor
174
-
175
- # Garante que as dimensões correspondam EXATAMENTE ao latente fornecido.
176
  kwargs['height'] = latent_h * vae_scale_factor
177
  kwargs['width'] = latent_w * vae_scale_factor
178
  kwargs['video_total_frames'] = kwargs.get('video_total_frames', _f * worker_to_use.pipeline.video_scale_factor)
179
  kwargs['latents'] = latents_to_refine
180
  kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
181
  kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
182
-
183
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
184
-
185
- logger.info(f"Iniciando REFINAMENTO em {worker_to_use.device} com shape {kwargs['width']}x{kwargs['height']}")
186
-
187
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
188
  result = pipeline_to_call(**pipeline_params).images
189
  return result, None
190
-
191
  except torch.cuda.OutOfMemoryError as e:
192
- logger.error(f"FALHA DE MEMÓRIA DURANTE O REFINAMENTO em {worker_to_use.device}: {e}")
193
- logger.warning("Limpando VRAM e retornando None para sinalizar a falha.")
194
  gc.collect(); torch.cuda.empty_cache()
195
  return None, None
196
  except Exception as e:
197
- logger.error(f"LTX POOL MANAGER: Erro inesperado durante o refinamento em {worker_to_use.device}: {e}", exc_info=True)
198
  raise e
199
  finally:
200
  if worker_to_use and worker_to_use.device.type == 'cuda':
201
  with torch.cuda.device(worker_to_use.device):
202
  gc.collect(); torch.cuda.empty_cache()
203
 
204
- # --- Instanciação Singleton ---
205
- logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
206
  with open("config.yaml", 'r') as f:
207
  config = yaml.safe_load(f)
208
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']
209
  ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
210
  ltx_config_path = config['specialists']['ltx']['config_file']
211
  ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file=ltx_config_path)
212
- logger.info("Especialista de Vídeo (LTX) pronto.")
 
1
  # ltx_manager_helpers.py
 
2
  #
3
+ # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 2.1.0
6
+ #
7
+ # This file manages the LTX-Video specialist pool. It now includes a crucial
8
+ # "monkey patch" for the LTX pipeline's `prepare_conditioning` method. This approach
9
+ # isolates our ADUC-specific modifications from the original library code, ensuring
10
+ # better maintainability and respecting the principle of separation of concerns.
11
 
12
  import torch
13
  import gc
 
17
  import huggingface_hub
18
  import time
19
  import threading
20
+ from typing import Optional, List, Tuple, Union
 
21
 
22
  from optimization import optimize_ltx_worker, can_optimize_fp8
23
  from hardware_manager import hardware_manager
24
  from inference import create_ltx_video_pipeline, calculate_padding
25
+ # We need these types for our patch
26
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
27
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
28
+ from diffusers.utils.torch_utils import randn_tensor
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
+
33
+ # --- MONKEY PATCHING SECTION ---
34
+ # This section contains our custom logic that will override the default
35
+ # behavior of the LTX pipeline at runtime.
36
+
37
+ def _aduc_prepare_conditioning_patch(
38
+ self: LTXVideoPipeline, # 'self' will be the instance of the LTXVideoPipeline
39
+ conditioning_items: Optional[List[Union[ConditioningItem, "LatentConditioningItem"]]],
40
+ init_latents: torch.Tensor,
41
+ num_frames: int,
42
+ height: int,
43
+ width: int,
44
+ vae_per_channel_normalize: bool = False,
45
+ generator=None,
46
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
47
+ """
48
+ This is our custom version of the `prepare_conditioning` method.
49
+ It correctly handles both standard ConditioningItem (from pixels) and our
50
+ ADUC-specific LatentConditioningItem (from latents), which the original
51
+ method does not. This function will replace the original one at runtime.
52
+ """
53
+ if not conditioning_items:
54
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
55
+ init_pixel_coords = latent_to_pixel_coords(
56
+ init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning
57
+ )
58
+ return init_latents, init_pixel_coords, None, 0
59
+
60
+ init_conditioning_mask = torch.zeros(init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device)
61
+ extra_conditioning_latents = []
62
+ extra_conditioning_pixel_coords = []
63
+ extra_conditioning_mask = []
64
+ extra_conditioning_num_latents = 0
65
+
66
+ is_latent_mode = hasattr(conditioning_items[0], 'latent_tensor')
67
+
68
+ if is_latent_mode:
69
+ for item in conditioning_items:
70
+ media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
71
+ media_frame_number = item.media_frame_number
72
+ strength = item.conditioning_strength
73
+ n_latent_frames = media_item_latents.shape[2]
74
+
75
+ if media_frame_number == 0:
76
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
77
+ init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
78
+ init_conditioning_mask[:, :f_l, :h_l, :w_l] = strength
79
+ else:
80
+ noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
81
+ media_item_latents = torch.lerp(noise, media_item_latents, strength)
82
+ patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
83
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
84
+ pixel_coords[:, 0] += media_frame_number
85
+ extra_conditioning_num_latents += patched_latents.shape[1]
86
+ new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
87
+ extra_conditioning_latents.append(patched_latents)
88
+ extra_conditioning_pixel_coords.append(pixel_coords)
89
+ extra_conditioning_mask.append(new_mask)
90
+ else: # Original pixel-based logic
91
+ for item in conditioning_items:
92
+ if not isinstance(item, ConditioningItem): continue
93
+ item = self._resize_conditioning_item(item, height, width)
94
+ media_item_latents = vae_encode(
95
+ item.media_item.to(dtype=self.vae.dtype, device=self.vae.device),
96
+ self.vae, vae_per_channel_normalize=vae_per_channel_normalize
97
+ ).to(dtype=init_latents.dtype)
98
+ media_frame_number = item.media_frame_number
99
+ strength = item.conditioning_strength
100
+ n_pixel_frames = item.media_item.shape[2]
101
+ if media_frame_number == 0:
102
+ media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
103
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
104
+ init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
105
+ init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
106
+ else:
107
+ # ... (this part of the original logic can be included if needed) ...
108
+ pass # For ADUC, we primarily use latent anchors for non-zero frames.
109
+
110
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
111
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
112
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
113
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
114
+
115
+ if extra_conditioning_latents:
116
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
117
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
118
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
119
+ if self.transformer.use_tpu_flash_attention:
120
+ init_latents = init_latents[:, :-extra_conditioning_num_latents]
121
+ init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
122
+ init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
123
+
124
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
125
+
126
+ # --- END OF MONKEY PATCHING SECTION ---
127
+
128
+
129
  class LtxWorker:
130
  """
131
+ Represents a single instance of the LTX-Video pipeline on a specific device.
132
+ Manages model loading to CPU and movement to/from GPU.
133
  """
134
  def __init__(self, device_id, ltx_config_file):
135
  self.cpu_device = torch.device('cpu')
136
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
137
+ logger.info(f"LTX Worker ({self.device}): Initializing with config '{ltx_config_file}'...")
138
 
139
  with open(ltx_config_file, "r") as file:
140
  self.config = yaml.safe_load(file)
 
143
 
144
  models_dir = "downloaded_models_gradio"
145
 
146
+ logger.info(f"LTX Worker ({self.device}): Loading model to CPU...")
147
  model_path = os.path.join(models_dir, self.config["checkpoint_path"])
148
  if not os.path.exists(model_path):
149
  model_path = huggingface_hub.hf_hub_download(
 
156
  text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
157
  sampler=self.config["sampler"], device='cpu'
158
  )
159
+ logger.info(f"LTX Worker ({self.device}): Model ready on CPU. Is distilled model? {self.is_distilled}")
160
 
161
  def to_gpu(self):
162
+ """Moves the pipeline to the designated GPU AND optimizes if possible."""
163
  if self.device.type == 'cpu': return
164
+ logger.info(f"LTX Worker: Moving pipeline to GPU {self.device}...")
165
  self.pipeline.to(self.device)
166
 
167
  if self.device.type == 'cuda' and can_optimize_fp8():
168
+ logger.info(f"LTX Worker ({self.device}): FP8 supported GPU detected. Optimizing...")
169
  optimize_ltx_worker(self)
170
+ logger.info(f"LTX Worker ({self.device}): Optimization complete.")
171
  elif self.device.type == 'cuda':
172
+ logger.info(f"LTX Worker ({self.device}): FP8 optimization not supported or disabled.")
173
 
174
  def to_cpu(self):
175
+ """Moves the pipeline back to the CPU and frees GPU memory."""
176
  if self.device.type == 'cpu': return
177
+ logger.info(f"LTX Worker: Unloading pipeline from GPU {self.device}...")
178
  self.pipeline.to('cpu')
179
  gc.collect()
180
  if torch.cuda.is_available(): torch.cuda.empty_cache()
181
 
182
  def generate_video_fragment_internal(self, **kwargs):
183
+ """Invokes the generation pipeline."""
184
  return self.pipeline(**kwargs).images
185
 
186
  class LtxPoolManager:
187
  """
188
+ Manages a pool of LtxWorkers for optimized multi-GPU usage.
189
+ HOT START MODE: Keeps all models loaded in VRAM for minimum latency.
190
  """
191
  def __init__(self, device_ids, ltx_config_file):
192
+ logger.info(f"LTX POOL MANAGER: Creating workers for devices: {device_ids}")
193
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
194
  self.current_worker_index = 0
195
  self.lock = threading.Lock()
196
 
197
+ # <<< NEW: APPLY PATCH AFTER INITIALIZATION >>>
198
+ self._apply_ltx_pipeline_patches()
199
+
200
  if all(w.device.type == 'cuda' for w in self.workers):
201
+ logger.info("LTX POOL MANAGER: HOT START MODE ENABLED. Pre-warming all GPUs...")
202
  for worker in self.workers:
203
  worker.to_gpu()
204
+ logger.info("LTX POOL MANAGER: All GPUs are hot and ready.")
205
  else:
206
+ logger.info("LTX POOL MANAGER: Operating in CPU or mixed mode. GPU pre-warming skipped.")
207
+
208
+ def _apply_ltx_pipeline_patches(self):
209
+ """
210
+ Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
211
+ This is where the monkey patching happens.
212
+ """
213
+ logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
214
+ for worker in self.workers:
215
+ # The __get__ method binds our standalone function to the pipeline instance,
216
+ # making it behave like a regular method (so 'self' works correctly).
217
+ worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
218
+ logger.info("LTX POOL MANAGER: All pipeline instances have been patched successfully.")
219
 
220
  def _get_next_worker(self):
221
  with self.lock:
 
224
  return worker
225
 
226
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
227
+ # This function remains unchanged
228
  pipeline_params = {
229
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
230
  "frame_rate": kwargs.get('video_fps', 24),
 
235
  "rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
236
  "output_type": "latent"
237
  }
 
238
  if 'latents' in kwargs:
239
  pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
240
  if 'strength' in kwargs:
 
245
  item.latent_tensor = item.latent_tensor.to(worker.device)
246
  final_conditioning_items.append(item)
247
  pipeline_params["conditioning_items"] = final_conditioning_items
 
248
  if worker.is_distilled:
249
+ logger.info(f"Worker {worker.device} is using a distilled model. Using fixed timesteps.")
250
  fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
251
  pipeline_params["timesteps"] = fixed_timesteps
252
  if fixed_timesteps:
253
  pipeline_params["num_inference_steps"] = len(fixed_timesteps)
 
254
  return pipeline_params
255
 
256
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
257
+ # This function remains unchanged
258
  worker_to_use = self._get_next_worker()
259
  try:
 
260
  height, width = kwargs['height'], kwargs['width']
261
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
262
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
263
  kwargs['height'], kwargs['width'] = padded_h, padded_w
 
264
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
265
+ logger.info(f"Initiating GENERATION on {worker_to_use.device} with shape {padded_w}x{padded_h}")
 
 
266
  if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
267
  result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
268
  else:
269
  result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
 
270
  return result, padding_vals
271
  except Exception as e:
272
+ logger.error(f"LTX POOL MANAGER: Error during generation on {worker_to_use.device}: {e}", exc_info=True)
273
  raise e
274
  finally:
275
  if worker_to_use and worker_to_use.device.type == 'cuda':
 
277
  gc.collect(); torch.cuda.empty_cache()
278
 
279
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
280
+ # This function remains unchanged
281
  worker_to_use = self._get_next_worker()
282
  try:
 
283
  _b, _c, _f, latent_h, latent_w = latents_to_refine.shape
284
  vae_scale_factor = worker_to_use.pipeline.vae_scale_factor
 
 
285
  kwargs['height'] = latent_h * vae_scale_factor
286
  kwargs['width'] = latent_w * vae_scale_factor
287
  kwargs['video_total_frames'] = kwargs.get('video_total_frames', _f * worker_to_use.pipeline.video_scale_factor)
288
  kwargs['latents'] = latents_to_refine
289
  kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
290
  kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
 
291
  pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
292
+ logger.info(f"Initiating REFINEMENT on {worker_to_use.device} with shape {kwargs['width']}x{kwargs['height']}")
 
 
293
  pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
294
  result = pipeline_to_call(**pipeline_params).images
295
  return result, None
 
296
  except torch.cuda.OutOfMemoryError as e:
297
+ logger.error(f"MEMORY FAILURE DURING REFINEMENT on {worker_to_use.device}: {e}")
298
+ logger.warning("Clearing VRAM and returning None to signal failure.")
299
  gc.collect(); torch.cuda.empty_cache()
300
  return None, None
301
  except Exception as e:
302
+ logger.error(f"LTX POOL MANAGER: Unexpected error during refinement on {worker_to_use.device}: {e}", exc_info=True)
303
  raise e
304
  finally:
305
  if worker_to_use and worker_to_use.device.type == 'cuda':
306
  with torch.cuda.device(worker_to_use.device):
307
  gc.collect(); torch.cuda.empty_cache()
308
 
309
+ # --- Singleton Instantiation ---
310
+ logger.info("Reading config.yaml to initialize LTX Pool Manager...")
311
  with open("config.yaml", 'r') as f:
312
  config = yaml.safe_load(f)
313
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']
314
  ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
315
  ltx_config_path = config['specialists']['ltx']['config_file']
316
  ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file=ltx_config_path)
317
+ logger.info("Video Specialist (LTX) ready.")