eeuuia commited on
Commit
7c1bfd4
·
verified ·
1 Parent(s): 79815c9

Update api/ltx_server_refactored_complete.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored_complete.py +74 -69
api/ltx_server_refactored_complete.py CHANGED
@@ -1,7 +1,7 @@
1
  # FILE: api/ltx_server_refactored_complete.py
2
- # DESCRIPTION: Final orchestrator for LTX-Video generation.
3
- # Features path resolution for cached models, dedicated VAE device logic,
4
- # delegation to utility modules, and advanced debug logging.
5
 
6
  import gc
7
  import json
@@ -13,7 +13,7 @@ import tempfile
13
  import time
14
  from pathlib import Path
15
  from typing import Dict, List, Optional, Tuple
16
- import random
17
  import torch
18
  import yaml
19
  import numpy as np
@@ -24,7 +24,6 @@ from huggingface_hub import hf_hub_download
24
  # ==============================================================================
25
 
26
  # Configuração de logging e supressão de warnings
27
- # (Pode ser removido se o logging for configurado globalmente)
28
  import warnings
29
  warnings.filterwarnings("ignore")
30
  logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
@@ -179,31 +178,46 @@ class VideoService:
179
  except Exception: pass
180
 
181
  # ==========================================================================
182
- # --- LÓGICA DE NEGÓCIO: ORQUESTRADORES PÚBLICOS ---
183
  # ==========================================================================
184
 
185
  @log_function_io
186
- def generate_narrative_low(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
187
- """Orchestrates the generation of a video from a multi-line prompt (sequence of scenes)."""
188
- logging.info("Starting narrative low-res generation...")
189
- used_seed = self._resolve_seed(kwargs.get("seed"))
 
 
 
190
  seed_everything(used_seed)
 
191
 
192
  prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()]
193
  if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.")
 
 
 
194
 
195
  num_chunks = len(prompt_list)
