Eueuiaa commited on
Commit
52d1c8b
·
verified ·
1 Parent(s): 06e4a29

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +245 -80
api/seedvr_server.py CHANGED
@@ -1,111 +1,276 @@
 
 
1
  import os
2
- import shutil
3
- import subprocess
4
  import sys
5
  import time
6
- import mimetypes
 
 
7
  from pathlib import Path
8
- from typing import List, Optional, Tuple
9
 
10
  from huggingface_hub import hf_hub_download
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class SeedVRServer:
13
  def __init__(self, **kwargs):
14
- self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
15
- # Apontamos para o nosso diretório de checkpoints customizado
16
- self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
 
17
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
18
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
19
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
20
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
21
- self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
22
-
23
- print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
24
- for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
25
  p.mkdir(parents=True, exist_ok=True)
26
-
27
  self.setup_dependencies()
28
- print(" SeedVRServer (FP16) pronto.")
29
 
30
  def setup_dependencies(self):
31
- self._ensure_repo()
32
- # O monkey patch agora é feito pelo start_seedvr.sh, não mais aqui.
33
- self._ensure_model()
34
-
35
- def _ensure_repo(self) -> None:
36
  if not (self.SEEDVR_ROOT / ".git").exists():
37
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
38
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
39
  else:
40
  print("[SeedVRServer] Repositório SeedVR já existe.")
41
-
42
- def _ensure_model(self) -> None:
43
- """Baixa os arquivos de modelo FP16 otimizados e suas dependências."""
44
- print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
45
 
 
 
46
  model_files = {
47
- "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses", "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
48
- "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B", "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
49
  }
50
-
51
  for filename, repo_id in model_files.items():
52
  if not (self.CKPTS_ROOT / filename).exists():
53
- print(f"Baixando {filename} de {repo_id}...")
54
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT), cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN"))
55
- print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
56
-
57
- def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
58
- ts = f"{int(time.time())}_{os.urandom(4).hex()}"
59
- job_input_dir = self.INPUT_ROOT / f"job_{ts}"
60
- out_dir = self.OUTPUT_ROOT / f"run_{ts}"
61
- job_input_dir.mkdir(parents=True, exist_ok=True)
62
- out_dir.mkdir(parents=True, exist_ok=True)
63
- shutil.copy2(input_file, job_input_dir / Path(input_file).name)
64
- return job_input_dir, out_dir
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def run_inference(self, filepath: str, *, seed: int, resh: int, resw: int, spsize: int, fps: Optional[float] = None):
67
- script = self.SEEDVR_ROOT / "inference_cli.py"
68
- job_input_dir, outdir = self._prepare_job(filepath)
69
- mediatype, _ = mimetypes.guess_type(filepath)
70
- is_image = mediatype and mediatype.startswith("image")
71
-
72
- effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
73
- effective_spsize = 1 if is_image else spsize
74
-
75
- output_filename = f"result_{Path(filepath).stem}.mp4" if not is_image else f"{Path(filepath).stem}_upscaled"
76
- output_filepath = outdir / output_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
78
 
 
 
 
 
 
 
 
79
 
80
- cmd = [
81
- "torchrun", "--standalone", "--nnodes=1",
82
- f"--nproc-per-node={effective_nproc}",
83
- str(script),
84
- "--video_path", str(filepath),
85
- "--output", str(output_filepath),
86
- "--model_dir", str(self.CKPTS_ROOT),
87
- "--seed", str(seed),
88
- "--cuda_device", "0",
89
- "--resolution", str(resh),
90
- "--batch_size", str(effective_spsize),
91
- "--model", "seedvr2_ema_3b_fp16.safetensors",
92
- "--preserve_vram",
93
- "--debug",
94
- "--output_format", "video" if not is_image else "png",
95
- ]
96
-
97
-
98
- print("SeedVRServer Comando:", " ".join(cmd))
99
- try:
100
- subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr)
101
- # Constrói a tupla de retorno de forma determinística
102
- if is_image:
103
- # CLI salva PNGs em diretório args.output (tratado como diretório quando outputformat=png)
104
- image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
105
- return str(image_dir), None, outdir
106
- else:
107
- # CLI salva vídeo exatamente em output_filepath
108
- return None, str(output_filepath), outdir
109
- except Exception as e:
110
- print(f"[UI ERROR] A inferência falhou: {e}")
111
- return None, None, None
 
 
 
 
 
 
1
+ # api/seedvr_server.py
2
+
3
  import os
 
 
