euiiiia commited on
Commit
cb6fb4c
·
verified ·
1 Parent(s): a9beee3

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +199 -58
api/ltx_server_refactored.py CHANGED
@@ -19,7 +19,6 @@ import subprocess
19
  from pathlib import Path
20
  from typing import List, Dict, Optional, Tuple, Union
21
 
22
-
23
  # --- Configurações de Logging e Avisos ---
24
  warnings.filterwarnings("ignore", category=UserWarning)
25
  warnings.filterwarnings("ignore", category=FutureWarning)
@@ -92,12 +91,55 @@ from ltx_video.schedulers.rf import RectifiedFlowScheduler
92
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
93
  import ltx_video.pipelines.crf_compressor as crf_compressor
94
 
95
- from ltx_video.models.autoencoders.vae_encode import (
96
- get_vae_size_scale_factor,
97
- latent_to_pixel_coords,
98
- vae_decode,
99
- vae_encode,
100
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
103
  latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
@@ -174,22 +216,6 @@ def create_ltx_video_pipeline(
174
  transformer = transformer.to(torch.bfloat16)
175
  text_encoder = text_encoder.to(torch.bfloat16)
176
 
177
-
178
- # --- Ajuste global de precisão coerente ---
179
- if precision in ["float8_e4m3fn", "bfloat16"]:
180
- dtype_target = torch.bfloat16
181
- elif precision == "mixed_precision":
182
- dtype_target = torch.float16
183
- else:
184
- dtype_target = torch.float32
185
-
186
- for m in [vae, transformer, text_encoder]:
187
- m.to(dtype_target)
188
-
189
- # garante coerência geral da pipeline
190
- pipeline_dtype = dtype_target
191
-
192
-
193
  # Use submodels for the pipeline
194
  submodel_dict = {
195
  "transformer": transformer,
@@ -206,14 +232,38 @@ def create_ltx_video_pipeline(
206
  }
207
 
208
  pipeline = LTXVideoPipeline(**submodel_dict)
209
-
210
  pipeline = pipeline.to(device)
211
- pipeline.to(dtype=pipeline_dtype)
212
-
213
-
214
-
215
  return pipeline
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  # ==============================================================================
219
  # 3. CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO
@@ -230,7 +280,7 @@ class VideoService:
230
  t0 = time.perf_counter()
231
  print("[INFO] Inicializando VideoService...")
232
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
233
- self.config = self._load_config("ltxv-13b-0.9.8-dev-fp8.yaml")
234
 
235
  self.pipeline, self.latent_upsampler = self._load_models_from_hub()
236
  self._move_models_to_device()
@@ -241,10 +291,6 @@ class VideoService:
241
  device=self.device,
242
  autocast_dtype=self.runtime_autocast_dtype
243
  )
244
-
245
- self._apply_precision_policy()
246
- #print(f"[DEBUG] runtime_autocast_dtype = {getattr(self, 'runtime_autocast_dtype', None)}")
247
-
248
  self._tmp_dirs = set()
249
  RESULTS_DIR.mkdir(exist_ok=True)
250
  print(f"[INFO] VideoService pronto. Tempo de inicialização: {time.perf_counter()-t0:.2f}s")
@@ -253,6 +299,30 @@ class VideoService:
253
  # --- Métodos Públicos (API do Serviço) ---
254
  # --------------------------------------------------------------------------
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def generate_low_resolution(
257
  self, prompt: str, negative_prompt: str,
258
  height: int, width: int, duration_secs: float,
@@ -263,45 +333,120 @@ class VideoService:
263
  Gera um vídeo de baixa resolução e retorna os caminhos para o vídeo e os latentes.
264
  """
265
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
 
 
266
  actual_num_frames = int(duration_secs * DEFAULT_FPS)
267
- #= self._calculate_downscaled_dims(height, width)
268
 
 
269
 
270
  first_pass_kwargs = {
271
  "prompt": prompt,
272
  "negative_prompt": negative_prompt,
273
- "height": height,
274
- "width": width,
275
- "num_frames": max(24, actual_num_frames)+1,
276
  "frame_rate": int(DEFAULT_FPS),
277
  "generator": torch.Generator(device=self.device).manual_seed(used_seed),
278
  "output_type": "latent",
279
  "conditioning_items": conditioning_items,
280
  "guidance_scale": float(guidance_scale),
281
- "is_video": True,
282
- "vae_per_channel_normalize": True,
283
  **(self.config.get("first_pass", {}))
284
  }
285
 
286
  temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
287
  self._register_tmp_dir(temp_dir)
288
- latents = self.pipeline(**first_pass_kwargs).images
289
- pixel_tensor = vae_manager_singleton.decode(latents.clone(), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
290
- video_path = self._save_video_from_tensor(pixel_tensor, "low_res_video", used_seed, temp_dir)
291
- latents_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
292
-
293
  try:
 
 
 
 
 
 
 
294
  return video_path, latents_path, used_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  finally:
296
  self._finalize()
297
 
298
-
299
  def encode_latents_to_mp4(self, latents_path: str, fps: int = int(DEFAULT_FPS)) -> str:
300
  """Decodifica um tensor de latentes salvo e o salva como um vídeo MP4."""
301
  latents = torch.load(latents_path)
302
  temp_dir = tempfile.mkdtemp(prefix="ltxv_enc_")
303
  self._register_tmp_dir(temp_dir)
304
-
 
305
  try:
306
  chunks = self._split_latents_with_overlap(latents)
307
  pixel_chunks = []
@@ -419,7 +564,12 @@ class VideoService:
419
  # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
420
  return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
421
 
422
-
 
 
 
 
 
423
  def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
424
  """Calcula as dimensões para o primeiro passo (baixa resolução)."""
425
  height_padded = ((height - 1) // 8 + 1) * 8
@@ -498,17 +648,8 @@ class VideoService:
498
  if torch.backends.mps.is_available():
499
  torch.mps.manual_seed(seed)
500
 
501
- def _apply_precision_policy(self):
502
- precision = str(self.config.get("precision", "bfloat16")).lower()
503
- if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16
504
- elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16
505
- else: self.runtime_autocast_dtype = torch.float32
506
-
507
-
508
-
509
-
510
  # ==============================================================================
511
  # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
512
  # ==============================================================================
513
- video_generation_service = VideoService()
514
- print("Instância do VideoService pronta para uso.")
 
19
  from pathlib import Path
20
  from typing import List, Dict, Optional, Tuple, Union
21
 
 
22
  # --- Configurações de Logging e Avisos ---
23
  warnings.filterwarnings("ignore", category=UserWarning)
24
  warnings.filterwarnings("ignore", category=FutureWarning)
 
91
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
92
  import ltx_video.pipelines.crf_compressor as crf_compressor
93
 
94
+
95
+
96
+ def load_image_to_tensor_with_resize_and_crop(
97
+ image_input: Union[str, Image.Image],
98
+ target_height: int = 512,
99
+ target_width: int = 768,
100
+ just_crop: bool = False,
101
+ ) -> torch.Tensor:
102
+ """Load and process an image into a tensor.
103
+
104
+ Args:
105
+ image_input: Either a file path (str) or a PIL Image object
106
+ target_height: Desired height of output tensor
107
+ target_width: Desired width of output tensor
108
+ just_crop: If True, only crop the image to the target size without resizing
109
+ """
110
+ if isinstance(image_input, str):
111
+ image = Image.open(image_input).convert("RGB")
112
+ elif isinstance(image_input, Image.Image):
113
+ image = image_input
114
+ else:
115
+ raise ValueError("image_input must be either a file path or a PIL Image object")
116
+
117
+ input_width, input_height = image.size
118
+ aspect_ratio_target = target_width / target_height
119
+ aspect_ratio_frame = input_width / input_height
120
+ if aspect_ratio_frame > aspect_ratio_target:
121
+ new_width = int(input_height * aspect_ratio_target)
122
+ new_height = input_height
123
+ x_start = (input_width - new_width) // 2
124
+ y_start = 0
125
+ else:
126
+ new_width = input_width
127
+ new_height = int(input_width / aspect_ratio_target)
128
+ x_start = 0
129
+ y_start = (input_height - new_height) // 2
130
+
131
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
132
+ if not just_crop:
133
+ image = image.resize((target_width, target_height))
134
+
135
+ image = np.array(image)
136
+ image = cv2.GaussianBlur(image, (3, 3), 0)
137
+ frame_tensor = torch.from_numpy(image).float()
138
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
139
+ frame_tensor = frame_tensor.permute(2, 0, 1)
140
+ frame_tensor = (frame_tensor / 127.5) - 1.0
141
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
142
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
143
 
144
  def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
145
  latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
 
216
  transformer = transformer.to(torch.bfloat16)
217
  text_encoder = text_encoder.to(torch.bfloat16)
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Use submodels for the pipeline
220
  submodel_dict = {
221
  "transformer": transformer,
 
232
  }
233
 
234
  pipeline = LTXVideoPipeline(**submodel_dict)
 
235
  pipeline = pipeline.to(device)
 
 
 
 
236
  return pipeline
237
 
238
+ # ==============================================================================
239
+ # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO
240
+ # ==============================================================================
241
+
242
+ def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
243
+ """Calcula o preenchimento para centralizar uma imagem em uma nova dimensão."""
244
+ pad_h = target_h - orig_h
245
+ pad_w = target_w - orig_w
246
+ pad_top = pad_h // 2
247
+ pad_bottom = pad_h - pad_top
248
+ pad_left = pad_w // 2
249
+ pad_right = pad_w - pad_left
250
+ return (pad_left, pad_right, pad_top, pad_bottom)
251
+
252
+ def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"):
253
+ """Exibe informações detalhadas sobre um tensor para depuração."""
254
+ if not isinstance(tensor, torch.Tensor):
255
+ print(f"\n[INFO] '{name}' não é um tensor.")
256
+ return
257
+ print(f"\n--- Tensor Info: {name} ---")
258
+ print(f" - Shape: {tuple(tensor.shape)}")
259
+ print(f" - Dtype: {tensor.dtype}")
260
+ print(f" - Device: {tensor.device}")
261
+ if tensor.numel() > 0:
262
+ try:
263
+ print(f" - Stats: Min={tensor.min().item():.4f}, Max={tensor.max().item():.4f}, Mean={tensor.mean().item():.4f}")
264
+ except RuntimeError:
265
+ print(" - Stats: Não foi possível calcular (ex: tensores bool).")
266
+ print("-" * 30)
267
 
268
  # ==============================================================================
269
  # 3. CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO
 
280
  t0 = time.perf_counter()
281
  print("[INFO] Inicializando VideoService...")
282
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
283
+ self.config = self._load_config("ltxv-13b-0.9.8-distilled-fp8.yaml")
284
 
285
  self.pipeline, self.latent_upsampler = self._load_models_from_hub()
286
  self._move_models_to_device()
 
291
  device=self.device,
292
  autocast_dtype=self.runtime_autocast_dtype
293
  )
 
 
 
 
294
  self._tmp_dirs = set()
295
  RESULTS_DIR.mkdir(exist_ok=True)
296
  print(f"[INFO] VideoService pronto. Tempo de inicialização: {time.perf_counter()-t0:.2f}s")
 
299
  # --- Métodos Públicos (API do Serviço) ---
300
  # --------------------------------------------------------------------------
301
 
302
+ def _prepare_condition_items(self, items_list: List[Tuple], height: int, width: int, num_frames: int) -> List[ConditioningItem]:
303
+ """Prepara os tensores de condicionamento a partir de imagens ou tensores."""
304
+ if not items_list:
305
+ return []
306
+
307
+ height, width = self._calculate_downscaled_dims(height, width)
308
+
309
+ height_padded = ((height - 1) // 8 + 1) * 8
310
+ width_padded = ((width - 1) // 8 + 1) * 8
311
+ padding_values = calculate_padding(height, width, height_padded, width_padded)
312
+
313
+ conditioning_items = []
314
+ for media, frame_idx, weight in items_list:
315
+ if isinstance(media, str):
316
+ tensor = self._prepare_conditioning_tensor_from_path(media, height, width, padding_values)
317
+ else: # Assume que é um tensor
318
+ tensor = media.to(self.device, dtype=self.runtime_autocast_dtype)
319
+
320
+ # Garante que o frame de condicionamento esteja dentro dos limites do vídeo
321
+ safe_frame_idx = max(0, min(int(frame_idx), num_frames - 1))
322
+ conditioning_items.append(ConditioningItem(tensor, safe_frame_idx, float(weight)))
323
+
324
+ return conditioning_items
325
+
326
  def generate_low_resolution(
327
  self, prompt: str, negative_prompt: str,
328
  height: int, width: int, duration_secs: float,
 
333
  Gera um vídeo de baixa resolução e retorna os caminhos para o vídeo e os latentes.
334
  """
335
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
336
+ self._seed_everething(used_seed)
337
+
338
  actual_num_frames = int(duration_secs * DEFAULT_FPS)
 
339
 
340
+ downscaled_height, downscaled_width = self._calculate_downscaled_dims(height, width)
341
 
342
  first_pass_kwargs = {
343
  "prompt": prompt,
344
  "negative_prompt": negative_prompt,
345
+ "height": downscaled_height,
346
+ "width": downscaled_width,
347
+ "num_frames": max(3, actual_num_frames//8)+1,
348
  "frame_rate": int(DEFAULT_FPS),
349
  "generator": torch.Generator(device=self.device).manual_seed(used_seed),
350
  "output_type": "latent",
351
  "conditioning_items": conditioning_items,
352
  "guidance_scale": float(guidance_scale),
 
 
353
  **(self.config.get("first_pass", {}))
354
  }
355
 
356
  temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
357
  self._register_tmp_dir(temp_dir)
358
+
 
 
 
 
359
  try:
360
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
361
+ latents = self.pipeline(**first_pass_kwargs).images
362
+ pixel_tensor = vae_manager_singleton.decode(latents.clone(), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
363
+ video_path = self._save_video_from_tensor(pixel_tensor, "low_res_video", used_seed, temp_dir)
364
+ latents_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
365
+
366
+
367
  return video_path, latents_path, used_seed
368
+
369
+ except Exception as e:
370
+ print(f"[ERROR] Falha na geração de baixa resolução: {e}")
371
+ traceback.print_exc()
372
+ raise
373
+ finally:
374
+ self._finalize()
375
+
376
+ def generate_upscale_denoise(
377
+ self, latents_path: str, prompt: str,
378
+ negative_prompt: str, height: int, width: int,
379
+ num_frames: float, guidance_scale: float, seed: Optional[int] = None,
380
+ conditioning_items: Optional[List[ConditioningItem]] = None
381
+ ) -> Tuple[str, str]:
382
+ """
383
+ Aplica upscale, AdaIN e Denoise em latentes de baixa resolução usando um processo de chunking.
384
+ """
385
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
386
+ self._seed_everething(used_seed)
387
+
388
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_up_")
389
+ self._register_tmp_dir(temp_dir)
390
+
391
+ try:
392
+ latents_low = torch.load(latents_path).to(self.device)
393
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
394
+ upsampled_latents = latents_low #self._upsample_and_filter_latents(latents_low)
395
+
396
+ #chunks = self._split_latents_with_overlap(upsampled_latents)
397
+ #refined_chunks = []
398
+
399
+ #for chunk in chunks:
400
+ #if chunk.shape[2] <= 1: continue # Pula chunks inválidos
401
+
402
+ chunk = upsampled_latents
403
+
404
+ second_pass_height = chunk.shape[3] * self.pipeline.vae_scale_factor
405
+ second_pass_width = chunk.shape[4] * self.pipeline.vae_scale_factor
406
+
407
+ second_pass_kwargs = {
408
+ "prompt": prompt,
409
+ "negative_prompt": negative_prompt,
410
+ "height": second_pass_height,
411
+ "width": second_pass_width,
412
+ "frame_rate": int(DEFAULT_FPS),
413
+ "num_frames": num_frames,
414
+ "latents": chunk, # O tensor completo é passado aqui
415
+ "guidance_scale": float(guidance_scale),
416
+ "output_type": "latent",
417
+ "generator": torch.Generator(device=self.device).manual_seed(used_seed),
418
+ "conditioning_items": conditioning_items,
419
+ **(self.config.get("second_pass", {}))
420
+ }
421
+ refined_chunk = self.pipeline(**second_pass_kwargs).images
422
+ #refined_chunks.append(refined_chunk)
423
+
424
+ del latents_low; torch.cuda.empty_cache()
425
+
426
+ final_latents = refined_chunk #self._merge_chunks_with_overlap(refined_chunks)
427
+ #if LTXV_DEBUG:
428
+ # log_tensor_info(final_latents, "Latentes Upscaled/Refinados Finais")
429
+
430
+ latents_path = self._save_latents_to_disk(final_latents, "latents_refined", used_seed)
431
+ pixel_tensor = vae_manager_singleton.decode(final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
432
+ video_path = self._save_video_from_tensor(pixel_tensor, "refined_video", used_seed, temp_dir)
433
+
434
+ return video_path, latents_path
435
+
436
+ except Exception as e:
437
+ print(f"[ERROR] Falha no processo de upscale e denoise: {e}")
438
+ traceback.print_exc()
439
+ raise
440
  finally:
441
  self._finalize()
442
 
 
443
  def encode_latents_to_mp4(self, latents_path: str, fps: int = int(DEFAULT_FPS)) -> str:
444
  """Decodifica um tensor de latentes salvo e o salva como um vídeo MP4."""
445
  latents = torch.load(latents_path)
446
  temp_dir = tempfile.mkdtemp(prefix="ltxv_enc_")
447
  self._register_tmp_dir(temp_dir)
448
+ seed = random.randint(0, 99999) # Seed apenas para nome do arquivo
449
+
450
  try:
451
  chunks = self._split_latents_with_overlap(latents)
452
  pixel_chunks = []
 
564
  # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
565
  return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
566
 
567
+ def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
568
+ """Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
569
+ tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
570
+ tensor = F.pad(tensor, padding)
571
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
572
+
573
  def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
574
  """Calcula as dimensões para o primeiro passo (baixa resolução)."""
575
  height_padded = ((height - 1) // 8 + 1) * 8
 
648
  if torch.backends.mps.is_available():
649
  torch.mps.manual_seed(seed)
650
 
 
 
 
 
 
 
 
 
 
651
  # ==============================================================================
652
  # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
653
  # ==============================================================================
654
+ print("Criando instância do VideoService. O carregamento do modelo começará agora...")
655
+ video_generation_service = VideoService()