eeuuia commited on
Commit
7809765
·
verified ·
1 Parent(s): 891ecfe

Update api/ltx/ltx_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_pipeline.py +62 -58
api/ltx/ltx_aduc_pipeline.py CHANGED
@@ -57,6 +57,8 @@ FRAMES_ALIGNMENT = 8
57
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
58
  if repo_path not in sys.path:
59
  sys.path.insert(0, repo_path)
 
 
60
 
61
  # ==============================================================================
62
  # --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
@@ -130,11 +132,58 @@ class LtxAducPipeline:
130
  strengths=[item[2] for item in initial_media_items],
131
  target_resolution=(kwargs['height'], kwargs['width'])
132
  )
 
 
 
 
 
 
 
 
133
 
134
- temp_latent_paths = []
135
- overlap_condition_item: Optional[LatentConditioningItem] = None
136
- current_conditions = initial_conditions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
 
 
 
 
138
  try:
139
  for i, chunk_prompt in enumerate(prompt_list):
140
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
@@ -143,12 +192,14 @@ class LtxAducPipeline:
143
  current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
144
  current_frames = self._align(current_frames, alignment_rule='n*8+1')
145
 
146
- kwargs.pop("prompt", None)
147
- kwargs.pop("num_frames", None)
148
- kwargs["prompt"] = chunk_prompt
149
- kwargs["num_frames"] = current_frames
150
 
151
- chunk_latents = self._generate_single_chunk_low(**kwargs)
 
 
152
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
153
 
154
  if is_narrative and i < num_chunks - 1:
@@ -158,10 +209,10 @@ class LtxAducPipeline:
158
  media_frame_number=0,
159
  conditioning_strength=1.0
160
  )
161
- kwargs.pop("conditioning_items", None)
162
- kwargs["conditioning_items"] = overlap_condition_item
163
  else:
164
- kwargs.pop("conditioning_items", None)
165
 
166
  if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
167
 
@@ -218,53 +269,6 @@ class LtxAducPipeline:
218
  # Usa o logger de debug para imprimir a mensagem completa
219
  logging.info("\n".join(log_str))
220
 
221
-
222
- @log_function_io
223
- def _generate_single_chunk_low(
224
- self, **kwargs,
225
- ) -> Optional[torch.Tensor]:
226
- """[WORKER] Calls the pipeline to generate a single chunk of latents."""
227
- height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
228
- downscale_factor = self.config.get("downscale_factor", 0.6666666)
229
- vae_scale_factor = self.pipeline.vae_scale_factor
230
- downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
231
- downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
232
-
233
- call_kwargs = {
234
- "cfg_star_rescale": "true",
235
- "prompt": kwargs["prompt"],
236
- "negative_prompt": kwargs['negative_prompt'],
237
- "height": downscaled_height,
238
- "width": downscaled_width,
239
- "num_frames": kwargs["num_frames"],
240
- "frame_rate": int(DEFAULT_FPS),
241
- "generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
242
- "output_type": "latent",
243
- "media_items": None,
244
- "decode_timestep": self.config["decode_timestep"],
245
- "decode_noise_scale": self.config["decode_noise_scale"],
246
- "stochastic_sampling": self.config["stochastic_sampling"],
247
- "image_cond_noise_scale": 0.05,
248
- "is_video": True,
249
- "vae_per_channel_normalize": True,
250
- "mixed_precision": (self.config["precision"] == "mixed_precision"),
251
- "offload_to_cpu": False,
252
- "enhance_prompt": False,
253
- }
254
-
255
- call_kwargs.pop("num_inference_steps", None)
256
- call_kwargs.pop("second_pass", None)
257
- first_pass_config = self.config.get("first_pass", {}).copy()
258
- call_kwargs.update(first_pass_config)
259
- ltx_configs_override = kwargs.get("ltx_configs_override", {}).copy()
260
- call_kwargs.update(ltx_configs_override)
261
- call_kwargs['conditioning_items'] = kwargs["conditioning_items"]
262
-
263
- with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
264
- latents_raw = self.pipeline(**call_kwargs).images
265
-
266
- return latents_raw.to(self.main_device)
267
-
268
  @log_function_io
