eeuuia commited on
Commit
b2348be
·
verified ·
1 Parent(s): 5ae8585

Update api/ltx/ltx_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_pipeline.py +26 -15
api/ltx/ltx_aduc_pipeline.py CHANGED
@@ -150,8 +150,7 @@ class LtxAducPipeline:
150
  elif stg_mode_str.lower() in ["stg_r", "residual"]: stg_strategy = SkipLayerStrategy.Residual
151
  elif stg_mode_str.lower() in ["stg_t", "transformer_block"]: stg_strategy = SkipLayerStrategy.TransformerBlock
152
 
153
-
154
-
155
  height_padded = ((kwargs['height'] - 1) // 8 + 1) * 8
156
  width_padded = ((kwargs['width'] - 1) // 8 + 1) * 8
157
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
@@ -161,21 +160,37 @@ class LtxAducPipeline:
161
  x_height = int(height_padded * downscale_factor)
162
  downscaled_height = x_height - (x_height % vae_scale_factor)
163
 
164
-
165
  call_kwargs = {
166
  "height": downscaled_height,
167
  "width": downscaled_width,
168
- "skip_initial_inference_steps": 0, "skip_final_inference_steps": 0, "num_inference_steps": 20,
 
 
169
  "negative_prompt": kwargs['negative_prompt'],
170
- "guidance_scale": 4, "stg_scale": self.config.get("stg_scale", 4),
171
- "rescaling_scale": self.config.get("rescaling_scale", 0.7), "skip_layer_strategy": stg_strategy,
172
- "skip_block_list": self.config.get("skip_block_list", None), "frame_rate": int(DEFAULT_FPS),
 
 
173
  "generator": torch.Generator(device=self.main_device).manual_seed(self._get_random_seed()),
174
- "output_type": "latent", "media_items": None, "decode_timestep": self.config.get("decode_timestep", None),
175
- "decode_noise_scale": self.config.get("decode_noise_scale", None), "stochastic_sampling": self.config.get("stochastic_sampling", None),
176
- "image_cond_noise_scale": 0.15, "is_video": True, "vae_per_channel_normalize": True,
177
- "mixed_precision": (self.config["precision"] == "mixed_precision"), "offload_to_cpu": False,
 
 
 
 
178
  "enhance_prompt": False,
 
 
 
 
 
 
 
 
 
179
  }
180
 
181
  ltx_configs_override = kwargs.get("ltx_configs_override", {})
@@ -185,11 +200,7 @@ class LtxAducPipeline:
185
  # --- ETAPA 1: GERAÇÃO DE CHUNKS E SALVAMENTO ---
186
  for i, chunk_prompt in enumerate(prompt_list):
187
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
188
- current_frames_base = frames_per_chunk if i < num_chunks - 1 else total_frames - ((num_chunks - 1) * frames_per_chunk)
189
- current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
190
- current_frames = self._align(current_frames, alignment_rule='n*8+1')
191
  call_kwargs["prompt"] = chunk_prompt
192
- call_kwargs["num_frames"] = current_frames
193
 
194
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
195
  chunk_latents = self.pipeline(**call_kwargs).images
 
150
  elif stg_mode_str.lower() in ["stg_r", "residual"]: stg_strategy = SkipLayerStrategy.Residual
151
  elif stg_mode_str.lower() in ["stg_t", "transformer_block"]: stg_strategy = SkipLayerStrategy.TransformerBlock
152
 
153
+
 
154
  height_padded = ((kwargs['height'] - 1) // 8 + 1) * 8
155
  width_padded = ((kwargs['width'] - 1) // 8 + 1) * 8
156
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
 
160
  x_height = int(height_padded * downscale_factor)
161
  downscaled_height = x_height - (x_height % vae_scale_factor)
162
 
 
163
  call_kwargs = {
164
  "height": downscaled_height,
165
  "width": downscaled_width,
166
+ "skip_initial_inference_steps": 3,
167
+ "skip_final_inference_steps": 0,
168
+ "num_inference_steps": 30,
169
  "negative_prompt": kwargs['negative_prompt'],
170
+ "guidance_scale": self.config.get("guidance_scale", [1, 1, 6, 8, 6, 1, 1]),
171
+ "stg_scale": self.config.get("stg_scale", [0, 0, 4, 4, 4, 2, 1]),
172
+ "rescaling_scale": self.config.get("rescaling_scale", [1, 1, 0.5, 0.5, 1, 1, 1]),
173
+ "skip_block_list": self.config.get("skip_block_list", [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]),
174
+ "frame_rate": int(DEFAULT_FPS),
175
  "generator": torch.Generator(device=self.main_device).manual_seed(self._get_random_seed()),
176
+ "output_type": "latent",
177
+ "media_items": None,
178
+ "decode_timestep": self.config.get("decode_timestep", 0.05),
179
+ "decode_noise_scale": self.config.get("decode_noise_scale", 0.025),
180
+ "stochastic_sampling": self.config.get("stochastic_sampling", false),
181
+ "is_video": True,
182
+ "vae_per_channel_normalize": True,
183
+ "offload_to_cpu": False,
184
  "enhance_prompt": False,
185
+ "num_frames": total_frames,
186
+ "downscale_factor": self.config.get("downscale_factor", 0.6666666),
187
+ "rescaling_scale": self.config.get("rescaling_scale", [1, 1, 0.5, 0.5, 1, 1, 1]),
188
+ "guidance_timesteps": self.config.get("guidance_timesteps", [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]),
189
+ "skip_block_list": self.config.get("skip_block_list", [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]),
190
+ "sampler": self.config.get("sampler", "from_checkpoint"),
191
+ "precision": self.config.get("precision", "float8_e4m3fn"),
192
+ "stochastic_sampling": self.config.get("stochastic_sampling", False),
193
+ "cfg_star_rescale": self.config.get("cfg_star_rescale", True),
194
  }
195
 
196
  ltx_configs_override = kwargs.get("ltx_configs_override", {})
 
200
  # --- ETAPA 1: GERAÇÃO DE CHUNKS E SALVAMENTO ---
201
  for i, chunk_prompt in enumerate(prompt_list):
202
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
 
 
 
203
  call_kwargs["prompt"] = chunk_prompt
 
204
 
205
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
206
  chunk_latents = self.pipeline(**call_kwargs).images