Eueuiaa commited on
Commit
8cf736e
·
verified ·
1 Parent(s): 8f9509f

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +64 -60
api/ltx_server.py CHANGED
@@ -246,7 +246,12 @@ def log_tensor_info(tensor, name="Tensor"):
246
  print("------------------------------------------\n")
247
 
248
 
249
-
 
 
 
 
 
250
 
251
 
252
  # --- 5. CLASSE PRINCIPAL DO SERVIÇO ---
@@ -289,73 +294,72 @@ class VideoService:
289
  print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
290
 
291
 
292
- def _aduc_prepare_conditioning_patch(
293
- self: "LTXVideoPipeline",
294
- conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
295
- init_latents: torch.Tensor,
296
- num_frames: int,
297
- height: int,
298
- width: int,
299
- vae_per_channel_normalize: bool = False,
300
- generator=None,
301
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
302
- if not conditioning_items:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
304
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
305
- return init_latents, init_pixel_coords, None, 0
306
-
307
- init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
308
- extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
309
- extra_conditioning_num_latents = 0
310
-
311
- for item in conditioning_items:
312
- if not isinstance(item, LatentConditioningItem):
313
- logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
314
- continue
315
 
316
- media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
317
- media_frame_number, strength = item.media_frame_number, item.conditioning_strength
 
 
318
 
319
- if media_frame_number == 0:
320
- f_l, h_l, w_l = media_item_latents.shape[-3:]
321
- init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
322
- init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
323
- else:
324
- noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
325
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
326
- patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
327
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
328
- pixel_coords[:, 0] += media_frame_number
329
- extra_conditioning_num_latents += patched_latents.shape[1]
330
- new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
331
- extra_conditioning_latents.append(patched_latents)
332
- extra_conditioning_pixel_coords.append(pixel_coords)
333
- extra_conditioning_mask.append(new_mask)
334
-
335
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
336
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
337
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
338
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
339
 
340
- if extra_conditioning_latents:
341
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
342
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
343
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
 
 
344
 
345
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
346
-
347
-
348
- def _apply_ltx_pipeline_patches(self):
349
- """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
350
- logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
351
- for worker in self.workers:
352
- worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
353
- logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")
354
-
355
-
356
- self._apply_ltx_pipeline_patches()
357
 
 
358
 
 
359
 
360
 
361
  def _log_gpu_memory(self, stage_name: str):
 
246
  print("------------------------------------------\n")
247
 
248
 
249
+ @dataclass
250
+ class LatentConditioningItem:
251
+ """Item de dados para condicionamento da pipeline LTX."""
252
+ latent_tensor: torch.Tensor
253
+ media_frame_number: int
254
+ conditioning_strength: float
255
 
256
 
257
  # --- 5. CLASSE PRINCIPAL DO SERVIÇO ---
 
294
  print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
295
 
296
 
297
+ def _aduc_prepare_conditioning_patch(
298
+ self: "LTXVideoPipeline",
299
+ conditioning_items: Optional[List[Union["ConditioningItem", "LatentConditioningItem"]]],
300
+ init_latents: torch.Tensor,
301
+ num_frames: int,
302
+ height: int,
303
+ width: int,
304
+ vae_per_channel_normalize: bool = False,
305
+ generator=None,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
307
+ if not conditioning_items:
308
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
309
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
310
+ return init_latents, init_pixel_coords, None, 0
311
+
312
+ init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
313
+ extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
314
+ extra_conditioning_num_latents = 0
315
+
316
+ for item in conditioning_items:
317
+ if not isinstance(item, LatentConditioningItem):
318
+ logger.warning("Patch ADUC: Item de condicionamento não é um LatentConditioningItem e será ignorado.")
319
+ continue
320
+
321
+ media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
322
+ media_frame_number, strength = item.media_frame_number, item.conditioning_strength
323
+
324
+ if media_frame_number == 0:
325
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
326
+ init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
327
+ init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
328
+ else:
329
+ noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
330
+ media_item_latents = torch.lerp(noise, media_item_latents, strength)
331
+ patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
332
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
333
+ pixel_coords[:, 0] += media_frame_number
334
+ extra_conditioning_num_latents += patched_latents.shape[1]
335
+ new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
336
+ extra_conditioning_latents.append(patched_latents)
337
+ extra_conditioning_pixel_coords.append(pixel_coords)
338
+ extra_conditioning_mask.append(new_mask)
339
+
340
  init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
341
  init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
342
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
343
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
 
 
 
 
 
 
 
 
344
 
345
+ if extra_conditioning_latents:
346
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
347
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
348
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
349
 
350
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ def _apply_ltx_pipeline_patches(self):
353
+ """Aplica patches em tempo de execução na pipeline LTX para compatibilidade com ADUC-SDR."""
354
+ logger.info("LTX POOL MANAGER: Aplicando patches ADUC-SDR na pipeline LTX...")
355
+ for worker in self.workers:
356
+ worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
357
+ logger.info("LTX POOL MANAGER: Todas as instâncias da pipeline foram corrigidas com sucesso.")
358
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ self._apply_ltx_pipeline_patches()
361
 
362
+
363
 
364
 
365
  def _log_gpu_memory(self, stage_name: str):