EuuIia commited on
Commit
ed88963
·
verified ·
1 Parent(s): 55327ec

Upload app_seedvr.py

Browse files
Files changed (1) hide show
  1. app_seedvr.py +117 -0
app_seedvr.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SeedVR UI (Gradio) — CLI torchrun
4
+
5
+ - Upload único: vídeo (.mp4) ou imagem (.png/.jpg/.jpeg/.webp).
6
+ - Parâmetros: seed, res_h, res_w, sp_size.
7
+ - Executa via torchrun com NUM_GPUS (do ambiente).
8
+ - Exibe vídeo se a entrada for vídeo; imagem se for imagem.
9
+ """
10
+
11
+ import os
12
+ import mimetypes
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ import gradio as gr
17
+
18
+ from services.seed_server import SeedVRServer
19
+
20
+ # Instância única do servidor (clona repo, baixa modelo, cria symlink)
21
+ server = SeedVRServer()
22
+
23
+ # Paths padrão (para allowed_paths e debug)
24
+ OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
25
+ CKPTS_ROOT = Path(os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B"))
26
+
27
+ def _is_video(path: str) -> bool:
28
+ mime, _ = mimetypes.guess_type(path)
29
+ return (mime or "").startswith("video") or str(path).lower().endswith(".mp4")
30
+
31
+ def _is_image(path: str) -> bool:
32
+ mime, _ = mimetypes.guess_type(path)
33
+ if mime and mime.startswith("image"):
34
+ return True
35
+ return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
36
+
37
+ def ui_infer(
38
+ input_path: Optional[str],
39
+ seed: int,
40
+ res_h: int,
41
+ res_w: int,
42
+ sp_size: int,
43
+ ):
44
+ if not input_path or not Path(input_path).exists():
45
+ gr.Warning("Arquivo de entrada ausente ou inválido.")
46
+ return None, None
47
+
48
+ is_vid = _is_video(input_path)
49
+ is_img = _is_image(input_path)
50
+ if not (is_vid or is_img):
51
+ gr.Warning("Tipo de arquivo não suportado. Envie .mp4, .png, .jpg, .jpeg ou .webp.")
52
+ return None, None
53
+
54
+ try:
55
+ video_out, image_out, _ = server.run_inference(
56
+ file_path=input_path,
57
+ seed=int(seed),
58
+ res_h=int(res_h),
59
+ res_w=int(res_w),
60
+ sp_size=int(sp_size),
61
+ )
62
+ except Exception as e:
63
+ gr.Warning(f"Erro na inferência: {e}")
64
+ return None, None
65
+
66
+ if is_vid:
67
+ if video_out and Path(video_out).exists():
68
+ return None, video_out
69
+ if image_out and Path(image_out).exists():
70
+ return image_out, None
71
+ gr.Warning("Nenhum resultado encontrado.")
72
+ return None, None
73
+ else:
74
+ if image_out and Path(image_out).exists():
75
+ return image_out, None
76
+ if video_out and Path(video_out).exists():
77
+ return None, video_out
78
+ gr.Warning("Nenhum resultado encontrado.")
79
+ return None, None
80
+
81
+ with gr.Blocks(title="SeedVR (CLI torchrun)") as demo:
82
+ gr.Markdown(
83
+ "\n".join([
84
+ "# SeedVR — Restauração (CLI torchrun)",
85
+ "- Envie um vídeo (.mp4) ou uma imagem (.png/.jpg/.jpeg/.webp).",
86
+ "- A execução utiliza torchrun com múltiplas GPUs.",
87
+ ])
88
+ )
89
+
90
+ with gr.Row():
91
+ inp = gr.File(label="Entrada (vídeo .mp4 ou imagem)", type="filepath")
92
+
93
+ with gr.Row():
94
+ seed = gr.Number(label="Seed", value=int(os.getenv("SEED", "42")), precision=0)
95
+ res_h = gr.Number(label="Altura (res_h)", value=int(os.getenv("RES_H", "720")), precision=0)
96
+ res_w = gr.Number(label="Largura (res_w)", value=int(os.getenv("RES_W", "1280")), precision=0)
97
+ sp_size = gr.Number(label="sp_size", value=int(os.getenv("SP_SIZE", "4")), precision=0)
98
+
99
+ run = gr.Button("Restaurar", variant="primary")
100
+
101
+ out_image = gr.Image(label="Resultado (imagem)")
102
+ out_video = gr.Video(label="Resultado (vídeo)")
103
+
104
+ run.click(
105
+ ui_infer,
106
+ inputs=[inp, seed, res_h, res_w, sp_size],
107
+ outputs=[out_image, out_video],
108
+ )
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch(
112
+ server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
113
+ server_port=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
114
+ allowed_paths=[str(OUTPUT_ROOT), str(CKPTS_ROOT)],
115
+ show_error=True,
116
+ )
117
+