File size: 6,124 Bytes
53cc24b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
import subprocess
import os
import sys
import threading
from huggingface_hub import snapshot_download
# --- 1. CONFIGURAÇÃO INICIAL (Executa apenas uma vez) ---
# Diretório base para o projeto SeedVR
SEEDVR_DIR = "SeedVR"
def setup():
"""
Clona o repositório, instala dependências especiais e baixa o modelo.
Esta função é executada uma vez quando o Space é iniciado.
"""
print("--- Iniciando configuração do ambiente ---")
# Etapa 1: Clonar o repositório SeedVR se ainda não existir
if not os.path.exists(SEEDVR_DIR):
print(f"Clonando o repositório SeedVR de https://github.com/bytedance-seed/SeedVR.git...")
subprocess.run(["git", "clone", "https://github.com/bytedance-seed/SeedVR.git"], check=True)
else:
print("Repositório SeedVR já existe.")
# Mudando para o diretório do projeto para os próximos comandos
os.chdir(SEEDVR_DIR)
# Etapa 2: Instalar dependências que exigem comandos específicos
print("Instalando flash_attn...")
subprocess.run([sys.executable, "-m", "pip", "install", "flash_attn==2.5.9.post1", "--no-build-isolation"], check=True)
# Nota sobre o Apex: O arquivo .whl precisa estar no seu repositório Hugging Face
apex_whl_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
if os.path.exists(apex_whl_path):
print(f"Instalando {apex_whl_path}...")
subprocess.run([sys.executable, "-m", "pip", "install", apex_whl_path], check=True)
else:
print(f"AVISO: Arquivo '{apex_whl_path}' não encontrado. A instalação do Apex foi pulada. Por favor, adicione este arquivo ao seu repositório.")
# Etapa 3: Baixar o modelo do Hugging Face Hub
save_dir = "ckpts/"
repo_id = "ByteDance-Seed/SeedVR2-3B"
cache_dir = os.path.join(save_dir, "cache")
if not os.path.exists(os.path.join(save_dir, "README.md")): # Checa se o download já foi feito
print(f"Baixando o modelo '{repo_id}' para '{save_dir}'...")
snapshot_download(
cache_dir=cache_dir,
local_dir=save_dir,
repo_id=repo_id,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"],
)
else:
print("Modelo já foi baixado.")
print("--- Configuração do ambiente concluída ---")
# Retornar ao diretório raiz original
os.chdir("..")
# Executa a configuração
setup()
# --- 2. LÓGICA DA INFERÊNCIA ---
def run_inference(video_path, seed, res_h, res_w, sp_size, progress=gr.Progress(track_tqdm=True)):
"""
Executa o script de inferência do SeedVR usando torchrun.
"""
if video_path is None:
return None, "Por favor, faça o upload de um arquivo de vídeo de entrada."
input_folder = os.path.dirname(video_path.name)
output_folder = "outputs"
os.makedirs(output_folder, exist_ok=True)
# Determinar o número de GPUs disponíveis. Para 4xL40s, será 4.
num_gpus = 4
command = [
"torchrun",
f"--nproc-per-node={num_gpus}",
"projects/inference_seedvr2_3b.py",
"--video_path", input_folder,
"--output_dir", f"../{output_folder}", # Navega para fora do dir SeedVR
"--seed", str(seed),
"--res_h", str(res_h),
"--res_w", str(res_w),
"--sp_size", str(sp_size),
]
log_output = ""
try:
print(f"Executando comando: {' '.join(command)}")
# Executar o comando dentro do diretório SeedVR
process = subprocess.Popen(
command,
cwd=SEEDVR_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding='utf-8'
)
# Capturar e exibir a saída em tempo real
while True:
line = process.stdout.readline()
if not line:
break
log_output += line
print(line, end='')
yield None, log_output
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, command, output=log_output)
# Encontrar os arquivos de vídeo gerados
result_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith(('.mp4', '.avi', '.mov'))]
if not result_files:
return None, log_output + "\n\nERRO: Nenhum arquivo de vídeo foi gerado."
return result_files, log_output
except subprocess.CalledProcessError as e:
error_message = f"Erro ao executar a inferência.\nOutput:\n{e.output}"
print(error_message)
return None, error_message
except Exception as e:
return None, f"Ocorreu um erro inesperado: {str(e)}"
# --- 3. INTERFACE GRAPHEMICA (GRADIO) ---
with gr.Blocks() as demo:
gr.Markdown("# 🎥 Interface de Inferência para SeedVR2")
gr.Markdown("Faça o upload de um vídeo, ajuste os parâmetros e clique em 'Gerar Vídeo' para iniciar a inferência.")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.File(label="Vídeo de Entrada (.mp4, .mov, etc.)")
seed_input = gr.Number(label="Seed", value=123)
res_h_input = gr.Number(label="Altura da Saída (res_h)", value=320)
res_w_input = gr.Number(label="Largura da Saída (res_w)", value=512)
sp_size_input = gr.Number(label="Tamanho do passo espacial (sp_size)", value=1)
run_button = gr.Button("Gerar Vídeo", variant="primary")
with gr.Column(scale=2):
gallery_output = gr.Gallery(label="Vídeo Gerado", show_label=True, elem_id="gallery")
log_display = gr.Textbox(label="Logs de Execução", lines=15, interactive=False)
run_button.click(
fn=run_inference,
inputs=[video_input, seed_input, res_h_input, res_w_input, sp_size_input],
outputs=[gallery_output, log_display]
)
if __name__ == "__main__":
demo.launch() |