269
  def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
270
  """Delegates final decoding and encoding to specialist services."""
 
57
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
58
  if repo_path not in sys.path:
59
  sys.path.insert(0, repo_path)
60
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
61
+
62
 
63
  # ==============================================================================
64
  # --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
 
132
  strengths=[item[2] for item in initial_media_items],
133
  target_resolution=(kwargs['height'], kwargs['width'])
134
  )
135
+
136
+ height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
137
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
138
+ vae_scale_factor = self.pipeline.vae_scale_factor
139
+ downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
140
+ downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
141
+
142
+ call_kwargs = self.config.get("first_pass", {}).copy()
143
 
144
+ stg_mode_str = self.config.get("stg_mode", "attention_values")
145
+ if stg_mode_str.lower() in ["stg_av", "attention_values"]:
146
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
147
+ elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
148
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
149
+ elif stg_mode_str.lower() in ["stg_r", "residual"]:
150
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
151
+ elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
152
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
153
+
154
+ call_kwargs = {
155
+ "skip_initial_inference_steps": 0,
156
+ "skip_final_inference_steps": 0,
157
+ "num_inference_steps": 20,
158
+ "negative_prompt": kwargs['negative_prompt'],
159
+ "height": downscaled_height,
160
+ "width": downscaled_width,
161
+ "guidance_scale": 4,
162
+ "stg_scale": self.config.get("stg_scale")
163
+ "rescaling_scale": self.config.get("rescaling_scale")
164
+ "skip_block_list": self.config.get("skip_block_list")
165
+ "frame_rate": int(DEFAULT_FPS),
166
+ "generator": torch.Generator(device=self.main_device).manual_seed(self._get_random_seed()),
167
+ "output_type": "latent",
168
+ "media_items": None,
169
+ "decode_timestep": self.config["decode_timestep"],
170
+ "decode_noise_scale": self.config["decode_noise_scale"],
171
+ "stochastic_sampling": self.config["stochastic_sampling"],
172
+ "image_cond_noise_scale": 0.15,
173
+ "is_video": True,
174
+ "vae_per_channel_normalize": True,
175
+ "mixed_precision": (self.config["precision"] == "mixed_precision"),
176
+ "offload_to_cpu": False,
177
+ "enhance_prompt": False,
178
+ }
179
 
180
+ ltx_configs_override = self.config.get("ltx_configs_override", {}).copy()
181
+ call_kwargs.update(ltx_configs_override)
182
+
183
+ if initial_conditions is not None:
184
+ call_kwargs["conditioning_items"] = initial_conditions
185
+
186
+ temp_latent_paths = []
187
  try:
188
  for i, chunk_prompt in enumerate(prompt_list):
189
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
 
192
  current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
193
  current_frames = self._align(current_frames, alignment_rule='n*8+1')
194
 
195
+ call_kwargs.pop("prompt", None)
196
+ call_kwargs.pop("num_frames", None)
197
+ call_kwargs["prompt"] = chunk_prompt
198
+ call_kwargs["num_frames"] = current_frames
199
 
200
+ with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
201
+ chunk_latents = self.pipeline(**call_kwargs).images
202
+
203
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
204
 
205
  if is_narrative and i < num_chunks - 1:
 
209
  media_frame_number=0,
210
  conditioning_strength=1.0
211
  )
212
+ call_kwargs.pop("conditioning_items", None)
213
+ call_kwargs["conditioning_items"] = overlap_condition_item
214
  else:
215
+ call_kwargsl.pop("conditioning_items", None)
216
 
217
  if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
218
 
 
269
  # Usa o logger de debug para imprimir a mensagem completa
270
  logging.info("\n".join(log_str))
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  @log_function_io
273
  def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
274
  """Delegates final decoding and encoding to specialist services."""