SeedVR2 / main.py
aducsdr's picture
Create main.py
e6d2def verified
raw
history blame
4.16 kB
# main.py (API com FastAPI)
import os
import uuid
import shutil
import subprocess
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.concurrency import run_in_threadpool
# Diretório base onde o código do SeedVR está
SEEDVR_DIR = "/app/SeedVR"
app = FastAPI()
def run_inference_blocking(input_video_path: str, output_dir: str, seed: int, res_h: int, res_w: int) -> str:
"""
Função síncrona que executa o script torchrun.
Ela bloqueia a execução, por isso deve ser chamada em um thread separado.
"""
# O script de inferência espera ser executado de dentro do diretório SeedVR
# e que os caminhos de entrada/saída sejam relativos a ele.
# Constrói o caminho relativo para a pasta de entrada
input_folder_relative = os.path.relpath(os.path.dirname(input_video_path), SEEDVR_DIR)
# Constrói o caminho relativo para a pasta de saída
output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
command = [
"torchrun",
"--nproc-per-node=4",
"projects/inference_seedvr2_3b.py",
"--video_path", input_folder_relative,
"--output_dir", output_folder_relative,
"--seed", str(seed),
"--res_h", str(res_h),
"--res_w", str(res_w),
"--sp_size", "1", # Mantido fixo ou pode se tornar um parâmetro
]
try:
print(f"Executando comando: {' '.join(command)}")
# Executa o subprocesso a partir do diretório do SeedVR
subprocess.run(command, cwd=SEEDVR_DIR, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
# Se o script falhar, captura o erro e o log para depuração
print("Erro na execução do subprocesso!")
print(f"Stdout: {e.stdout}")
print(f"Stderr: {e.stderr}")
raise HTTPException(status_code=500, detail=f"A inferência falhou: {e.stderr}")
# Encontra o arquivo de saída gerado
output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png'))]
if not output_files:
raise HTTPException(status_code=500, detail="A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.")
return os.path.join(output_dir, output_files[0])
@app.get("/")
async def root():
return {"message": "API de Inferência SeedVR2 está online. Use o endpoint /infer/ para processar vídeos."}
@app.post("/infer/", response_class=FileResponse)
async def create_inference_job(
video: UploadFile = File(...),
seed: int = Form(666),
res_h: int = Form(720),
res_w: int = Form(1280),
):
"""
Recebe um vídeo e parâmetros, executa a inferência e retorna o vídeo processado.
"""
# Cria diretórios temporários únicos para esta requisição para evitar conflitos
job_id = str(uuid.uuid4())
input_dir = os.path.join("/app", "temp_inputs", job_id)
output_dir = os.path.join("/app", "temp_outputs", job_id)
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
input_video_path = os.path.join(input_dir, video.filename)
try:
# Salva o vídeo enviado para o disco
with open(input_video_path, "wb") as buffer:
shutil.copyfileobj(video.file, buffer)
# Executa a função de inferência pesada em um thread separado
# para não bloquear o servidor da API
result_path = await run_in_threadpool(
run_inference_blocking,
input_video_path=input_video_path,
output_dir=output_dir,
seed=seed,
res_h=res_h,
res_w=res_w
)
# Retorna o arquivo de vídeo como uma resposta para download
return FileResponse(path=result_path, media_type='video/mp4', filename=os.path.basename(result_path))
finally:
# Limpa os diretórios temporários após a conclusão ou falha
print("Limpando diretórios temporários...")
shutil.rmtree(input_dir, ignore_errors=True)
shutil.rmtree(output_dir, ignore_errors=True)