Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import yaml | |
| import json | |
| import threading | |
| from pathlib import Path | |
| # Importações de Hugging Face | |
| from huggingface_hub import snapshot_download, HfFolder | |
| from transformers import T5EncoderModel, T5TokenizerFast | |
| from diffusers import LTXLatentUpsamplePipeline | |
| from diffusers.models import AutoencoderKLLTXVideo, LTXVideoTransformer3DModel | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| # Nossa pipeline customizada e utilitários | |
| from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition | |
| from diffusers.utils import export_to_video | |
| from PIL import Image, ImageOps | |
| import imageio | |
| # --- Configuração de Logging e Avisos --- | |
| import warnings | |
| import logging | |
| warnings.filterwarnings("ignore", category="UserWarning") | |
| warnings.filterwarnings("ignore", category="FutureWarning") | |
| warnings.filterwarnings("ignore", message=".*") | |
| from huggingface_hub import logging as hf_logging | |
| hf_logging.set_verbosity_error() | |
| # --- Classe de Serviço para Carregamento e Gerenciamento dos Modelos --- | |
| class VideoGenerationService: | |
| """ | |
| Encapsula o carregamento e a configuração das pipelines de IA. | |
| Carrega os componentes de forma explícita e modular a partir de um arquivo de configuração. | |
| """ | |
| def __init__(self, config_path: Path): | |
| print("=== [Serviço de Geração de Vídeo] Inicializando... ===") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA é necessário para rodar este serviço.") | |
| self.device = "cuda" | |
| self.torch_dtype = torch.bfloat16 | |
| print(f"[Init] Dispositivo: {self.device}, DType: {self.torch_dtype}") | |
| with open(config_path, "r") as f: | |
| self.cfg = yaml.safe_load(f) | |
| print(f"[Init] Configuração carregada de: {config_path}") | |
| print(json.dumps(self.cfg, indent=2)) | |
| # Parâmetros do YAML | |
| self.base_repo = self.cfg.get("base_repo") | |
| self.checkpoint_path = self.cfg.get("checkpoint_path") | |
| self.upscaler_repo = self.cfg.get("spatial_upscaler_model_path") | |
| self._initialize() | |
| print("=== [Serviço de Geração de Vídeo] Inicialização concluída. ===") | |
| def _initialize(self): | |
| print(f"=== [Init] Baixando snapshot do repositório base: {self.base_repo} ===") | |
| local_repo_path = snapshot_download( | |
| repo_id=self.base_repo, | |
| token=os.getenv("HF_TOKEN") or HfFolder.get_token(), | |
| resume_download=True | |
| ) | |
| print("[Init] Carregando componentes da pipeline a partir de arquivos locais...") | |
| self.vae = AutoencoderKLLTXVideo.from_pretrained(local_repo_path, subfolder="vae", torch_dtype=self.torch_dtype) | |
| self.text_encoder = T5EncoderModel.from_pretrained(local_repo_path, subfolder="text_encoder", torch_dtype=self.torch_dtype) | |
| self.tokenizer = T5TokenizerFast.from_pretrained(local_repo_path, subfolder="tokenizer") | |
| self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(local_repo_path, subfolder="scheduler") | |
| # Causa do erro anterior: desativar explicitamente o dynamic shifting para compatibilidade | |
| if hasattr(self.scheduler.config, 'use_dynamic_shifting') and self.scheduler.config.use_dynamic_shifting: | |
| print("[Init] Desativando 'use_dynamic_shifting' no scheduler.") | |
| self.scheduler.config.use_dynamic_shifting = False | |
| print(f"[Init] Carregando pesos do Transformer de: {self.checkpoint_path}") | |
| self.transformer = LTXVideoTransformer3DModel.from_pretrained( | |
| local_repo_path, subfolder="transformer", weight_name=self.checkpoint_path, torch_dtype=self.torch_dtype | |
| ) | |
| print("[Init] Montando a LTXConditionPipeline...") | |
| self.pipeline = LTXConditionPipeline( | |
| vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, | |
| scheduler=self.scheduler, transformer=self.transformer | |
| ) | |
| self.pipeline.to(self.device) | |
| self.pipeline.vae.enable_tiling() | |
| print(f"[Init] Carregando o upsampler espacial de: {self.upscaler_repo}") | |
| self.upsampler = LTXLatentUpsamplePipeline.from_pretrained( | |
| self.upscaler_repo, vae=self.vae, torch_dtype=self.torch_dtype | |
| ) | |
| self.upsampler.to(self.device) | |
| # --- Inicialização da Aplicação --- | |
| CONFIG_PATH = Path("ltx_config.yaml") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Arquivo de configuração '{CONFIG_PATH}' não encontrado. Crie-o antes de executar a aplicação.") | |
| # Instancia o serviço que carrega e mantém os modelos | |
| service = VideoGenerationService(config_path=CONFIG_PATH) | |
| pipeline = service.pipeline | |
| pipe_upsample = service.upsampler | |
| FPS = 24 | |
| # --- Lógica Principal da Geração de Vídeo --- | |
| def round_to_nearest_resolution_acceptable_by_vae(height, width, vae_temporal_compression_ratio): | |
| height = height - (height % vae_temporal_compression_ratio) | |
| width = width - (width % vae_temporal_compression_ratio) | |
| return height, width | |
| def prepare_and_generate_video( | |
| condition_image_1, condition_strength_1, condition_frame_index_1, | |
| condition_image_2, condition_strength_2, condition_frame_index_2, | |
| prompt, duration, negative_prompt, | |
| height, width, guidance_scale, seed, randomize_seed, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| try: | |
| conditions_data = [ | |
| (condition_image_1, condition_strength_1, condition_frame_index_1), | |
| (condition_image_2, condition_strength_2, condition_frame_index_2) | |
| ] | |
| if randomize_seed: | |
| seed = random.randint(0, 2**32 - 1) | |
| num_frames = int(duration * FPS) + 1 | |
| temporal_compression = pipeline.vae_temporal_compression_ratio | |
| num_frames = ((num_frames - 1) // temporal_compression) * temporal_compression + 1 | |
| # Etapa 1: Preparar condições para baixa resolução | |
| downscale_factor = 2 / 3 | |
| downscaled_height = int(height * downscale_factor) | |
| downscaled_width = int(width * downscale_factor) | |
| downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( | |
| downscaled_height, downscaled_width, pipeline.vae_temporal_compression_ratio | |
| ) | |
| conditions_low_res = [] | |
| for image, strength, frame_index in conditions_data: | |
| if image is not None: | |
| processed_image = ImageOps.fit(image, (downscaled_width, downscaled_height), Image.LANCZOS) | |
| conditions_low_res.append(LTXVideoCondition( | |
| image=processed_image, strength=strength, frame_index=int(frame_index) | |
| )) | |
| pipeline_args_low_res = {"conditions": conditions_low_res} if conditions_low_res else {} | |
| latents = pipeline( | |
| prompt=prompt, negative_prompt=negative_prompt, width=downscaled_width, height=downscaled_height, | |
| num_frames=num_frames, generator=torch.Generator().manual_seed(seed), | |
| output_type="latent", **pipeline_args_low_res | |
| ).frames | |
| # Etapa 2: Upscale | |
| upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 | |
| upscaled_latents = pipe_upsample(latents=latents, output_type="latent").frames | |
| # Etapa 3: Preparar condições para alta resolução (para manter frames imutáveis) | |
| conditions_high_res = [] | |
| for image, strength, frame_index in conditions_data: | |
| if image is not None: | |
| processed_image_high_res = ImageOps.fit(image, (upscaled_width, upscaled_height), Image.LANCZOS) | |
| conditions_high_res.append(LTXVideoCondition( | |
| image=processed_image_high_res, strength=strength, frame_index=int(frame_index) | |
| )) | |
| pipeline_args_high_res = {"conditions": conditions_high_res} if conditions_high_res else {} | |
| final_video_frames_np = pipeline( | |
| prompt=prompt, negative_prompt=negative_prompt, width=upscaled_width, height=upscaled_height, | |
| num_frames=num_frames, denoise_strength=0.999, latents=upscaled_latents, | |
| generator=torch.Generator(device="cuda").manual_seed(seed), | |
| output_type="np", **pipeline_args_high_res | |
| ).frames[0] | |
| # Etapa 4: Exportação | |
| video_uint8_frames = [(frame * 255).astype(np.uint8) for frame in final_video_frames_np] | |
| output_filename = "output.mp4" | |
| with imageio.get_writer(output_filename, fps=FPS, quality=8, macro_block_size=1) as writer: | |
| for frame_idx, frame_data in enumerate(video_uint8_frames): | |
| progress((frame_idx + 1) / len(video_uint8_frames), desc="Codificando frames do vídeo...") | |
| writer.append_data(frame_data) | |
| return output_filename, seed | |
| except Exception as e: | |
| print(f"Ocorreu um erro: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, seed | |
| # --- Interface Gráfica com Gradio --- | |
| with gr.Blocks(theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"]), delete_cache=(60, 900)) as demo: | |
| gr.Markdown("# Geração de Vídeo com LTX\n**Crie vídeos a partir de texto e imagens de condição.**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Descreva o vídeo que você quer gerar...", lines=3, value="O Coringa dançando em um quarto escuro, iluminação dramática.") | |
| with gr.Accordion("Imagem de Condição 1", open=True): | |
| condition_image_1 = gr.Image(label="Imagem 1", type="pil") | |
| with gr.Row(): | |
| condition_strength_1 = gr.Slider(label="Peso", minimum=0.0, maximum=1.0, step=0.05, value=1.0) | |
| condition_frame_index_1 = gr.Number(label="Frame", value=0, precision=0) | |
| with gr.Accordion("Imagem de Condição 2", open=False): | |
| condition_image_2 = gr.Image(label="Imagem 2", type="pil") | |
| with gr.Row(): | |
| condition_strength_2 = gr.Slider(label="Peso", minimum=0.0, maximum=1.0, step=0.05, value=1.0) | |
| condition_frame_index_2 = gr.Number(label="Frame", value=0, precision=0) | |
| duration = gr.Slider(label="Duração (s)", minimum=1.0, maximum=10.0, step=0.5, value=2) | |
| with gr.Accordion("Configurações Avançadas", open=False): | |
| negative_prompt = gr.Textbox(label="Prompt Negativo", lines=2, value="pior qualidade, embaçado, tremido, distorcido") | |
| with gr.Row(): | |
| height = gr.Slider(label="Altura", minimum=256, maximum=1536, step=32, value=768) | |
| width = gr.Slider(label="Largura", minimum=256, maximum=1536, step=32, value=1152) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=5.0, step=0.1, value=1.0) | |
| randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True) | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| generate_btn = gr.Button("Gerar Vídeo", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_video = gr.Video(label="Vídeo Gerado", height=400) | |
| generated_seed = gr.Number(label="Seed Utilizada", interactive=False) | |
| generate_btn.click( | |
| fn=prepare_and_generate_video, | |
| inputs=[ | |
| condition_image_1, condition_strength_1, condition_frame_index_1, | |
| condition_image_2, condition_strength_2, condition_frame_index_2, | |
| prompt, duration, negative_prompt, | |
| height, width, guidance_scale, seed, randomize_seed, | |
| ], | |
| outputs=[output_video, generated_seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True) |