Spaces:
Paused
Paused
Update api/ltx_server.py
Browse files- api/ltx_server.py +64 -60
api/ltx_server.py
CHANGED
|
@@ -246,7 +246,12 @@ def log_tensor_info(tensor, name="Tensor"):
|
|
| 246 |
print("------------------------------------------\n")
|
| 247 |
|
| 248 |
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
|
| 252 |
# --- 5. CLASSE PRINCIPAL DO SERVIÇO ---
|
|
@@ -289,73 +294,72 @@ class VideoService:
|
|
| 289 |
print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
|
| 290 |
|
| 291 |
|
| 292 |
-
def _aduc_prepare_conditioning_patch(
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 304 |
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
|
| 308 |
-
extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
|
| 309 |
-
extra_conditioning_num_latents = 0
|
| 310 |
-
|
| 311 |
-
for item in conditioning_items:
|
| 312 |
-
if not isinstance(item, LatentConditioningItem):
|
| 313 |
-
logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
|
| 314 |
-
continue
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
|
| 319 |
-
|
| 320 |
-
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 321 |
-
init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
|
| 322 |
-
init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
|
| 323 |
-
else:
|
| 324 |
-
noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
|
| 325 |
-
media_item_latents = torch.lerp(noise, media_item_latents, strength)
|
| 326 |
-
patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
|
| 327 |
-
pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 328 |
-
pixel_coords[:, 0] += media_frame_number
|
| 329 |
-
extra_conditioning_num_latents += patched_latents.shape[1]
|
| 330 |
-
new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
|
| 331 |
-
extra_conditioning_latents.append(patched_latents)
|
| 332 |
-
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 333 |
-
extra_conditioning_mask.append(new_mask)
|
| 334 |
-
|
| 335 |
-
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 336 |
-
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 337 |
-
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
|
| 338 |
-
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
| 344 |
|
| 345 |
-
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
def _apply_ltx_pipeline_patches(self):
|
| 349 |
-
"""Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
|
| 350 |
-
logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
|
| 351 |
-
for worker in self.workers:
|
| 352 |
-
worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
|
| 353 |
-
logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
self._apply_ltx_pipeline_patches()
|
| 357 |
|
|
|
|
| 358 |
|
|
|
|
| 359 |
|
| 360 |
|
| 361 |
def _log_gpu_memory(self, stage_name: str):
|
|
|
|
| 246 |
print("------------------------------------------\n")
|
| 247 |
|
| 248 |
|
| 249 |
+
@dataclass
|
| 250 |
+
class LatentConditioningItem:
|
| 251 |
+
"""Item de dados para condicionamento da pipeline LTX."""
|
| 252 |
+
latent_tensor: torch.Tensor
|
| 253 |
+
media_frame_number: int
|
| 254 |
+
conditioning_strength: float
|
| 255 |
|
| 256 |
|
| 257 |
# --- 5. CLASSE PRINCIPAL DO SERVIÇO ---
|
|
|
|
| 294 |
print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
|
| 295 |
|
| 296 |
|
| 297 |
+
def _aduc_prepare_conditioning_patch(
|
| 298 |
+
self: "LTXVideoPipeline",
|
| 299 |
+
conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
|
| 300 |
+
init_latents: torch.Tensor,
|
| 301 |
+
num_frames: int,
|
| 302 |
+
height: int,
|
| 303 |
+
width: int,
|
| 304 |
+
vae_per_channel_normalize: bool = False,
|
| 305 |
+
generator=None,
|
| 306 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 307 |
+
if not conditioning_items:
|
| 308 |
+
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 309 |
+
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 310 |
+
return init_latents, init_pixel_coords, None, 0
|
| 311 |
+
|
| 312 |
+
init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
|
| 313 |
+
extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
|
| 314 |
+
extra_conditioning_num_latents = 0
|
| 315 |
+
|
| 316 |
+
for item in conditioning_items:
|
| 317 |
+
if not isinstance(item, LatentConditioningItem):
|
| 318 |
+
logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
|
| 322 |
+
media_frame_number, strength = item.media_frame_number, item.conditioning_strength
|
| 323 |
+
|
| 324 |
+
if media_frame_number == 0:
|
| 325 |
+
f_l, h_l, w_l = media_item_latents.shape[-3:]
|
| 326 |
+
init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
|
| 327 |
+
init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
|
| 328 |
+
else:
|
| 329 |
+
noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
|
| 330 |
+
media_item_latents = torch.lerp(noise, media_item_latents, strength)
|
| 331 |
+
patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
|
| 332 |
+
pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 333 |
+
pixel_coords[:, 0] += media_frame_number
|
| 334 |
+
extra_conditioning_num_latents += patched_latents.shape[1]
|
| 335 |
+
new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
|
| 336 |
+
extra_conditioning_latents.append(patched_latents)
|
| 337 |
+
extra_conditioning_pixel_coords.append(pixel_coords)
|
| 338 |
+
extra_conditioning_mask.append(new_mask)
|
| 339 |
+
|
| 340 |
init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
|
| 341 |
init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
|
| 342 |
+
init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
|
| 343 |
+
init_conditioning_mask = init_conditioning_mask.squeeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
+
if extra_conditioning_latents:
|
| 346 |
+
init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
|
| 347 |
+
init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
|
| 348 |
+
init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
|
| 349 |
|
| 350 |
+
return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
def _apply_ltx_pipeline_patches(self):
|
| 353 |
+
"""Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
|
| 354 |
+
logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
|
| 355 |
+
for worker in self.workers:
|
| 356 |
+
worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
|
| 357 |
+
logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
+
self._apply_ltx_pipeline_patches()
|
| 361 |
|
| 362 |
+
|
| 363 |
|
| 364 |
|
| 365 |
def _log_gpu_memory(self, stage_name: str):
|