|
|
|
|
|
|
|
|
import os |
|
|
import uuid |
|
|
import shutil |
|
|
import subprocess |
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
|
|
from fastapi.responses import FileResponse |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
|
|
|
|
|
|
SEEDVR_DIR = "/app/SeedVR" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
def run_inference_blocking(input_video_path: str, output_dir: str, seed: int, res_h: int, res_w: int) -> str: |
|
|
""" |
|
|
Função síncrona que executa o script torchrun. |
|
|
Ela bloqueia a execução, por isso deve ser chamada em um thread separado. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_folder_relative = os.path.relpath(os.path.dirname(input_video_path), SEEDVR_DIR) |
|
|
|
|
|
|
|
|
output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR) |
|
|
|
|
|
command = [ |
|
|
"torchrun", |
|
|
"--nproc-per-node=4", |
|
|
"projects/inference_seedvr2_3b.py", |
|
|
"--video_path", input_folder_relative, |
|
|
"--output_dir", output_folder_relative, |
|
|
"--seed", str(seed), |
|
|
"--res_h", str(res_h), |
|
|
"--res_w", str(res_w), |
|
|
"--sp_size", "1", |
|
|
] |
|
|
|
|
|
try: |
|
|
print(f"Executando comando: {' '.join(command)}") |
|
|
|
|
|
subprocess.run(command, cwd=SEEDVR_DIR, check=True, capture_output=True, text=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
|
|
|
print("Erro na execução do subprocesso!") |
|
|
print(f"Stdout: {e.stdout}") |
|
|
print(f"Stderr: {e.stderr}") |
|
|
raise HTTPException(status_code=500, detail=f"A inferência falhou: {e.stderr}") |
|
|
|
|
|
|
|
|
output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png'))] |
|
|
if not output_files: |
|
|
raise HTTPException(status_code=500, detail="A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.") |
|
|
|
|
|
return os.path.join(output_dir, output_files[0]) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "API de Inferência SeedVR2 está online. Use o endpoint /infer/ para processar vídeos."} |
|
|
|
|
|
|
|
|
@app.post("/infer/", response_class=FileResponse) |
|
|
async def create_inference_job( |
|
|
video: UploadFile = File(...), |
|
|
seed: int = Form(666), |
|
|
res_h: int = Form(720), |
|
|
res_w: int = Form(1280), |
|
|
): |
|
|
""" |
|
|
Recebe um vídeo e parâmetros, executa a inferência e retorna o vídeo processado. |
|
|
""" |
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
input_dir = os.path.join("/app", "temp_inputs", job_id) |
|
|
output_dir = os.path.join("/app", "temp_outputs", job_id) |
|
|
os.makedirs(input_dir, exist_ok=True) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
input_video_path = os.path.join(input_dir, video.filename) |
|
|
|
|
|
try: |
|
|
|
|
|
with open(input_video_path, "wb") as buffer: |
|
|
shutil.copyfileobj(video.file, buffer) |
|
|
|
|
|
|
|
|
|
|
|
result_path = await run_in_threadpool( |
|
|
run_inference_blocking, |
|
|
input_video_path=input_video_path, |
|
|
output_dir=output_dir, |
|
|
seed=seed, |
|
|
res_h=res_h, |
|
|
res_w=res_w |
|
|
) |
|
|
|
|
|
|
|
|
return FileResponse(path=result_path, media_type='video/mp4', filename=os.path.basename(result_path)) |
|
|
|
|
|
finally: |
|
|
|
|
|
print("Limpando diretórios temporários...") |
|
|
shutil.rmtree(input_dir, ignore_errors=True) |
|
|
shutil.rmtree(output_dir, ignore_errors=True) |