|
|
import gradio as gr |
|
|
import subprocess |
|
|
import os |
|
|
import sys |
|
|
import threading |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEEDVR_DIR = "SeedVR" |
|
|
|
|
|
def setup(): |
|
|
""" |
|
|
Clona o repositório, instala dependências especiais e baixa o modelo. |
|
|
Esta função é executada uma vez quando o Space é iniciado. |
|
|
""" |
|
|
print("--- Iniciando configuração do ambiente ---") |
|
|
|
|
|
|
|
|
if not os.path.exists(SEEDVR_DIR): |
|
|
print(f"Clonando o repositório SeedVR de https://github.com/bytedance-seed/SeedVR.git...") |
|
|
subprocess.run(["git", "clone", "https://github.com/bytedance-seed/SeedVR.git"], check=True) |
|
|
else: |
|
|
print("Repositório SeedVR já existe.") |
|
|
|
|
|
|
|
|
os.chdir(SEEDVR_DIR) |
|
|
|
|
|
|
|
|
print("Instalando flash_attn...") |
|
|
subprocess.run([sys.executable, "-m", "pip", "install", "flash_attn==2.5.9.post1", "--no-build-isolation"], check=True) |
|
|
|
|
|
|
|
|
apex_whl_path = "apex-0.1-cp310-cp310-linux_x86_64.whl" |
|
|
if os.path.exists(apex_whl_path): |
|
|
print(f"Instalando {apex_whl_path}...") |
|
|
subprocess.run([sys.executable, "-m", "pip", "install", apex_whl_path], check=True) |
|
|
else: |
|
|
print(f"AVISO: Arquivo '{apex_whl_path}' não encontrado. A instalação do Apex foi pulada. Por favor, adicione este arquivo ao seu repositório.") |
|
|
|
|
|
|
|
|
save_dir = "ckpts/" |
|
|
repo_id = "ByteDance-Seed/SeedVR2-3B" |
|
|
cache_dir = os.path.join(save_dir, "cache") |
|
|
|
|
|
if not os.path.exists(os.path.join(save_dir, "README.md")): |
|
|
print(f"Baixando o modelo '{repo_id}' para '{save_dir}'...") |
|
|
snapshot_download( |
|
|
cache_dir=cache_dir, |
|
|
local_dir=save_dir, |
|
|
repo_id=repo_id, |
|
|
local_dir_use_symlinks=False, |
|
|
resume_download=True, |
|
|
allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"], |
|
|
) |
|
|
else: |
|
|
print("Modelo já foi baixado.") |
|
|
|
|
|
print("--- Configuração do ambiente concluída ---") |
|
|
|
|
|
os.chdir("..") |
|
|
|
|
|
|
|
|
setup() |
|
|
|
|
|
|
|
|
|
|
|
def run_inference(video_path, seed, res_h, res_w, sp_size, progress=gr.Progress(track_tqdm=True)): |
|
|
""" |
|
|
Executa o script de inferência do SeedVR usando torchrun. |
|
|
""" |
|
|
if video_path is None: |
|
|
return None, "Por favor, faça o upload de um arquivo de vídeo de entrada." |
|
|
|
|
|
input_folder = os.path.dirname(video_path.name) |
|
|
output_folder = "outputs" |
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
num_gpus = 4 |
|
|
|
|
|
command = [ |
|
|
"torchrun", |
|
|
f"--nproc-per-node={num_gpus}", |
|
|
"projects/inference_seedvr2_3b.py", |
|
|
"--video_path", input_folder, |
|
|
"--output_dir", f"../{output_folder}", |
|
|
"--seed", str(seed), |
|
|
"--res_h", str(res_h), |
|
|
"--res_w", str(res_w), |
|
|
"--sp_size", str(sp_size), |
|
|
] |
|
|
|
|
|
log_output = "" |
|
|
try: |
|
|
print(f"Executando comando: {' '.join(command)}") |
|
|
|
|
|
process = subprocess.Popen( |
|
|
command, |
|
|
cwd=SEEDVR_DIR, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
text=True, |
|
|
encoding='utf-8' |
|
|
) |
|
|
|
|
|
|
|
|
while True: |
|
|
line = process.stdout.readline() |
|
|
if not line: |
|
|
break |
|
|
log_output += line |
|
|
print(line, end='') |
|
|
yield None, log_output |
|
|
|
|
|
process.wait() |
|
|
|
|
|
if process.returncode != 0: |
|
|
raise subprocess.CalledProcessError(process.returncode, command, output=log_output) |
|
|
|
|
|
|
|
|
result_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if f.endswith(('.mp4', '.avi', '.mov'))] |
|
|
|
|
|
if not result_files: |
|
|
return None, log_output + "\n\nERRO: Nenhum arquivo de vídeo foi gerado." |
|
|
|
|
|
return result_files, log_output |
|
|
|
|
|
except subprocess.CalledProcessError as e: |
|
|
error_message = f"Erro ao executar a inferência.\nOutput:\n{e.output}" |
|
|
print(error_message) |
|
|
return None, error_message |
|
|
except Exception as e: |
|
|
return None, f"Ocorreu um erro inesperado: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🎥 Interface de Inferência para SeedVR2") |
|
|
gr.Markdown("Faça o upload de um vídeo, ajuste os parâmetros e clique em 'Gerar Vídeo' para iniciar a inferência.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
video_input = gr.File(label="Vídeo de Entrada (.mp4, .mov, etc.)") |
|
|
seed_input = gr.Number(label="Seed", value=123) |
|
|
res_h_input = gr.Number(label="Altura da Saída (res_h)", value=320) |
|
|
res_w_input = gr.Number(label="Largura da Saída (res_w)", value=512) |
|
|
sp_size_input = gr.Number(label="Tamanho do passo espacial (sp_size)", value=1) |
|
|
|
|
|
run_button = gr.Button("Gerar Vídeo", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gallery_output = gr.Gallery(label="Vídeo Gerado", show_label=True, elem_id="gallery") |
|
|
log_display = gr.Textbox(label="Logs de Execução", lines=15, interactive=False) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_inference, |
|
|
inputs=[video_input, seed_input, res_h_input, res_w_input, sp_size_input], |
|
|
outputs=[gallery_output, log_display] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |