Update ltx_manager_helpers.py
Browse files- ltx_manager_helpers.py +27 -58
ltx_manager_helpers.py
CHANGED
|
@@ -105,87 +105,52 @@ class LtxPoolManager:
|
|
| 105 |
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
|
| 106 |
return worker
|
| 107 |
|
| 108 |
-
# --- [NOVO] Função centralizada para preparar parâmetros da pipeline ---
|
| 109 |
def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
|
| 110 |
"""Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
|
| 111 |
-
|
| 112 |
-
# Parâmetros obrigatórios
|
| 113 |
-
height = kwargs['height']
|
| 114 |
-
width = kwargs['width']
|
| 115 |
-
num_frames = kwargs['video_total_frames']
|
| 116 |
-
|
| 117 |
-
# Parâmetros com valores padrão
|
| 118 |
-
motion_prompt = kwargs.get('motion_prompt', "")
|
| 119 |
-
negative_prompt = kwargs.get('negative_prompt', "blurry, distorted, static, bad quality")
|
| 120 |
-
guidance_scale = kwargs.get('guidance_scale', 1.0)
|
| 121 |
-
stg_scale = kwargs.get('stg_scale', 0.0)
|
| 122 |
-
rescaling_scale = kwargs.get('rescaling_scale', 0.15)
|
| 123 |
-
num_inference_steps = kwargs.get('num_inference_steps', 20)
|
| 124 |
-
|
| 125 |
-
# Parâmetros opcionais (para geração ou refinamento)
|
| 126 |
-
latents_input = kwargs.get('latents')
|
| 127 |
-
strength = kwargs.get('strength')
|
| 128 |
-
conditioning_data = kwargs.get('conditioning_items_data')
|
| 129 |
-
|
| 130 |
-
# Prepara os itens de condicionamento se existirem
|
| 131 |
-
final_conditioning_items = []
|
| 132 |
-
if conditioning_data:
|
| 133 |
-
for item in conditioning_data:
|
| 134 |
-
item.latent_tensor = item.latent_tensor.to(worker.device)
|
| 135 |
-
final_conditioning_items.append(item)
|
| 136 |
-
|
| 137 |
-
# Constrói o dicionário base de parâmetros
|
| 138 |
pipeline_params = {
|
| 139 |
-
"height": height, "width": width, "num_frames":
|
| 140 |
"frame_rate": kwargs.get('video_fps', 24),
|
| 141 |
"generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
|
| 142 |
"is_video": True, "vae_per_channel_normalize": True,
|
| 143 |
-
"prompt": motion_prompt, "negative_prompt": negative_prompt,
|
| 144 |
-
"guidance_scale": guidance_scale, "stg_scale": stg_scale,
|
| 145 |
-
"rescaling_scale": rescaling_scale, "num_inference_steps": num_inference_steps,
|
| 146 |
"output_type": "latent"
|
| 147 |
}
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
| 155 |
pipeline_params["conditioning_items"] = final_conditioning_items
|
| 156 |
|
| 157 |
-
# --- LÓGICA CENTRALIZADA E À PROVA DE ERRO ---
|
| 158 |
-
# Se o modelo for destilado, sobrescreve os passos com os timesteps fixos obrigatórios.
|
| 159 |
if worker.is_distilled:
|
| 160 |
logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
|
| 161 |
fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
|
| 162 |
pipeline_params["timesteps"] = fixed_timesteps
|
| 163 |
if fixed_timesteps:
|
| 164 |
pipeline_params["num_inference_steps"] = len(fixed_timesteps)
|
| 165 |
-
|
| 166 |
-
# Log dos parâmetros para depuração
|
| 167 |
-
log_params = {k: v for k, v in pipeline_params.items() if k not in ['generator', 'latents', 'conditioning_items']}
|
| 168 |
-
logger.info(f"Parâmetros preparados para a pipeline em {worker.device}:\n{json.dumps(log_params, indent=2)}")
|
| 169 |
-
|
| 170 |
return pipeline_params
|
| 171 |
|
| 172 |
-
# --- [REATORADO] Função de Geração simplificada ---
|
| 173 |
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
|
| 174 |
worker_to_use = self._get_next_worker()
|
| 175 |
try:
|
| 176 |
-
#
|
| 177 |
height, width = kwargs['height'], kwargs['width']
|
| 178 |
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
|
| 179 |
padding_vals = calculate_padding(height, width, padded_h, padded_w)
|
| 180 |
-
|
| 181 |
-
# Atualiza kwargs com as dimensões com padding
|
| 182 |
-
kwargs['height'] = padded_h
|
| 183 |
-
kwargs['width'] = padded_w
|
| 184 |
|
| 185 |
-
# Prepara os parâmetros usando a função centralizada
|
| 186 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 187 |
|
| 188 |
-
|
|
|
|
| 189 |
if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
|
| 190 |
result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
|
| 191 |
else:
|
|
@@ -200,22 +165,26 @@ class LtxPoolManager:
|
|
| 200 |
with torch.cuda.device(worker_to_use.device):
|
| 201 |
gc.collect(); torch.cuda.empty_cache()
|
| 202 |
|
| 203 |
-
# --- [REATORADO] Função de Refinamento simplificada ---
|
| 204 |
def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
|
| 205 |
worker_to_use = self._get_next_worker()
|
| 206 |
try:
|
| 207 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
kwargs['latents'] = latents_to_refine
|
| 209 |
kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
|
| 210 |
kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
|
| 211 |
|
| 212 |
-
# Prepara os parâmetros usando a mesma função centralizada
|
| 213 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 214 |
|
| 215 |
-
logger.info("
|
| 216 |
|
| 217 |
pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
|
| 218 |
-
|
| 219 |
result = pipeline_to_call(**pipeline_params).images
|
| 220 |
return result, None
|
| 221 |
|
|
|
|
| 105 |
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
|
| 106 |
return worker
|
| 107 |
|
|
|
|
| 108 |
def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
|
| 109 |
"""Prepara o dicionário de parâmetros para a pipeline, tratando casos especiais como modelos destilados."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
pipeline_params = {
|
| 111 |
+
"height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
|
| 112 |
"frame_rate": kwargs.get('video_fps', 24),
|
| 113 |
"generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
|
| 114 |
"is_video": True, "vae_per_channel_normalize": True,
|
| 115 |
+
"prompt": kwargs.get('motion_prompt', ""), "negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality"),
|
| 116 |
+
"guidance_scale": kwargs.get('guidance_scale', 1.0), "stg_scale": kwargs.get('stg_scale', 0.0),
|
| 117 |
+
"rescaling_scale": kwargs.get('rescaling_scale', 0.15), "num_inference_steps": kwargs.get('num_inference_steps', 20),
|
| 118 |
"output_type": "latent"
|
| 119 |
}
|
| 120 |
|
| 121 |
+
if 'latents' in kwargs:
|
| 122 |
+
pipeline_params["latents"] = kwargs['latents'].to(worker.device, dtype=worker.pipeline.transformer.dtype)
|
| 123 |
+
if 'strength' in kwargs:
|
| 124 |
+
pipeline_params["strength"] = kwargs['strength']
|
| 125 |
+
if 'conditioning_items_data' in kwargs:
|
| 126 |
+
final_conditioning_items = []
|
| 127 |
+
for item in kwargs['conditioning_items_data']:
|
| 128 |
+
item.latent_tensor = item.latent_tensor.to(worker.device)
|
| 129 |
+
final_conditioning_items.append(item)
|
| 130 |
pipeline_params["conditioning_items"] = final_conditioning_items
|
| 131 |
|
|
|
|
|
|
|
| 132 |
if worker.is_distilled:
|
| 133 |
logger.info(f"Worker {worker.device} está usando um modelo destilado. Usando timesteps fixos.")
|
| 134 |
fixed_timesteps = worker.config.get("first_pass", {}).get("timesteps")
|
| 135 |
pipeline_params["timesteps"] = fixed_timesteps
|
| 136 |
if fixed_timesteps:
|
| 137 |
pipeline_params["num_inference_steps"] = len(fixed_timesteps)
|
| 138 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
return pipeline_params
|
| 140 |
|
|
|
|
| 141 |
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
|
| 142 |
worker_to_use = self._get_next_worker()
|
| 143 |
try:
|
| 144 |
+
# [CORREÇÃO] A lógica de padding é específica para a geração do zero.
|
| 145 |
height, width = kwargs['height'], kwargs['width']
|
| 146 |
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
|
| 147 |
padding_vals = calculate_padding(height, width, padded_h, padded_w)
|
| 148 |
+
kwargs['height'], kwargs['width'] = padded_h, padded_w
|
|
|
|
|
|
|
|
|
|
| 149 |
|
|
|
|
| 150 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 151 |
|
| 152 |
+
logger.info(f"Iniciando GERAÇÃO em {worker_to_use.device} com shape {padded_w}x{padded_h}")
|
| 153 |
+
|
| 154 |
if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline):
|
| 155 |
result = worker_to_use.pipeline.video_pipeline(**pipeline_params).images
|
| 156 |
else:
|
|
|
|
| 165 |
with torch.cuda.device(worker_to_use.device):
|
| 166 |
gc.collect(); torch.cuda.empty_cache()
|
| 167 |
|
|
|
|
| 168 |
def refine_latents(self, latents_to_refine: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
|
| 169 |
worker_to_use = self._get_next_worker()
|
| 170 |
try:
|
| 171 |
+
# [CORREÇÃO] A lógica de dimensionamento para refinamento deriva da forma do latente.
|
| 172 |
+
_b, _c, _f, latent_h, latent_w = latents_to_refine.shape
|
| 173 |
+
vae_scale_factor = worker_to_use.pipeline.vae_scale_factor
|
| 174 |
+
|
| 175 |
+
# Garante que as dimensões correspondam EXATAMENTE ao latente fornecido.
|
| 176 |
+
kwargs['height'] = latent_h * vae_scale_factor
|
| 177 |
+
kwargs['width'] = latent_w * vae_scale_factor
|
| 178 |
+
kwargs['video_total_frames'] = kwargs.get('video_total_frames', _f * worker_to_use.pipeline.video_scale_factor)
|
| 179 |
kwargs['latents'] = latents_to_refine
|
| 180 |
kwargs['strength'] = kwargs.get('denoise_strength', 0.4)
|
| 181 |
kwargs['num_inference_steps'] = int(kwargs.get('refine_steps', 10))
|
| 182 |
|
|
|
|
| 183 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 184 |
|
| 185 |
+
logger.info(f"Iniciando REFINAMENTO em {worker_to_use.device} com shape {kwargs['width']}x{kwargs['height']}")
|
| 186 |
|
| 187 |
pipeline_to_call = worker_to_use.pipeline.video_pipeline if isinstance(worker_to_use.pipeline, LTXMultiScalePipeline) else worker_to_use.pipeline
|
|
|
|
| 188 |
result = pipeline_to_call(**pipeline_params).images
|
| 189 |
return result, None
|
| 190 |
|