euiia commited on
Commit
1849a75
·
verified ·
1 Parent(s): 616038b

Update managers/ltx_manager.py

Browse files
Files changed (1) hide show
  1. managers/ltx_manager.py +20 -52
managers/ltx_manager.py CHANGED
@@ -2,12 +2,11 @@
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.2.0
6
  #
7
  # This file manages the LTX-Video specialist pool. It has been refactored to be
8
- # self-contained by automatically cloning its own dependencies from the official
9
- # LTX-Video repository. This modular approach makes the ADUC-SDR framework
10
- # robust, portable, and easy to maintain.
11
 
12
  import torch
13
  import gc
@@ -22,8 +21,9 @@ import subprocess
22
  from pathlib import Path
23
  from typing import Optional, List, Tuple, Union
24
 
25
- from tools.optimization import optimize_ltx_worker, can_optimize_fp8
26
- from tools.hardware_manager import hardware_manager
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
@@ -32,18 +32,16 @@ DEPS_DIR = Path("./deps")
32
  LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
33
  LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"
34
 
35
- # --- Placeholder for lazy-loaded modules ---
36
  create_ltx_video_pipeline = None
37
  calculate_padding = None
38
  LTXVideoPipeline = None
39
  ConditioningItem = None
40
- LatentConditioningItem = None
41
  LTXMultiScalePipeline = None
42
  vae_encode = None
43
  latent_to_pixel_coords = None
44
  randn_tensor = None
45
 
46
-
47
  class LtxPoolManager:
48
  """
49
  Manages a pool of LtxWorkers for optimized multi-GPU usage.
@@ -55,10 +53,9 @@ class LtxPoolManager:
55
  self._setup_dependencies()
56
  self._lazy_load_ltx_modules()
57
 
58
- # Adjust config path to be inside the cloned repo
59
  self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name
60
 
61
- self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in self]
62
  self.current_worker_index = 0
63
  self.lock = threading.Lock()
64
 
@@ -98,11 +95,11 @@ class LtxPoolManager:
98
  if self._ltx_modules_loaded:
99
  return
100
 
101
- global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LatentConditioningItem
102
- global vae_encode, latent_to_pixel_coords, LTXMultiScalePipeline, randn_tensor
103
 
104
- from inference import create_ltx_video_pipeline, calculate_padding
105
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem, LTXMultiScalePipeline
106
  from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
107
  from diffusers.utils.torch_utils import randn_tensor
108
 
@@ -174,33 +171,8 @@ class LtxPoolManager:
174
  gc.collect(); torch.cuda.empty_cache()
175
 
176
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
177
- worker_to_use = self._get_next_worker()
178
- try:
179
- _b, _c, _f, latent_h, latent_w = latents_to_refine.shape
180
- vae_scale_factor = worker_to_use.pipeline.vae_scale_factor
181
- kwargs['height'] = latent_h * vae_scale_factor
182
- kwargs['width'] = latent_w * vae_scale_factor
183
- kwargs['video_total_frames'] = kwargs.get('video_total_frames', _f * worker_to_use.pipeline.video_scale_factor)
184
- kwargs['latents'] = latents_to_refine
185
- kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
186
- kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
187
- pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
188
- logger.info(f"Initiating REFINEMENT on {worker_to_use.device} with shape {kwargs['width']}x{kwargs['height']}")
189
- pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
190
- result = pipeline_to_call(**pipeline_params).images
191
- return result, None
192
- except torch.cuda.OutOfMemoryError as e:
193
- logger.error(f"MEMORY FAILURE DURING REFINEMENT on {worker_to_use.device}: {e}")
194
- logger.warning("Clearing VRAM and returning None to signal failure.")
195
- gc.collect(); torch.cuda.empty_cache()
196
- return None, None
197
- except Exception as e:
198
- logger.error(f"LTX POOL MANAGER: Unexpected error during refinement on {worker_to_use.device}: {e}", exc_info=True)
199
- raise e
200
- finally:
201
- if worker_to_use and worker_to_use.device.type == 'cuda':
202
- with torch.cuda.device(worker_to_use.device):
203
- gc.collect(); torch.cuda.empty_cache()
204
 
205
  class LtxWorker:
206
  """