196
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
197
- frames_per_chunk = (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT
198
- overlap_frames = self.config.get("overlap_frames", 8)
199
 
200
  temp_latent_paths = []
201
  overlap_condition_item = None
202
 
203
  try:
204
  for i, chunk_prompt in enumerate(prompt_list):
205
- logging.info(f"Generating narrative chunk {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
206
- current_frames = frames_per_chunk + (overlap_frames if i > 0 else 0)
 
 
 
 
 
 
 
 
207
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
208
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
209
 
@@ -211,9 +225,9 @@ class VideoService:
211
  prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
212
  conditioning_items=current_conditions, **kwargs
213
  )
214
- if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for chunk {i+1}.")
215
 
216
- if i < num_chunks - 1:
217
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
218
  overlap_condition_item = ConditioningItem(media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
219
 
@@ -223,46 +237,23 @@ class VideoService:
223
  torch.save(chunk_latents.cpu(), chunk_path)
224
  temp_latent_paths.append(chunk_path)
225
 
226
- return self._finalize_generation(temp_latent_paths, "narrative_video", used_seed)
 
227
  except Exception as e:
228
- logging.error(f"Error during narrative generation: {e}", exc_info=True)
229
  return None, None, None
230
  finally:
231
  for path in temp_latent_paths:
232
  if path.exists(): path.unlink()
233
  self.finalize()
234
 
235
- @log_function_io
236
- def generate_single_low(self, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
237
- """Orchestrates the generation of a video from a single prompt in one go."""
238
- logging.info("Starting single-prompt low-res generation...")
239
- used_seed = self._resolve_seed(kwargs.get("seed"))
240
- seed_everything(used_seed)
241
-
242
- try:
243
- total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0), min_frames=9)
244
- final_latents = self._generate_single_chunk_low(
245
- num_frames=total_frames, seed=used_seed,
246
- conditioning_items=kwargs.get("initial_conditions", []), **kwargs
247
- )
248
- if final_latents is None: raise RuntimeError("Failed to generate latents.")
249
-
250
- temp_latent_path = RESULTS_DIR / f"temp_single_{used_seed}.pt"
251
- torch.save(final_latents.cpu(), temp_latent_path)
252
- return self._finalize_generation([temp_latent_path], "single_video", used_seed)
253
- except Exception as e:
254
- logging.error(f"Error during single generation: {e}", exc_info=True)
255
- return None, None, None
256
- finally:
257
- self.finalize()
258
-
259
  # ==========================================================================
260
  # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
261
  # ==========================================================================
262
 
263
  @log_function_io
264
  def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
265
- """Calls the pipeline to generate a single chunk of latents."""
266
  height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
267
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
268
  vae_scale_factor = self.pipeline.vae_scale_factor
@@ -271,7 +262,7 @@ class VideoService:
271
 
272
  first_pass_config = self.config.get("first_pass", {}).copy()
273
  if kwargs.get("ltx_configs_override"):
274
- first_pass_config.update(self._prepare_guidance_overrides(kwargs["ltx_configs_override"]))
275
 
276
  pipeline_kwargs = {
277
  "prompt": kwargs['prompt'], "negative_prompt": kwargs['negative_prompt'],
@@ -304,40 +295,53 @@ class VideoService:
304
 
305
  @log_function_io
306
  def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]:
 
307
  if not items_list: return []
308
  height_padded, width_padded = self._align(height), self._align(width)
309
  padding_values = calculate_padding(height, width, height_padded, width_padded)
310
 
311
  conditioning_items = []
312
- for media, frame, weight in items_list:
313
- tensor = self._prepare_conditioning_tensor(media, height, width, padding_values)
 
 
 
 
 
 
 
 
 
314
  safe_frame = max(0, min(int(frame), num_frames - 1))
315
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
316
  return conditioning_items
317
 
318
- @log_function_io
319
- def _prepare_conditioning_tensor(self, media_path: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
320
- tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width)
321
- tensor = torch.nn.functional.pad(tensor, padding)
322
- return tensor.to(self.main_device, dtype=self.runtime_autocast_dtype)
323
-
324
- def _prepare_guidance_overrides(self, ltx_configs: Dict) -> Dict:
325
- overrides = {}
326
- preset = ltx_configs.get("guidance_preset", "Padrão (Recomendado)")
 
 
 
327
  if preset == "Agressivo":
328
- overrides["guidance_scale"] = [1, 2, 8, 12, 8, 2, 1]
329
- overrides["stg_scale"] = [0, 0, 5, 6, 5, 3, 2]
330
  elif preset == "Suave":
331
- overrides["guidance_scale"] = [1, 1, 4, 5, 4, 1, 1]
332
- overrides["stg_scale"] = [0, 0, 2, 2, 2, 1, 0]
333
  elif preset == "Customizado":
334
  try:
335
- overrides["guidance_scale"] = json.loads(ltx_configs["guidance_scale_list"])
336
- overrides["stg_scale"] = json.loads(ltx_configs["stg_scale_list"])
337
- except (json.JSONDecodeError, KeyError) as e:
338
- logging.warning(f"Failed to parse custom guidance values: {e}. Falling back to defaults.")
339
- if overrides: logging.info(f"Applying '{preset}' guidance preset overrides.")
340
- return overrides
 
 
341
 
342
  def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
343
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -361,10 +365,11 @@ class VideoService:
361
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
362
  num_frames = int(round(duration_s * DEFAULT_FPS))
363
  aligned_frames = self._align(num_frames)
364
- return max(aligned_frames + 1, min_frames)
365
 
366
- def _resolve_seed(self, seed: Optional[int]) -> int:
367
- return random.randint(0, 2**32 - 1) if seed is None else int(seed)
 
368
 
369
  # ==============================================================================
370
  # --- INSTANCIAÇÃO SINGLETON ---
@@ -374,4 +379,4 @@ try:
374
  logging.info("Global VideoService orchestrator instance created successfully.")
375
  except Exception as e:
376
  logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
377
- sys.exit(1)
 
1
  # FILE: api/ltx_server_refactored_complete.py
2
+ # DESCRIPTION: Final high-level orchestrator for LTX-Video generation.
3
+ # This version features a unified generation workflow, random seed generation,
4
+ # delegation to specialized modules, and advanced debugging capabilities.
5
 
6
  import gc
7
  import json
 
13
  import time
14
  from pathlib import Path
15
  from typing import Dict, List, Optional, Tuple
16
+
17
  import torch
18
  import yaml
19
  import numpy as np
 
24
  # ==============================================================================
25
 
26
  # Configuração de logging e supressão de warnings
 
27
  import warnings
28
  warnings.filterwarnings("ignore")
29
  logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
 
178
  except Exception: pass
179
 
180
  # ==========================================================================
181
+ # --- LÓGICA DE NEGÓCIO: ORQUESTRADOR PÚBLICO UNIFICADO ---
182
  # ==========================================================================
183
 
184
  @log_function_io
185
+ def generate_low_resolution(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
186
+ """
187
+ [UNIFIED ORCHESTRATOR] Generates a low-resolution video from a prompt.
188
+ Handles both single-line and multi-line prompts transparently.
189
+ """
190
+ logging.info("Starting unified low-resolution generation (random seed)...")
191
+ used_seed = self._get_random_seed()
192
  seed_everything(used_seed)
193
+ logging.info(f"Using randomly generated seed: {used_seed}")
194
 
195
  prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()]
196
  if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.")
197
+
198
+ is_narrative = len(prompt_list) > 1
199
+ logging.info(f"Generation mode detected: {'Narrative' if is_narrative else 'Simple'} ({len(prompt_list)} scene(s)).")
200
 
201
  num_chunks = len(prompt_list)
202
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
203
+ frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
204
+ overlap_frames = self.config.get("overlap_frames", 8) if is_narrative else 0
205
 
206
  temp_latent_paths = []
207
  overlap_condition_item = None
208
 
209
  try:
210
  for i, chunk_prompt in enumerate(prompt_list):
211
+ logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
212
+
213
+ if i == num_chunks - 1:
214
+ processed_frames = (num_chunks - 1) * frames_per_chunk
215
+ current_frames = total_frames - processed_frames
216
+ else:
217
+ current_frames = frames_per_chunk
218
+
219
+ if i > 0: current_frames += overlap_frames
220
+
221
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
222
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
223
 
 
225
  prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
226
  conditioning_items=current_conditions, **kwargs
227
  )
