aducsdr commited on
Commit
e6d2def
·
verified ·
1 Parent(s): c3bf719

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +108 -0
main.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py (API com FastAPI)
2
+
3
+ import os
4
+ import uuid
5
+ import shutil
6
+ import subprocess
7
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
8
+ from fastapi.responses import FileResponse
9
+ from fastapi.concurrency import run_in_threadpool
10
+
11
+ # Diretório base onde o código do SeedVR está
12
+ SEEDVR_DIR = "/app/SeedVR"
13
+
14
+ app = FastAPI()
15
+
16
+ def run_inference_blocking(input_video_path: str, output_dir: str, seed: int, res_h: int, res_w: int) -> str:
17
+ """
18
+ Função síncrona que executa o script torchrun.
19
+ Ela bloqueia a execução, por isso deve ser chamada em um thread separado.
20
+ """
21
+ # O script de inferência espera ser executado de dentro do diretório SeedVR
22
+ # e que os caminhos de entrada/saída sejam relativos a ele.
23
+
24
+ # Constrói o caminho relativo para a pasta de entrada
25
+ input_folder_relative = os.path.relpath(os.path.dirname(input_video_path), SEEDVR_DIR)
26
+
27
+ # Constrói o caminho relativo para a pasta de saída
28
+ output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
29
+
30
+ command = [
31
+ "torchrun",
32
+ "--nproc-per-node=4",
33
+ "projects/inference_seedvr2_3b.py",
34
+ "--video_path", input_folder_relative,
35
+ "--output_dir", output_folder_relative,
36
+ "--seed", str(seed),
37
+ "--res_h", str(res_h),
38
+ "--res_w", str(res_w),
39
+ "--sp_size", "1", # Mantido fixo ou pode se tornar um parâmetro
40
+ ]
41
+
42
+ try:
43
+ print(f"Executando comando: {' '.join(command)}")
44
+ # Executa o subprocesso a partir do diretório do SeedVR
45
+ subprocess.run(command, cwd=SEEDVR_DIR, check=True, capture_output=True, text=True)
46
+ except subprocess.CalledProcessError as e:
47
+ # Se o script falhar, captura o erro e o log para depuração
48
+ print("Erro na execução do subprocesso!")
49
+ print(f"Stdout: {e.stdout}")
50
+ print(f"Stderr: {e.stderr}")
51
+ raise HTTPException(status_code=500, detail=f"A inferência falhou: {e.stderr}")
52
+
53
+ # Encontra o arquivo de saída gerado
54
+ output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png'))]
55
+ if not output_files:
56
+ raise HTTPException(status_code=500, detail="A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.")
57
+
58
+ return os.path.join(output_dir, output_files[0])
59
+
60
+
61
+ @app.get("/")
62
+ async def root():
63
+ return {"message": "API de Inferência SeedVR2 está online. Use o endpoint /infer/ para processar vídeos."}
64
+
65
+
66
+ @app.post("/infer/", response_class=FileResponse)
67
+ async def create_inference_job(
68
+ video: UploadFile = File(...),
69
+ seed: int = Form(666),
70
+ res_h: int = Form(720),
71
+ res_w: int = Form(1280),
72
+ ):
73
+ """
74
+ Recebe um vídeo e parâmetros, executa a inferência e retorna o vídeo processado.
75
+ """
76
+ # Cria diretórios temporários únicos para esta requisição para evitar conflitos
77
+ job_id = str(uuid.uuid4())
78
+ input_dir = os.path.join("/app", "temp_inputs", job_id)
79
+ output_dir = os.path.join("/app", "temp_outputs", job_id)
80
+ os.makedirs(input_dir, exist_ok=True)
81
+ os.makedirs(output_dir, exist_ok=True)
82
+
83
+ input_video_path = os.path.join(input_dir, video.filename)
84
+
85
+ try:
86
+ # Salva o vídeo enviado para o disco
87
+ with open(input_video_path, "wb") as buffer:
88
+ shutil.copyfileobj(video.file, buffer)
89
+
90
+ # Executa a função de inferência pesada em um thread separado
91
+ # para não bloquear o servidor da API
92
+ result_path = await run_in_threadpool(
93
+ run_inference_blocking,
94
+ input_video_path=input_video_path,
95
+ output_dir=output_dir,
96
+ seed=seed,
97
+ res_h=res_h,
98
+ res_w=res_w
99
+ )
100
+
101
+ # Retorna o arquivo de vídeo como uma resposta para download
102
+ return FileResponse(path=result_path, media_type='video/mp4', filename=os.path.basename(result_path))
103
+
104
+ finally:
105
+ # Limpa os diretórios temporários após a conclusão ou falha
106
+ print("Limpando diretórios temporários...")
107
+ shutil.rmtree(input_dir, ignore_errors=True)
108
+ shutil.rmtree(output_dir, ignore_errors=True)