@@ -271,16 +243,13 @@ def _aduc_prepare_conditioning_patch(
271
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
272
  return init_latents, init_pixel_coords, None, 0
273
  init_conditioning_mask = torch.zeros(init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device)
274
- extra_conditioning_latents = []
275
- extra_conditioning_pixel_coords = []
276
- extra_conditioning_mask = []
277
  extra_conditioning_num_latents = 0
278
  is_latent_mode = hasattr(conditioning_items[0], 'latent_tensor')
279
  if is_latent_mode:
280
  for item in conditioning_items:
281
  media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
282
- media_frame_number = item.media_frame_number
283
- strength = item.conditioning_strength
284
  if media_frame_number == 0:
285
  f_l, h_l, w_l = media_item_latents.shape[-3:]
286
  init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
@@ -301,15 +270,14 @@ def _aduc_prepare_conditioning_patch(
301
  if not isinstance(item, ConditioningItem): continue
302
  item = self._resize_conditioning_item(item, height, width)
303
  media_item_latents = vae_encode(item.media_item.to(dtype=self.vae.dtype, device=self.vae.device), self.vae, vae_per_channel_normalize=vae_per_channel_normalize).to(dtype=init_latents.dtype)
304
- media_frame_number = item.media_frame_number
305
- strength = item.conditioning_strength
306
- if media_frame_number == 0:
307
  media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
308
  f_l, h_l, w_l = media_item_latents.shape[-3:]
309
- init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
310
- init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
311
  else:
312
  logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
 
313
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
314
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
315
  init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
 
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.2.2
6
  #
7
  # This file manages the LTX-Video specialist pool. It has been refactored to be
8
+ # self-contained by automatically cloning its own dependencies and using a local
9
+ # utility module for pipeline creation, fully decoupling it from external scripts.
 
10
 
11
  import torch
12
  import gc
 
21
  from pathlib import Path
22
  from typing import Optional, List, Tuple, Union
23
 
24
+ from optimization import optimize_ltx_worker, can_optimize_fp8
25
+ from hardware_manager import hardware_manager
26
+ from aduc_types import LatentConditioningItem
27
 
28
  logger = logging.getLogger(__name__)
29
 
 
32
  LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
33
  LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"
34
 
35
+ # --- Placeholders for lazy-loaded modules ---
36
  create_ltx_video_pipeline = None
37
  calculate_padding = None
38
  LTXVideoPipeline = None
39
  ConditioningItem = None
 
40
  LTXMultiScalePipeline = None
41
  vae_encode = None
42
  latent_to_pixel_coords = None
43
  randn_tensor = None
44
 
 
45
  class LtxPoolManager:
46
  """
47
  Manages a pool of LtxWorkers for optimized multi-GPU usage.
 
53
  self._setup_dependencies()
54
  self._lazy_load_ltx_modules()
55
 
 
56
  self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name
57
 
58
+ self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in device_ids]
59
  self.current_worker_index = 0
60
  self.lock = threading.Lock()
61
 
 
95
  if self._ltx_modules_loaded:
96
  return
97
 
98
+ global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
99
+ global vae_encode, latent_to_pixel_coords, randn_tensor
100
 
101
+ from managers.ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding
102
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LTXMultiScalePipeline
103
  from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
104
  from diffusers.utils.torch_utils import randn_tensor
105
 
 
171
  gc.collect(); torch.cuda.empty_cache()
172
 
173
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
174
+ # This function can be expanded later if needed.
175
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  class LtxWorker:
178
  """
 
243
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
244
  return init_latents, init_pixel_coords, None, 0
245
  init_conditioning_mask = torch.zeros(init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device)
246
+ extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
 
 
247
  extra_conditioning_num_latents = 0
248
  is_latent_mode = hasattr(conditioning_items[0], 'latent_tensor')
249
  if is_latent_mode:
250
  for item in conditioning_items:
251
  media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
252
+ media_frame_number, strength = item.media_frame_number, item.conditioning_strength
 
253
  if media_frame_number == 0:
254
  f_l, h_l, w_l = media_item_latents.shape[-3:]
255
  init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
 
270
  if not isinstance(item, ConditioningItem): continue
271
  item = self._resize_conditioning_item(item, height, width)
272
  media_item_latents = vae_encode(item.media_item.to(dtype=self.vae.dtype, device=self.vae.device), self.vae, vae_per_channel_normalize=vae_per_channel_normalize).to(dtype=init_latents.dtype)
273
+ if item.media_frame_number == 0:
 
 
274
  media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
275
  f_l, h_l, w_l = media_item_latents.shape[-3:]
276
+ init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, item.conditioning_strength)
277
+ init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = item.conditioning_strength
278
  else:
279
  logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
280
+
281
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
282
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
283
  init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))