228
+ if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
229
 
230
+ if is_narrative and i < num_chunks - 1:
231
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
232
  overlap_condition_item = ConditioningItem(media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
233
 
 
237
  torch.save(chunk_latents.cpu(), chunk_path)
238
  temp_latent_paths.append(chunk_path)
239
 
240
+ base_filename = "narrative_video" if is_narrative else "single_video"
241
+ return self._finalize_generation(temp_latent_paths, base_filename, used_seed)
242
  except Exception as e:
243
+ logging.error(f"Error during unified generation: {e}", exc_info=True)
244
  return None, None, None
245
  finally:
246
  for path in temp_latent_paths:
247
  if path.exists(): path.unlink()
248
  self.finalize()
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  # ==========================================================================
251
  # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
252
  # ==========================================================================
253
 
254
  @log_function_io
255
  def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
256
+ """[WORKER] Calls the pipeline to generate a single chunk of latents."""
257
  height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
258
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
259
  vae_scale_factor = self.pipeline.vae_scale_factor
 
262
 
263
  first_pass_config = self.config.get("first_pass", {}).copy()
264
  if kwargs.get("ltx_configs_override"):
265
+ self._apply_ui_overrides(first_pass_config, kwargs["ltx_configs_override"])
266
 
267
  pipeline_kwargs = {
268
  "prompt": kwargs['prompt'], "negative_prompt": kwargs['negative_prompt'],
 
295
 
296
  @log_function_io
297
  def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]:
298
+ """[UNIFIED] Prepares ConditioningItems from a mixed list of file paths and tensors."""
299
  if not items_list: return []
300
  height_padded, width_padded = self._align(height), self._align(width)
301
  padding_values = calculate_padding(height, width, height_padded, width_padded)
302
 
303
  conditioning_items = []
304
+ for media_item, frame, weight in items_list:
305
+ if isinstance(media_item, str):
306
+ tensor = load_image_to_tensor_with_resize_and_crop(media_item, height, width)
307
+ tensor = torch.nn.functional.pad(tensor, padding_values)
308
+ tensor = tensor.to(self.main_device, dtype=self.runtime_autocast_dtype)
309
+ elif isinstance(media_item, torch.Tensor):
310
+ tensor = media_item.to(self.main_device, dtype=self.runtime_autocast_dtype)
311
+ else:
312
+ logging.warning(f"Unknown conditioning media type: {type(media_item)}. Skipping.")
313
+ continue
314
+
315
  safe_frame = max(0, min(int(frame), num_frames - 1))
316
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
317
  return conditioning_items
318
 
319
+ def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict):
320
+ """Applies advanced settings from the UI to a config dictionary."""
321
+ # Override step counts
322
+ for key in ["num_inference_steps", "skip_initial_inference_steps", "skip_final_inference_steps"]:
323
+ ui_value = overrides.get(key)
324
+ if ui_value and ui_value > 0:
325
+ config_dict[key] = ui_value
326
+ logging.info(f"Override: '{key}' set to {ui_value} by UI.")
327
+
328
+ # Override guidance settings
329
+ preset = overrides.get("guidance_preset", "Padrão (Recomendado)")
330
+ guidance_overrides = {}
331
  if preset == "Agressivo":
332
+ guidance_overrides = {"guidance_scale": [1, 2, 8, 12, 8, 2, 1], "stg_scale": [0, 0, 5, 6, 5, 3, 2]}
 
333
  elif preset == "Suave":
334
+ guidance_overrides = {"guidance_scale": [1, 1, 4, 5, 4, 1, 1], "stg_scale": [0, 0, 2, 2, 2, 1, 0]}
 
335
  elif preset == "Customizado":
336
  try:
337
+ guidance_overrides["guidance_scale"] = json.loads(overrides["guidance_scale_list"])
338
+ guidance_overrides["stg_scale"] = json.loads(overrides["stg_scale_list"])
339
+ except Exception as e:
340
+ logging.warning(f"Failed to parse custom guidance values: {e}. Using defaults.")
341
+
342
+ if guidance_overrides:
343
+ config_dict.update(guidance_overrides)
344
+ logging.info(f"Applying '{preset}' guidance preset overrides.")
345
 
346
  def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
347
  with tempfile.TemporaryDirectory() as temp_dir:
 
365
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
366
  num_frames = int(round(duration_s * DEFAULT_FPS))
367
  aligned_frames = self._align(num_frames)
368
+ return max(aligned_frames, min_frames)
369
 
370
+ def _get_random_seed(self) -> int:
371
+ """Always generates and returns a new random seed."""
372
+ return random.randint(0, 2**32 - 1)
373
 
374
  # ==============================================================================
375
  # --- INSTANCIAÇÃO SINGLETON ---
 
379
  logging.info("Global VideoService orchestrator instance created successfully.")
380
  except Exception as e:
381
  logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
382
+ sys.exit(1)```