Rename managers/ltx_manager_helpers.py to managers/ltx_manager.py
Browse files
managers/{ltx_manager_helpers.py → ltx_manager.py}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#
|
| 2 |
#
|
| 3 |
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
|
@@ -25,6 +25,7 @@ from inference import create_ltx_video_pipeline, calculate_padding
|
|
| 25 |
# We need these types for our patch
|
| 26 |
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
|
| 27 |
from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
|
|
|
|
| 28 |
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
|
@@ -70,8 +71,7 @@ def _aduc_prepare_conditioning_patch(
|
|
| 70 |
media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
|
| 71 |
media_frame_number = item.media_frame_number
|
| 72 |
strength = item.conditioning_strength
|
| 73 |
-
|
| 74 |
-
|
| 75 |
if media_frame_number == 0:
|
| 76 |
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 77 |
init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
|
|
@@ -87,7 +87,7 @@ def _aduc_prepare_conditioning_patch(
|
|
| 87 |
extra_conditioning_latents.append(patched_latents)
|
| 88 |
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 89 |
extra_conditioning_mask.append(new_mask)
|
| 90 |
-
else: # Original pixel-based logic
|
| 91 |
for item in conditioning_items:
|
| 92 |
if not isinstance(item, ConditioningItem): continue
|
| 93 |
item = self._resize_conditioning_item(item, height, width)
|
|
@@ -97,15 +97,14 @@ def _aduc_prepare_conditioning_patch(
|
|
| 97 |
).to(dtype=init_latents.dtype)
|
| 98 |
media_frame_number = item.media_frame_number
|
| 99 |
strength = item.conditioning_strength
|
| 100 |
-
n_pixel_frames = item.media_item.shape[2]
|
| 101 |
if media_frame_number == 0:
|
| 102 |
media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
|
| 103 |
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 104 |
init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
|
| 105 |
init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
|
| 106 |
else:
|
| 107 |
-
|
| 108 |
-
pass
|
| 109 |
|
| 110 |
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 111 |
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
|
@@ -194,7 +193,6 @@ class LtxPoolManager:
|
|
| 194 |
self.current_worker_index = 0
|
| 195 |
self.lock = threading.Lock()
|
| 196 |
|
| 197 |
-
# <<< NEW: APPLY PATCH AFTER INITIALIZATION >>>
|
| 198 |
self._apply_ltx_pipeline_patches()
|
| 199 |
|
| 200 |
if all(w.device.type == 'cuda' for w in self.workers):
|
|
@@ -208,12 +206,9 @@ class LtxPoolManager:
|
|
| 208 |
def _apply_ltx_pipeline_patches(self):
|
| 209 |
"""
|
| 210 |
Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
|
| 211 |
-
This is where the monkey patching happens.
|
| 212 |
"""
|
| 213 |
logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
|
| 214 |
for worker in self.workers:
|
| 215 |
-
# The __get__ method binds our standalone function to the pipeline instance,
|
| 216 |
-
# making it behave like a regular method (so 'self' works correctly).
|
| 217 |
worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
|
| 218 |
logger.info("LTX POOL MANAGER: All pipeline instances have been patched successfully.")
|
| 219 |
|
|
@@ -224,7 +219,6 @@ class LtxPoolManager:
|
|
| 224 |
return worker
|
| 225 |
|
| 226 |
def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
|
| 227 |
-
# This function remains unchanged
|
| 228 |
pipeline_params = {
|
| 229 |
"height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
|
| 230 |
"frame_rate": kwargs.get('video_fps', 24),
|
|
@@ -254,7 +248,6 @@ class LtxPoolManager:
|
|
| 254 |
return pipeline_params
|
| 255 |
|
| 256 |
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
|
| 257 |
-
# This function remains unchanged
|
| 258 |
worker_to_use = self._get_next_worker()
|
| 259 |
try:
|
| 260 |
height, width = kwargs['height'], kwargs['width']
|
|
@@ -277,7 +270,6 @@ class LtxPoolManager:
|
|
| 277 |
gc.collect(); torch.cuda.empty_cache()
|
| 278 |
|
| 279 |
def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
|
| 280 |
-
# This function remains unchanged
|
| 281 |
worker_to_use = self._get_next_worker()
|
| 282 |
try:
|
| 283 |
_b, _c, _f, latent_h, latent_w = latents_to_refine.shape
|
|
|
|
| 1 |
+
# managers/ltx_manager.py
|
| 2 |
#
|
| 3 |
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
|
|
|
| 25 |
# We need these types for our patch
|
| 26 |
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
|
| 27 |
from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
|
| 28 |
+
from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline
|
| 29 |
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
|
|
|
| 71 |
media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
|
| 72 |
media_frame_number = item.media_frame_number
|
| 73 |
strength = item.conditioning_strength
|
| 74 |
+
|
|
|
|
| 75 |
if media_frame_number == 0:
|
| 76 |
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 77 |
init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
|
|
|
|
| 87 |
extra_conditioning_latents.append(patched_latents)
|
| 88 |
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 89 |
extra_conditioning_mask.append(new_mask)
|
| 90 |
+
else: # Original pixel-based logic for fallback
|
| 91 |
for item in conditioning_items:
|
| 92 |
if not isinstance(item, ConditioningItem): continue
|
| 93 |
item = self._resize_conditioning_item(item, height, width)
|
|
|
|
| 97 |
).to(dtype=init_latents.dtype)
|
| 98 |
media_frame_number = item.media_frame_number
|
| 99 |
strength = item.conditioning_strength
|
|
|
|
| 100 |
if media_frame_number == 0:
|
| 101 |
media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
|
| 102 |
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 103 |
init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
|
| 104 |
init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
|
| 105 |
else:
|
| 106 |
+
logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
|
| 107 |
+
pass
|
| 108 |
|
| 109 |
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 110 |
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
|
|
|
| 193 |
self.current_worker_index = 0
|
| 194 |
self.lock = threading.Lock()
|
| 195 |
|
|
|
|
| 196 |
self._apply_ltx_pipeline_patches()
|
| 197 |
|
| 198 |
if all(w.device.type == 'cuda' for w in self.workers):
|
|
|
|
| 206 |
def _apply_ltx_pipeline_patches(self):
|
| 207 |
"""
|
| 208 |
Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
|
|
|
|
| 209 |
"""
|
| 210 |
logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
|
| 211 |
for worker in self.workers:
|
|
|
|
|
|
|
| 212 |
worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
|
| 213 |
logger.info("LTX POOL MANAGER: All pipeline instances have been patched successfully.")
|
| 214 |
|
|
|
|
| 219 |
return worker
|
| 220 |
|
| 221 |
def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
|
|
|
|
| 222 |
pipeline_params = {
|
| 223 |
"height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
|
| 224 |
"frame_rate": kwargs.get('video_fps', 24),
|
|
|
|
| 248 |
return pipeline_params
|
| 249 |
|
| 250 |
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
|
|
|
|
| 251 |
worker_to_use = self._get_next_worker()
|
| 252 |
try:
|
| 253 |
height, width = kwargs['height'], kwargs['width']
|
|
|
|
| 270 |
gc.collect(); torch.cuda.empty_cache()
|
| 271 |
|
| 272 |
def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
|
|
|
|
| 273 |
worker_to_use = self._get_next_worker()
|
| 274 |
try:
|
| 275 |
_b, _c, _f, latent_h, latent_w = latents_to_refine.shape
|