# main.py (API com FastAPI) import os import uuid import shutil import subprocess from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import FileResponse from fastapi.concurrency import run_in_threadpool # Diretório base onde o código do SeedVR está SEEDVR_DIR = "/app/SeedVR" app = FastAPI() def run_inference_blocking(input_video_path: str, output_dir: str, seed: int, res_h: int, res_w: int) -> str: """ Função síncrona que executa o script torchrun. Ela bloqueia a execução, por isso deve ser chamada em um thread separado. """ # O script de inferência espera ser executado de dentro do diretório SeedVR # e que os caminhos de entrada/saída sejam relativos a ele. # Constrói o caminho relativo para a pasta de entrada input_folder_relative = os.path.relpath(os.path.dirname(input_video_path), SEEDVR_DIR) # Constrói o caminho relativo para a pasta de saída output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR) command = [ "torchrun", "--nproc-per-node=4", "projects/inference_seedvr2_3b.py", "--video_path", input_folder_relative, "--output_dir", output_folder_relative, "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w), "--sp_size", "1", # Mantido fixo ou pode se tornar um parâmetro ] try: print(f"Executando comando: {' '.join(command)}") # Executa o subprocesso a partir do diretório do SeedVR subprocess.run(command, cwd=SEEDVR_DIR, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as e: # Se o script falhar, captura o erro e o log para depuração print("Erro na execução do subprocesso!") print(f"Stdout: {e.stdout}") print(f"Stderr: {e.stderr}") raise HTTPException(status_code=500, detail=f"A inferência falhou: {e.stderr}") # Encontra o arquivo de saída gerado output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png'))] if not output_files: raise HTTPException(status_code=500, detail="A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.") return os.path.join(output_dir, output_files[0]) @app.get("/") async def root(): return {"message": "API de Inferência SeedVR2 está online. Use o endpoint /infer/ para processar vídeos."} @app.post("/infer/", response_class=FileResponse) async def create_inference_job( video: UploadFile = File(...), seed: int = Form(666), res_h: int = Form(720), res_w: int = Form(1280), ): """ Recebe um vídeo e parâmetros, executa a inferência e retorna o vídeo processado. """ # Cria diretórios temporários únicos para esta requisição para evitar conflitos job_id = str(uuid.uuid4()) input_dir = os.path.join("/app", "temp_inputs", job_id) output_dir = os.path.join("/app", "temp_outputs", job_id) os.makedirs(input_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) input_video_path = os.path.join(input_dir, video.filename) try: # Salva o vídeo enviado para o disco with open(input_video_path, "wb") as buffer: shutil.copyfileobj(video.file, buffer) # Executa a função de inferência pesada em um thread separado # para não bloquear o servidor da API result_path = await run_in_threadpool( run_inference_blocking, input_video_path=input_video_path, output_dir=output_dir, seed=seed, res_h=res_h, res_w=res_w ) # Retorna o arquivo de vídeo como uma resposta para download return FileResponse(path=result_path, media_type='video/mp4', filename=os.path.basename(result_path)) finally: # Limpa os diretórios temporários após a conclusão ou falha print("Limpando diretórios temporários...") shutil.rmtree(input_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True)