4
  import sys
5
  import time
6
+ import subprocess
7
+ import queue
8
+ import multiprocessing as mp
9
  from pathlib import Path
10
+ from typing import Optional, Callable
11
 
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # -------------------------------------------------------------
15
+ # 1. CONFIGURAÇÃO DE AMBIENTE E CUDA
16
+ # -------------------------------------------------------------
17
+
18
+ # Garante o uso seguro de CUDA com multiprocessing para estabilidade.
19
+ if mp.get_start_method(allow_none=True) != 'spawn':
20
+ mp.set_start_method('spawn', force=True)
21
+
22
+ # Configuração de alocação de memória da VRAM
23
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
24
+
25
+ # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
26
+ SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
27
+ if str(SEEDVR_REPO_PATH) not in sys.path:
28
+ sys.path.insert(0, str(SEEDVR_REPO_PATH))
29
+
30
+ # Importações pesadas (torch, etc.) são feitas após a configuração do ambiente.
31
+ import torch
32
+ import cv2
33
+ import numpy as np
34
+ from datetime import datetime
35
+
36
+ # -------------------------------------------------------------
37
+ # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO (Workers e I/O)
38
+ # -------------------------------------------------------------
39
+
40
+ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
41
+ """Extrai quadros de um vídeo e os converte para o formato de tensor."""
42
+ if debug: print(f"🎬 Extraindo frames de: {video_path}")
43
+ if not os.path.exists(video_path): raise FileNotFoundError(f"Arquivo de vídeo não encontrado: {video_path}")
44
+
45
+ cap = cv2.VideoCapture(video_path)
46
+ if not cap.isOpened(): raise ValueError(f"Não foi possível abrir o arquivo de vídeo: {video_path}")
47
+
48
+ fps = cap.get(cv2.CAP_PROP_FPS)
49
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
50
+ if debug: print(f"📊 Info do vídeo: {frame_count} frames, {fps:.2f} FPS")
51
+
52
+ frames = []
53
+ frames_loaded = 0
54
+ for i in range(frame_count):
55
+ ret, frame = cap.read()
56
+ if not ret: break
57
+ if i < skip_first_frames: continue
58
+ if load_cap and frames_loaded >= load_cap: break
59
+
60
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
+ frames.append(frame.astype(np.float32) / 255.0)
62
+ frames_loaded += 1
63
+ cap.release()
64
+
65
+ if not frames: raise ValueError(f"Nenhum frame foi extraído do vídeo: {video_path}")
66
+ if debug: print(f"✅ {len(frames)} frames extraídos com sucesso.")
67
+ return torch.from_numpy(np.stack(frames)).to(torch.float16), fps
68
+
69
+ def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
70
+ """Salva um tensor de quadros em um arquivo de vídeo."""
71
+ if debug: print(f"🎬 Salvando {frames_tensor.shape[0]} frames em: {output_path}")
72
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
73
+
74
+ frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
75
+ T, H, W, _ = frames_np.shape
76
+
77
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
78
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
79
+ if not out.isOpened(): raise ValueError(f"Não foi possível criar o arquivo de vídeo: {output_path}")
80
+
81
+ for frame in frames_np:
82
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
83
+ out.release()
84
+ if debug: print(f"✅ Vídeo salvo com sucesso: {output_path}")
85
+
86
+ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
87
+ """Processo filho (worker) que executa o upscaling em uma GPU dedicada."""
88
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
89
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
90
+
91
+ import torch
92
+ from src.core.model_manager import configure_runner
93
+ from src.core.generation import generation_loop
94
+
95
+ try:
96
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
97
+
98
+ callback = (lambda b, t, _, m: progress_queue.put((proc_idx, b, t, m))) if progress_queue else None
99
+
100
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
101
+ result_tensor = generation_loop(
102
+ runner=runner, images=frames_tensor, cfg_scale=1.0, seed=shared_args["seed"],
103
+ res_w=shared_args["resolution"], batch_size=shared_args["batch_size"],
104
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=0,
105
+ debug=shared_args["debug"], progress_callback=callback
106
+ )
107
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
108
+ except Exception as e:
109
+ import traceback
110
+ error_msg = f"ERRO no worker {proc_idx}: {e}\n{traceback.format_exc()}"
111
+ print(error_msg)
112
+ if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
113
+ return_queue.put((proc_idx, error_msg))
114
+
115
+ # -------------------------------------------------------------
116
+ # 3. CLASSE DO SERVIDOR PRINCIPAL
117
+ # -------------------------------------------------------------
118
+
119
  class SeedVRServer:
