Spaces:
Paused
Paused
| # FILE: api/ltx/ltx_utils.py | |
| # DESCRIPTION: A pure utility library for the LTX ecosystem. | |
| # Contains the official low-level builder function for the complete pipeline | |
| # and other stateless helper functions. | |
| import os | |
| import random | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Tuple, Union | |
| import torchvision.transforms.functional as TVF | |
| from PIL import Image | |
| import torch | |
| from safetensors import safe_open | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| # ============================================================================== | |
| # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX --- | |
| # ============================================================================== | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| """Adiciona o diretório do repositório LTX ao sys.path para importação de suas bibliotecas.""" | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") | |
| add_deps_to_path() | |
| try: | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| except ImportError as e: | |
| logging.critical("Failed to import a core LTX-Video library component.", exc_info=True) | |
| raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}") | |
| # ============================================================================== | |
| # --- FUNÇÃO HELPER 'create_transformer' (Essencial) --- | |
| # ============================================================================== | |
| def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel: | |
| """ | |
| Cria e carrega o modelo Transformer3D com a lógica de precisão correta, | |
| incluindo suporte para a otimização float8_e4m3fn. | |
| """ | |
| if precision == "float8_e4m3fn": | |
| try: | |
| from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn) | |
| patch_transformer_for_q8_kernels(transformer) | |
| return transformer | |
| except ImportError: | |
| raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.") | |
| elif precision == "bfloat16": | |
| return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16) | |
| else: | |
| return Transformer3DModel.from_pretrained(ckpt_path) | |
| # ============================================================================== | |
| # --- BUILDER DE BAIXO NÍVEL OFICIAL --- | |
| # ============================================================================== | |
| def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVideoPipeline: | |
| """ | |
| Constrói o pipeline LTX COMPLETO, incluindo o VAE, e o mantém na CPU. | |
| Esta é a função de construção fundamental usada pelo LTXAducManager. | |
| """ | |
| logging.info(f"Building complete LTX pipeline from checkpoint: {Path(checkpoint_path).name}") | |
| with safe_open(checkpoint_path, framework="pt") as f: | |
| metadata = f.metadata() or {} | |
| config_str = metadata.get("config", "{}") | |
| allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps") | |
| precision = config.get("precision", "bfloat16") | |
| # Usa a função helper correta para criar o transformer | |
| transformer = create_transformer(checkpoint_path, precision).to("cpu") | |
| scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path) | |
| text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu") | |
| tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer") | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu") | |
| if precision == "bfloat16": | |
| text_encoder.to(torch.bfloat16) | |
| vae.to(torch.bfloat16) | |
| # O transformer já foi convertido para bfloat16 dentro de create_transformer, se aplicável | |
| pipeline = LTXVideoPipeline( | |
| transformer=transformer, | |
| patchifier=patchifier, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| vae=vae, # VAE é incluído para que o pipeline possa ser auto-suficiente | |
| allowed_inference_steps=allowed_inference_steps, | |
| prompt_enhancer_image_caption_model=None, | |
| prompt_enhancer_image_caption_processor=None, | |
| prompt_enhancer_llm_model=None, | |
| prompt_enhancer_llm_tokenizer=None, | |
| ) | |
| return pipeline | |
| # ============================================================================== | |
| # --- FUNÇÕES AUXILIARES GENÉRICAS --- | |
| # ============================================================================== | |
| # # FILE: api/ltx/ltx_utils.py | |
| # DESCRIPTION: A pure utility library for the LTX ecosystem. | |
| # Contains the official low-level builder function for the complete pipeline | |
| # and other stateless helper functions. | |
| import os | |
| import random | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import torch | |
| from safetensors import safe_open | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| # ============================================================================== | |
| # --- CONFIGURAÇÃO DE PATH E IMPORTS DA BIBLIOTECA LTX --- | |
| # ============================================================================== | |
| LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video") | |
| def add_deps_to_path(): | |
| """Adiciona o diretório do repositório LTX ao sys.path para importação de suas bibliotecas.""" | |
| repo_path = str(LTX_VIDEO_REPO_DIR.resolve()) | |
| if repo_path not in sys.path: | |
| sys.path.insert(0, repo_path) | |
| logging.info(f"[ltx_utils] LTX-Video repository added to sys.path: {repo_path}") | |
| add_deps_to_path() | |
| try: | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| except ImportError as e: | |
| logging.critical("Failed to import a core LTX-Video library component.", exc_info=True) | |
| raise ImportError(f"Could not import from LTX-Video library. Check repo integrity at '{LTX_VIDEO_REPO_DIR}'. Error: {e}") | |
| # ============================================================================== | |
| # --- FUNÇÃO HELPER 'create_transformer' (Essencial) --- | |
| # ============================================================================== | |
| def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel: | |
| """ | |
| Cria e carrega o modelo Transformer3D com a lógica de precisão correta, | |
| incluindo suporte para a otimização float8_e4m3fn. | |
| """ | |
| if precision == "float8_e4m3fn": | |
| try: | |
| from q8_kernels.integration.patch_transformer import patch_diffusers_transformer as patch_transformer_for_q8_kernels | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path, dtype=torch.float8_e4m3fn) | |
| patch_transformer_for_q8_kernels(transformer) | |
| return transformer | |
| except ImportError: | |
| raise ValueError("Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from the project's wheels.") | |
| elif precision == "bfloat16": | |
| return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16) | |
| else: | |
| return Transformer3DModel.from_pretrained(ckpt_path) | |
| # ============================================================================== | |
| # --- BUILDER DE BAIXO NÍVEL OFICIAL --- | |
| # ============================================================================== | |
| def build_complete_pipeline_on_cpu(checkpoint_path: str, config: Dict) -> LTXVideoPipeline: | |
| """ | |
| Constrói o pipeline LTX COMPLETO, incluindo o VAE, e o mantém na CPU. | |
| Esta é a função de construção fundamental usada pelo LTXAducManager. | |
| """ | |
| logging.info(f"Building complete LTX pipeline from checkpoint: {Path(checkpoint_path).name}") | |
| with safe_open(checkpoint_path, framework="pt") as f: | |
| metadata = f.metadata() or {} | |
| config_str = metadata.get("config", "{}") | |
| allowed_inference_steps = json.loads(config_str).get("allowed_inference_steps") | |
| precision = config.get("precision", "bfloat16") | |
| # Usa a função helper correta para criar o transformer | |
| transformer = create_transformer(checkpoint_path, precision).to("cpu") | |
| scheduler = RectifiedFlowScheduler.from_pretrained(checkpoint_path) | |
| text_encoder = T5EncoderModel.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="text_encoder").to("cpu") | |
| tokenizer = T5Tokenizer.from_pretrained(config["text_encoder_model_name_or_path"], subfolder="tokenizer") | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| vae = CausalVideoAutoencoder.from_pretrained(checkpoint_path).to("cpu") | |
| if precision == "bfloat16": | |
| text_encoder.to(torch.bfloat16) | |
| vae.to(torch.bfloat16) | |
| # O transformer já foi convertido para bfloat16 dentro de create_transformer, se aplicável | |
| pipeline = LTXVideoPipeline( | |
| transformer=transformer, | |
| patchifier=patchifier, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| vae=vae, # VAE é incluído para que o pipeline possa ser auto-suficiente | |
| allowed_inference_steps=allowed_inference_steps, | |
| prompt_enhancer_image_caption_model=None, | |
| prompt_enhancer_image_caption_processor=None, | |
| prompt_enhancer_llm_model=None, | |
| prompt_enhancer_llm_tokenizer=None, | |
| ) | |
| return pipeline | |
| # ============================================================================== | |
| # --- FUNÇÕES AUXILIARES GENÉRICAS --- | |
| # ============================================================================== | |
| def seed_everything(seed: int): | |
| """ | |
| Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade. | |
| """ | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = Fals | |
| def load_image_to_tensor_with_resize_and_crop( | |
| image_input: Union[str, Image.Image], | |
| target_height: int, | |
| target_width: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Carrega, redimensiona, corta e processa uma imagem para um tensor de pixel 5D, | |
| normalizado para [-1, 1], pronto para ser enviado ao VAE para encoding. | |
| """ | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input.convert("RGB") | |
| else: | |
| raise ValueError("image_input must be a file path or a PIL Image object") | |
| input_width, input_height = image.size | |
| aspect_ratio_target = target_width / target_height | |
| aspect_ratio_frame = input_width / input_height | |
| if aspect_ratio_frame > aspect_ratio_target: | |
| new_width, new_height = int(input_height * aspect_ratio_target), input_height | |
| x_start = (input_width - new_width) // 2 | |
| image = image.crop((x_start, 0, x_start + new_width, new_height)) | |
| else: | |
| new_height = int(input_width / aspect_ratio_target) | |
| y_start = (input_height - new_height) // 2 | |
| image = image.crop((0, y_start, input_width, y_start + new_height)) | |
| image = image.resize((target_width, target_height), Image.Resampling.LANCZOS) | |
| frame_tensor = TVF.to_tensor(image) | |
| # Esta parte depende de 'crf_compressor', então precisamos importá-lo aqui também | |
| try: | |
| from ltx_video.pipelines import crf_compressor | |
| frame_tensor_hwc = frame_tensor.permute(1, 2, 0) | |
| frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) | |
| frame_tensor = frame_tensor_hwc.permute(2, 0, 1) | |
| except ImportError: | |
| logging.warning("CRF Compressor not found. Skipping compression step.") | |
| frame_tensor = (frame_tensor * 2.0) - 1.0 | |
| return frame_tensor.unsqueeze(0).unsqueeze(2) | |
| def seed_everything(seed: int): | |
| """ | |
| Define a semente para PyTorch, NumPy e Python para garantir reprodutibilidade. | |
| """ | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False |