euiia commited on
Commit
550dd1d
·
verified ·
1 Parent(s): 5da5952

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +30 -9
ltx_manager_helpers.py CHANGED
@@ -1,5 +1,9 @@
1
- # ltx_manager_helpers.py (Com Lógica de Refinamento Especializada)
2
  # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
 
 
 
 
3
 
4
  import torch
5
  import gc
@@ -19,6 +23,10 @@ from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem
19
  logger = logging.getLogger(__name__)
20
 
21
  class LtxWorker:
 
 
 
 
22
  def __init__(self, device_id, ltx_config_file):
23
  self.cpu_device = torch.device('cpu')
24
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
@@ -28,6 +36,7 @@ class LtxWorker:
28
  self.config = yaml.safe_load(file)
29
 
30
  self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
 
31
  models_dir = "downloaded_models_gradio"
32
 
33
  logger.info(f"LTX Worker ({self.device}): Carregando modelo para a CPU...")
@@ -46,9 +55,11 @@ class LtxWorker:
46
  logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
47
 
48
  def to_gpu(self):
 
49
  if self.device.type == 'cpu': return
50
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
51
  self.pipeline.to(self.device)
 
52
  if self.device.type == 'cuda' and can_optimize_fp8():
53
  logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
54
  optimize_ltx_worker(self)
@@ -57,6 +68,7 @@ class LtxWorker:
57
  logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada.")
58
 
59
  def to_cpu(self):
 
60
  if self.device.type == 'cpu': return
61
  logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
62
  self.pipeline.to('cpu')
@@ -64,14 +76,20 @@ class LtxWorker:
64
  if torch.cuda.is_available(): torch.cuda.empty_cache()
65
 
66
  def generate_video_fragment_internal(self, **kwargs):
 
67
  return self.pipeline(**kwargs).images
68
 
69
  class LtxPoolManager:
 
 
 
 
70
  def __init__(self, device_ids, ltx_config_file):
71
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
72
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
73
  self.current_worker_index = 0
74
  self.lock = threading.Lock()
 
75
  if all(w.device.type == 'cuda' for w in self.workers):
76
  logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
77
  for worker in self.workers:
@@ -93,12 +111,19 @@ class LtxPoolManager:
93
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
94
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
95
 
96
- conditioning_items = [item.to(worker_to_use.device) for item in kwargs.get('conditioning_items_data', [])]
 
 
 
 
 
 
 
97
 
98
  pipeline_params = {
99
  "height": padded_h, "width": padded_w, "num_frames": kwargs['video_total_frames'],
100
  "frame_rate": kwargs['video_fps'], "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
101
- "conditioning_items": conditioning_items, "is_video": True, "vae_per_channel_normalize": True,
102
  "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
103
  "guidance_scale": kwargs['guidance_scale'], "stg_scale": kwargs['stg_scale'],
104
  "rescaling_scale": kwargs['rescaling_scale'], "num_inference_steps": kwargs['num_inference_steps']
@@ -120,9 +145,6 @@ class LtxPoolManager:
120
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
121
  worker_to_use = self._get_next_worker()
122
  try:
123
- # --- [INÍCIO DA CORREÇÃO] ---
124
- # Para refinamento, as dimensões são derivadas DIRETAMENTE do tensor latente.
125
- # Não há padding. A resolução em pixels é passada, mas a forma latente é a fonte da verdade.
126
  height, width, num_frames = kwargs['height'], kwargs['width'], kwargs['video_total_frames']
127
 
128
  pipeline_params = {
@@ -131,16 +153,15 @@ class LtxPoolManager:
131
  "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
132
  "is_video": True, "vae_per_channel_normalize": True,
133
  "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
134
- "guidance_scale": kwargs.get('guidance_scale', 1.0), # Força 1.0 para refinamento incondicional se não especificado
135
  "num_inference_steps": int(kwargs.get('refine_steps', 10)),
136
  "strength": kwargs.get('denoise_strength', 0.4),
137
  "output_type": "latent"
138
  }
139
- # --- [FIM DA CORREÇÃO] ---
140
 
141
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise)...")
142
  result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
143
- return result, None # Nenhum padding é aplicado no refinamento
144
  except Exception as e:
145
  logger.error(f"LTX POOL MANAGER: Erro durante o refinamento em {worker_to_use.device}: {e}", exc_info=True)
146
  raise e
 
1
+ # ltx_manager_helpers.py
2
  # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