120
  def __init__(self, **kwargs):
121
+ """Inicializa o servidor, define os caminhos e prepara o ambiente."""
122
+ print("⚙️ SeedVRServer inicializando...")
123
+ self.SEEDVR_ROOT = SEEDVR_REPO_PATH
124
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
125
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
126
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
127
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
128
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
129
+ self.NUM_GPUS_TOTAL = torch.cuda.device_count()
130
+
131
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
 
132
  p.mkdir(parents=True, exist_ok=True)
133
+
134
  self.setup_dependencies()
135
+ print("📦 SeedVRServer pronto.")
136
 
137
  def setup_dependencies(self):
138
+ """Garante que o repositório e os modelos estão presentes."""
139
+ # Clona o repositório do SeedVR se não existir
 
 
 
140
  if not (self.SEEDVR_ROOT / ".git").exists():
141
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
142
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
143
  else:
144
  print("[SeedVRServer] Repositório SeedVR já existe.")
 
 
 
 
145
 
146
+ # Baixa os checkpoints do Hugging Face se não existirem
147
+ print(f"[SeedVRServer] Verificando checkpoints em {self.CKPTS_ROOT}...")
148
  model_files = {
149
+ "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
150
+ "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses"
151
  }
 
152
  for filename, repo_id in model_files.items():
153
  if not (self.CKPTS_ROOT / filename).exists():
154
+ print(f"Baixando {filename}...")
155
+ from huggingface_hub import hf_hub_download
156
+ hf_hub_download(
157
+ repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
158
+ cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
159
+ )
160
+ print("[SeedVRServer] Checkpoints estão no local correto.")
161
+
162
+ def run_inference(
163
+ self,
164
+ file_path: str, *,
165
+ seed: int,
166
+ resolution: int,
167
+ batch_size: int,
168
+ model: str = "seedvr2_ema_3b_fp16.safetensors",
169
+ fps: Optional[float] = None,
170
+ debug: bool = False,
171
+ preserve_vram: bool = True,
172
+ progress: Optional[Callable] = None
173
+ ) -> str:
174
+ """
175
+ Executa o pipeline completo de upscaling de vídeo e retorna o caminho do arquivo de saída.
176
+ """
177
+ if progress: progress(0.01, "⌛ Inicializando...")
178
 
