euiia commited on
Commit
532ae25
·
verified ·
1 Parent(s): e52c86e

Rename managers/ltx_manager_helpers.py to managers/ltx_manager.py

Browse files
managers/{ltx_manager_helpers.py → ltx_manager.py} RENAMED
@@ -1,4 +1,4 @@
1
- # ltx_manager_helpers.py
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
@@ -25,6 +25,7 @@ from inference import create_ltx_video_pipeline, calculate_padding
25
  # We need these types for our patch
26
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
27
  from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
 
28
  from diffusers.utils.torch_utils import randn_tensor
29
 
30
  logger = logging.getLogger(__name__)
@@ -70,8 +71,7 @@ def _aduc_prepare_conditioning_patch(
70
  media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
71
  media_frame_number = item.media_frame_number
72
  strength = item.conditioning_strength
73
- n_latent_frames = media_item_latents.shape[2]
74
-
75
  if media_frame_number == 0:
76
  f_l, h_l, w_l = media_item_latents.shape[-3:]
77
  init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
@@ -87,7 +87,7 @@ def _aduc_prepare_conditioning_patch(
87
  extra_conditioning_latents.append(patched_latents)
88
  extra_conditioning_pixel_coords.append(pixel_coords)
89
  extra_conditioning_mask.append(new_mask)
90
- else: # Original pixel-based logic
91
  for item in conditioning_items:
92
  if not isinstance(item, ConditioningItem): continue
93
  item = self._resize_conditioning_item(item, height, width)
@@ -97,15 +97,14 @@ def _aduc_prepare_conditioning_patch(
97
  ).to(dtype=init_latents.dtype)
98
  media_frame_number = item.media_frame_number
99
  strength = item.conditioning_strength
100
- n_pixel_frames = item.media_item.shape[2]
101
  if media_frame_number == 0:
102
  media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
103
  f_l, h_l, w_l = media_item_latents.shape[-3:]
104
  init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
105
  init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
106
  else:
107
- # ... (this part of the original logic can be included if needed) ...
108
- pass # For ADUC, we primarily use latent anchors for non-zero frames.
109
 
110
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
111
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
@@ -194,7 +193,6 @@ class LtxPoolManager:
194
  self.current_worker_index = 0
195
  self.lock = threading.Lock()
196
 
197
- # <<< NEW: APPLY PATCH AFTER INITIALIZATION >>>
198
  self._apply_ltx_pipeline_patches()
199
 
200
  if all(w.device.type == 'cuda' for w in self.workers):
@@ -208,12 +206,9 @@ class LtxPoolManager:
208
  def _apply_ltx_pipeline_patches(self):
209
  """
210
  Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
211
- This is where the monkey patching happens.
212
  """
213
  logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
214
  for worker in self.workers:
215
- # The __get__ method binds our standalone function to the pipeline instance,
216
- # making it behave like a regular method (so 'self' works correctly).
217
  worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
218
  logger.info("LTX POOL MANAGER: All pipeline instances have been patched successfully.")
219
 
@@ -224,7 +219,6 @@ class LtxPoolManager:
224
  return worker
225
 
226
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
227
- # This function remains unchanged
228
  pipeline_params = {
229
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
230
  "frame_rate": kwargs.get('video_fps', 24),
@@ -254,7 +248,6 @@ class LtxPoolManager:
254
  return pipeline_params
255
 
256
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
257
- # This function remains unchanged
258
  worker_to_use = self._get_next_worker()
259
  try:
260
  height, width = kwargs['height'], kwargs['width']
@@ -277,7 +270,6 @@ class LtxPoolManager:
277
  gc.collect(); torch.cuda.empty_cache()
278
 
279
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
280
- # This function remains unchanged
281
  worker_to_use = self._get_next_worker()
282
  try:
283
  _b, _c, _f, latent_h, latent_w = latents_to_refine.shape
 
1
+ # managers/ltx_manager.py
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
 
25
  # We need these types for our patch
26
  from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
27
  from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
28
+ from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline
29
  from diffusers.utils.torch_utils import randn_tensor
30
 
31
  logger = logging.getLogger(__name__)
 
71
  media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
72
  media_frame_number = item.media_frame_number
73
  strength = item.conditioning_strength
74
+
 
75
  if media_frame_number == 0:
76
  f_l, h_l, w_l = media_item_latents.shape[-3:]
77
  init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
 
87
  extra_conditioning_latents.append(patched_latents)
88
  extra_conditioning_pixel_coords.append(pixel_coords)
89
  extra_conditioning_mask.append(new_mask)
90
+ else: # Original pixel-based logic for fallback
91
  for item in conditioning_items:
92
  if not isinstance(item, ConditioningItem): continue
93
  item = self._resize_conditioning_item(item, height, width)
 
97
  ).to(dtype=init_latents.dtype)
98
  media_frame_number = item.media_frame_number
99
  strength = item.conditioning_strength
 
100
  if media_frame_number == 0:
101
  media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
102
  f_l, h_l, w_l = media_item_latents.shape[-3:]
103
  init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
104
  init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
105
  else:
106
+ logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
107
+ pass
108
 
109
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
110
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
 
193
  self.current_worker_index = 0
194
  self.lock = threading.Lock()
195
 
 
196
  self._apply_ltx_pipeline_patches()
197
 
198
  if all(w.device.type == 'cuda' for w in self.workers):
 
206
  def _apply_ltx_pipeline_patches(self):
207
  """
208
  Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
 
209
  """
210
  logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
211
  for worker in self.workers:
 
 
212
  worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
213
  logger.info("LTX POOL MANAGER: All pipeline instances have been patched successfully.")
214
 
 
219
  return worker
220
 
221
  def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
 
222
  pipeline_params = {
223
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
224
  "frame_rate": kwargs.get('video_fps', 24),
 
248
  return pipeline_params
249
 
250
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
 
251
  worker_to_use = self._get_next_worker()
252
  try:
253
  height, width = kwargs['height'], kwargs['width']
 
270
  gc.collect(); torch.cuda.empty_cache()
271
 
272
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
 
273
  worker_to_use = self._get_next_worker()
274
  try:
275
  _b, _c, _f, latent_h, latent_w = latents_to_refine.shape