3
+ #
4
+ # Este programa é software livre: você pode redistribuí-lo e/ou modificá-lo
5
+ # sob os termos da Licença Pública Geral Affero GNU...
6
+ # AVISO DE PATENTE PENDENTE: Consulte NOTICE.md.
7
 
8
  import torch
9
  import gc
 
23
  logger = logging.getLogger(__name__)
24
 
25
  class LtxWorker:
26
+ """
27
+ Representa uma única instância da pipeline LTX-Video em um dispositivo específico.
28
+ Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
29
+ """
30
  def __init__(self, device_id, ltx_config_file):
31
  self.cpu_device = torch.device('cpu')
32
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
 
36
  self.config = yaml.safe_load(file)
37
 
38
  self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
39
+
40
  models_dir = "downloaded_models_gradio"
41
 
42
  logger.info(f"LTX Worker ({self.device}): Carregando modelo para a CPU...")
 
55
  logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
56
 
57
  def to_gpu(self):
58
+ """Move o pipeline para a GPU designada E OTIMIZA SE POSSÍVEL."""
59
  if self.device.type == 'cpu': return
60
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
61
  self.pipeline.to(self.device)
62
+
63
  if self.device.type == 'cuda' and can_optimize_fp8():
64
  logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
65
  optimize_ltx_worker(self)
 
68
  logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada.")
69
 
70
  def to_cpu(self):
71
+ """Move o pipeline de volta para a CPU e libera a memória da GPU."""
72
  if self.device.type == 'cpu': return
73
  logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
74
  self.pipeline.to('cpu')
 
76
  if torch.cuda.is_available(): torch.cuda.empty_cache()
77
 
78
  def generate_video_fragment_internal(self, **kwargs):
79
+ """Invoca a pipeline de geração."""
80
  return self.pipeline(**kwargs).images
81
 
82
  class LtxPoolManager:
83
+ """
84
+ Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs.
85
+ MODO "HOT START": Mantém todos os modelos carregados na VRAM para latência mínima.
86
+ """
87
  def __init__(self, device_ids, ltx_config_file):
88
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
89
  self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
90
  self.current_worker_index = 0
91
  self.lock = threading.Lock()
92
+
93
  if all(w.device.type == 'cuda' for w in self.workers):
94
  logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
95
  for worker in self.workers:
 
111
  padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
112
  padding_vals = calculate_padding(height, width, padded_h, padded_w)
113
 
114
+ # --- [INÍCIO DA CORREÇÃO] ---
115
+ # Move o tensor DENTRO de cada item de condicionamento para o dispositivo do worker.
116
+ conditioning_data = kwargs.get('conditioning_items_data', [])
117
+ final_conditioning_items = []
118
+ for item in conditioning_data:
119
+ item.latent_tensor = item.latent_tensor.to(worker_to_use.device)
120
+ final_conditioning_items.append(item)
121
+ # --- [FIM DA CORREÇÃO] ---
122
 
123
  pipeline_params = {
124
  "height": padded_h, "width": padded_w, "num_frames": kwargs['video_total_frames'],
125
  "frame_rate": kwargs['video_fps'], "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
126
+ "conditioning_items": final_conditioning_items, "is_video": True, "vae_per_channel_normalize": True,
127
  "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
128
  "guidance_scale": kwargs['guidance_scale'], "stg_scale": kwargs['stg_scale'],
129
  "rescaling_scale": kwargs['rescaling_scale'], "num_inference_steps": kwargs['num_inference_steps']
 
145
  def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
146
  worker_to_use = self._get_next_worker()
147
  try:
 
 
 
148
  height, width, num_frames = kwargs['height'], kwargs['width'], kwargs['video_total_frames']
149
 
150
  pipeline_params = {
 
153
  "generator": torch.Generator(device=worker_to_use.device).manual_seed(int(time.time()) + kwargs['current_fragment_index']),
154
  "is_video": True, "vae_per_channel_normalize": True,
155
  "prompt": kwargs['motion_prompt'], "negative_prompt": "blurry, distorted, static, bad quality",
156
+ "guidance_scale": kwargs.get('guidance_scale', 1.0),
157
  "num_inference_steps": int(kwargs.get('refine_steps', 10)),
158
  "strength": kwargs.get('denoise_strength', 0.4),
159
  "output_type": "latent"
160
  }
 
161
 
162
  logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise)...")
163
  result = worker_to_use.generate_video_fragment_internal(**pipeline_params)
164
+ return result, None
165
  except Exception as e:
166
  logger.error(f"LTX POOL MANAGER: Erro durante o refinamento em {worker_to_use.device}: {e}", exc_info=True)
167
  raise e