179
+ # --- 1. Extração de Frames ---
180
+ if progress: progress(0.05, "🎬 Extraindo frames do vídeo...")
181
+ frames_tensor, original_fps = extract_frames_from_video(file_path, debug)
182
+
183
+ # --- 2. Preparação do Processamento Multi-GPU ---
184
+ device_list = list(range(self.NUM_GPUS_TOTAL))
185
+ num_devices = len(device_list)
186
+ chunks = torch.chunk(frames_tensor, num_devices, dim=0)
187
+
188
+ manager = mp.Manager()
189
+ return_queue = manager.Queue()
190
+ progress_queue = manager.Queue() if progress else None
191
+
192
+ shared_args = {
193
+ "model": model, "model_dir": str(self.CKPTS_ROOT), "preserve_vram": preserve_vram,
194
+ "debug": debug, "seed": seed, "resolution": resolution, "batch_size": batch_size
195
+ }
196
+
197
+ # --- 3. Inicia os Workers ---
198
+ if progress: progress(0.1, f"🚀 Iniciando geração em {num_devices} GPUs...")
199
+ workers = []
200
+ for idx, device_id in enumerate(device_list):
201
+ p = mp.Process(target=_worker_process, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
202
+ p.start()
203
+ workers.append(p)
204
+
205
+ # --- 4. Coleta de Resultados e Monitoramento de Progresso ---
206
+ results_np = [None] * num_devices
207
+ finished_workers = 0
208
+ worker_progress = [0.0] * num_devices
209
+ while finished_workers < num_devices:
210
+ # Atualiza a barra de progresso com informações da fila
211
+ if progress_queue:
212
+ while not progress_queue.empty():
213
+ try:
214
+ p_idx, b_idx, b_total, msg = progress_queue.get_nowait()
215
+ if b_idx == -1: raise RuntimeError(f"Erro no Worker {p_idx}: {msg}")
216
+ if b_total > 0: worker_progress[p_idx] = b_idx / b_total
217
+ total_progress = sum(worker_progress) / num_devices
218
+ progress(0.1 + total_progress * 0.85, desc=f"GPU {p_idx+1}/{num_devices}: {msg}")
219
+ except queue.Empty: pass
220
+
221
+ # Verifica se algum worker terminou
222
+ try:
223
+ proc_idx, result = return_queue.get(timeout=0.2)
224
+ if isinstance(result, str): raise RuntimeError(f"Worker {proc_idx} falhou: {result}")
225
+ results_np[proc_idx] = result
226
+ worker_progress[proc_idx] = 1.0
227
+ finished_workers += 1
228
+ except queue.Empty: pass
229
 
230
+ for p in workers: p.join()
231
 
232
+ if any(r is None for r in results_np):
233
+ raise RuntimeError("Um ou mais workers falharam ao retornar um resultado.")
234
+
235
+ # --- 5. Combina os resultados e salva o vídeo final ---
236
+ result_tensor = torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
237
+
238
+ if progress: progress(0.95, "💾 Salvando o vídeo final...")
239
 
240
+ out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
241
+ out_dir.mkdir(parents=True, exist_ok=True)
242
+ output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
243
+
244
+ final_fps = fps if fps and fps > 0 else original_fps
245
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, debug)
246
+
247
+ print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
248
+ return str(output_filepath)
249
+
250
+ # -------------------------------------------------------------
251
+ # 4. PONTO DE ENTRADA PARA EXECUÇÃO
252
+ # -------------------------------------------------------------
253
+
254
+ if __name__ == "__main__":
255
+ # Bloco para testes ou inicialização autônoma.
256
+ print("🚀 Executando o servidor SeedVR em modo autônomo...")
257
+ try:
258
+ server = SeedVRServer()
259
+ print("✅ Servidor inicializado com sucesso. Pronto para receber chamadas.")
260
+ # Exemplo de como chamar a inferência (requer um arquivo de vídeo):
261
+ # input_video = "caminho/para/seu/video.mp4"
262
+ # if os.path.exists(input_video):
263
+ # server.run_inference(
264
+ # file_path=input_video,
265
+ # seed=42,
266
+ # resolution=1072,
267
+ # batch_size=4,
268
+ # progress=lambda p, desc: print(f"Progresso: {p*100:.1f}% - {desc}")
269
+ # )
270
+ # else:
271
+ # print(f"Vídeo de teste não encontrado em '{input_video}'. Pulei a execução da inferência.")
272
+ except Exception as e:
273
+ print(f"❌ Falha ao inicializar o servidor: {e}")
274
+ import traceback
275
+ traceback.print_exc()
276
+ sys.exit(1)