Spaces:
Paused
Paused
Update api/ltx/ltx_aduc_pipeline.py
Browse files- api/ltx/ltx_aduc_pipeline.py +62 -58
api/ltx/ltx_aduc_pipeline.py
CHANGED
|
@@ -57,6 +57,8 @@ FRAMES_ALIGNMENT = 8
|
|
| 57 |
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
|
| 58 |
if repo_path not in sys.path:
|
| 59 |
sys.path.insert(0, repo_path)
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# ==============================================================================
|
| 62 |
# --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
|
|
@@ -130,11 +132,58 @@ class LtxAducPipeline:
|
|
| 130 |
strengths=[item[2] for item in initial_media_items],
|
| 131 |
target_resolution=(kwargs['height'], kwargs['width'])
|
| 132 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
try:
|
| 139 |
for i, chunk_prompt in enumerate(prompt_list):
|
| 140 |
logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
|
|
@@ -143,12 +192,14 @@ class LtxAducPipeline:
|
|
| 143 |
current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
|
| 144 |
current_frames = self._align(current_frames, alignment_rule='n*8+1')
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
|
| 153 |
|
| 154 |
if is_narrative and i < num_chunks - 1:
|
|
@@ -158,10 +209,10 @@ class LtxAducPipeline:
|
|
| 158 |
media_frame_number=0,
|
| 159 |
conditioning_strength=1.0
|
| 160 |
)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
else:
|
| 164 |
-
|
| 165 |
|
| 166 |
if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
|
| 167 |
|
|
@@ -218,53 +269,6 @@ class LtxAducPipeline:
|
|
| 218 |
# Usa o logger de debug para imprimir a mensagem completa
|
| 219 |
logging.info("\n".join(log_str))
|
| 220 |
|
| 221 |
-
|
| 222 |
-
@log_function_io
|
| 223 |
-
def _generate_single_chunk_low(
|
| 224 |
-
self, **kwargs,
|
| 225 |
-
) -> Optional[torch.Tensor]:
|
| 226 |
-
"""[WORKER] Calls the pipeline to generate a single chunk of latents."""
|
| 227 |
-
height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
|
| 228 |
-
downscale_factor = self.config.get("downscale_factor", 0.6666666)
|
| 229 |
-
vae_scale_factor = self.pipeline.vae_scale_factor
|
| 230 |
-
downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
|
| 231 |
-
downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
|
| 232 |
-
|
| 233 |
-
call_kwargs = {
|
| 234 |
-
"cfg_star_rescale": "true",
|
| 235 |
-
"prompt": kwargs["prompt"],
|
| 236 |
-
"negative_prompt": kwargs['negative_prompt'],
|
| 237 |
-
"height": downscaled_height,
|
| 238 |
-
"width": downscaled_width,
|
| 239 |
-
"num_frames": kwargs["num_frames"],
|
| 240 |
-
"frame_rate": int(DEFAULT_FPS),
|
| 241 |
-
"generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
|
| 242 |
-
"output_type": "latent",
|
| 243 |
-
"media_items": None,
|
| 244 |
-
"decode_timestep": self.config["decode_timestep"],
|
| 245 |
-
"decode_noise_scale": self.config["decode_noise_scale"],
|
| 246 |
-
"stochastic_sampling": self.config["stochastic_sampling"],
|
| 247 |
-
"image_cond_noise_scale": 0.05,
|
| 248 |
-
"is_video": True,
|
| 249 |
-
"vae_per_channel_normalize": True,
|
| 250 |
-
"mixed_precision": (self.config["precision"] == "mixed_precision"),
|
| 251 |
-
"offload_to_cpu": False,
|
| 252 |
-
"enhance_prompt": False,
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
call_kwargs.pop("num_inference_steps", None)
|
| 256 |
-
call_kwargs.pop("second_pass", None)
|
| 257 |
-
first_pass_config = self.config.get("first_pass", {}).copy()
|
| 258 |
-
call_kwargs.update(first_pass_config)
|
| 259 |
-
ltx_configs_override = kwargs.get("ltx_configs_override", {}).copy()
|
| 260 |
-
call_kwargs.update(ltx_configs_override)
|
| 261 |
-
call_kwargs['conditioning_items'] = kwargs["conditioning_items"]
|
| 262 |
-
|
| 263 |
-
with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
|
| 264 |
-
latents_raw = self.pipeline(**call_kwargs).images
|
| 265 |
-
|
| 266 |
-
return latents_raw.to(self.main_device)
|
| 267 |
-
|
| 268 |
@log_function_io
|
| 269 |
def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
|
| 270 |
"""Delegates final decoding and encoding to specialist services."""
|
|
|
|
| 57 |
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
|
| 58 |
if repo_path not in sys.path:
|
| 59 |
sys.path.insert(0, repo_path)
|
| 60 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 61 |
+
|
| 62 |
|
| 63 |
# ==============================================================================
|
| 64 |
# --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
|
|
|
|
| 132 |
strengths=[item[2] for item in initial_media_items],
|
| 133 |
target_resolution=(kwargs['height'], kwargs['width'])
|
| 134 |
)
|
| 135 |
+
|
| 136 |
+
height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
|
| 137 |
+
downscale_factor = self.config.get("downscale_factor", 0.6666666)
|
| 138 |
+
vae_scale_factor = self.pipeline.vae_scale_factor
|
| 139 |
+
downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
|
| 140 |
+
downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
|
| 141 |
+
|
| 142 |
+
call_kwargs = self.config.get("first_pass", {}).copy()
|
| 143 |
|
| 144 |
+
stg_mode_str = self.config.get("stg_mode", "attention_values")
|
| 145 |
+
if stg_mode_str.lower() in ["stg_av", "attention_values"]:
|
| 146 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
|
| 147 |
+
elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
|
| 148 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
|
| 149 |
+
elif stg_mode_str.lower() in ["stg_r", "residual"]:
|
| 150 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
|
| 151 |
+
elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
|
| 152 |
+
call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
|
| 153 |
+
|
| 154 |
+
call_kwargs = {
|
| 155 |
+
"skip_initial_inference_steps": 0,
|
| 156 |
+
"skip_final_inference_steps": 0,
|
| 157 |
+
"num_inference_steps": 20,
|
| 158 |
+
"negative_prompt": kwargs['negative_prompt'],
|
| 159 |
+
"height": downscaled_height,
|
| 160 |
+
"width": downscaled_width,
|
| 161 |
+
"guidance_scale": 4,
|
| 162 |
+
"stg_scale": self.config.get("stg_scale")
|
| 163 |
+
"rescaling_scale": self.config.get("rescaling_scale")
|
| 164 |
+
"skip_block_list": self.config.get("skip_block_list")
|
| 165 |
+
"frame_rate": int(DEFAULT_FPS),
|
| 166 |
+
"generator": torch.Generator(device=self.main_device).manual_seed(self._get_random_seed()),
|
| 167 |
+
"output_type": "latent",
|
| 168 |
+
"media_items": None,
|
| 169 |
+
"decode_timestep": self.config["decode_timestep"],
|
| 170 |
+
"decode_noise_scale": self.config["decode_noise_scale"],
|
| 171 |
+
"stochastic_sampling": self.config["stochastic_sampling"],
|
| 172 |
+
"image_cond_noise_scale": 0.15,
|
| 173 |
+
"is_video": True,
|
| 174 |
+
"vae_per_channel_normalize": True,
|
| 175 |
+
"mixed_precision": (self.config["precision"] == "mixed_precision"),
|
| 176 |
+
"offload_to_cpu": False,
|
| 177 |
+
"enhance_prompt": False,
|
| 178 |
+
}
|
| 179 |
|
| 180 |
+
ltx_configs_override = self.config.get("ltx_configs_override", {}).copy()
|
| 181 |
+
call_kwargs.update(ltx_configs_override)
|
| 182 |
+
|
| 183 |
+
if initial_conditions is not None:
|
| 184 |
+
call_kwargs["conditioning_items"] = initial_conditions
|
| 185 |
+
|
| 186 |
+
temp_latent_paths = []
|
| 187 |
try:
|
| 188 |
for i, chunk_prompt in enumerate(prompt_list):
|
| 189 |
logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
|
|
|
|
| 192 |
current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
|
| 193 |
current_frames = self._align(current_frames, alignment_rule='n*8+1')
|
| 194 |
|
| 195 |
+
call_kwargs.pop("prompt", None)
|
| 196 |
+
call_kwargs.pop("num_frames", None)
|
| 197 |
+
call_kwargs["prompt"] = chunk_prompt
|
| 198 |
+
call_kwargs["num_frames"] = current_frames
|
| 199 |
|
| 200 |
+
with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
|
| 201 |
+
chunk_latents = self.pipeline(**call_kwargs).images
|
| 202 |
+
|
| 203 |
if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
|
| 204 |
|
| 205 |
if is_narrative and i < num_chunks - 1:
|
|
|
|
| 209 |
media_frame_number=0,
|
| 210 |
conditioning_strength=1.0
|
| 211 |
)
|
| 212 |
+
call_kwargs.pop("conditioning_items", None)
|
| 213 |
+
call_kwargs["conditioning_items"] = overlap_condition_item
|
| 214 |
else:
|
| 215 |
+
call_kwargsl.pop("conditioning_items", None)
|
| 216 |
|
| 217 |
if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
|
| 218 |
|
|
|
|
| 269 |
# Usa o logger de debug para imprimir a mensagem completa
|
| 270 |
logging.info("\n".join(log_str))
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
@log_function_io
|
| 273 |
def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
|
| 274 |
"""Delegates final decoding and encoding to specialist services."""
|