EuuIia commited on
Commit
a5720bf
·
verified ·
1 Parent(s): 1300afb

Rename seed_server.py to api/seedvr_server.py

Browse files
Files changed (2) hide show
  1. api/seedvr_server.py +144 -0
  2. seed_server.py +0 -187
api/seedvr_server.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+
3
+
4
+
5
+ euIaxs22
6
+ /
7
+ Aduc-sdr-2_5
8
+
9
+ like
10
+ 0
11
+ App
12
+ Files
13
+ Community
14
+ Aduc-sdr-2_5
15
+ /
16
+ services
17
+ /
18
+ seed_server.py
19
+
20
+ euIaxs22's picture
21
+ euIaxs22
22
+ Update services/seed_server.py
23
+ 1cd2f56
24
+ verified
25
+ raw
26
+
27
+ Copy download link
28
+ history
29
+ blame
30
+ contribute
31
+ delete
32
+
33
+ 5.25 kB
34
+ import os
35
+ import shutil
36
+ import subprocess
37
+ import sys
38
+ import time
39
+ import mimetypes
40
+ from pathlib import Path
41
+ from typing import List, Optional, Tuple
42
+
43
+ from huggingface_hub import hf_hub_download
44
+
45
+ class SeedVRServer:
46
+ def __init__(self, **kwargs):
47
+ self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
48
+ # Apontamos para o nosso diretório de checkpoints customizado
49
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
50
+ self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
51
+ self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
52
+ self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
53
+ self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
54
+ self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
55
+
56
+ print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
57
+ for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
58
+ p.mkdir(parents=True, exist_ok=True)
59
+
60
+ self.setup_dependencies()
61
+ print("✅ SeedVRServer (FP16) pronto.")
62
+
63
+ def setup_dependencies(self):
64
+ self._ensure_repo()
65
+ # O monkey patch agora é feito pelo start_seedvr.sh, não mais aqui.
66
+ self._ensure_model()
67
+
68
+ def _ensure_repo(self) -> None:
69
+ if not (self.SEEDVR_ROOT / ".git").exists():
70
+ print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
71
+ subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
72
+ else:
73
+ print("[SeedVRServer] Repositório SeedVR já existe.")
74
+
75
+ def _ensure_model(self) -> None:
76
+ """Baixa os arquivos de modelo FP16 otimizados e suas dependências."""
77
+ print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
78
+
79
+ model_files = {
80
+ "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses", "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
81
+ "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B", "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
82
+ }
83
+
84
+ for filename, repo_id in model_files.items():
85
+ if not (self.CKPTS_ROOT / filename).exists():
86
+ print(f"Baixando {filename} de {repo_id}...")
87
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT), cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN"))
88
+ print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
89
+
90
+ def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
91
+ ts = f"{int(time.time())}_{os.urandom(4).hex()}"
92
+ job_input_dir = self.INPUT_ROOT / f"job_{ts}"
93
+ out_dir = self.OUTPUT_ROOT / f"run_{ts}"
94
+ job_input_dir.mkdir(parents=True, exist_ok=True)
95
+ out_dir.mkdir(parents=True, exist_ok=True)
96
+ shutil.copy2(input_file, job_input_dir / Path(input_file).name)
97
+ return job_input_dir, out_dir
98
+
99
+ def run_inference(self, filepath: str, *, seed: int, resh: int, resw: int, spsize: int, fps: Optional[float] = None):
100
+ script = self.SEEDVR_ROOT / "inference_cli.py"
101
+ job_input_dir, outdir = self._prepare_job(filepath)
102
+ mediatype, _ = mimetypes.guess_type(filepath)
103
+ is_image = mediatype and mediatype.startswith("image")
104
+
105
+ effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
106
+ effective_spsize = 1 if is_image else spsize
107
+
108
+ output_filename = f"result_{Path(filepath).stem}.mp4" if not is_image else f"{Path(filepath).stem}_upscaled"
109
+ output_filepath = outdir / output_filename
110
+
111
+
112
+
113
+ cmd = [
114
+ "torchrun", "--standalone", "--nnodes=1",
115
+ f"--nproc-per-node={effective_nproc}",
116
+ str(script),
117
+ "--video_path", str(filepath),
118
+ "--output", str(output_filepath),
119
+ "--model_dir", str(self.CKPTS_ROOT),
120
+ "--seed", str(seed),
121
+ "--cuda_device", "0",
122
+ "--resolution", str(resh),
123
+ "--batch_size", str(effective_spsize),
124
+ "--model", "seedvr2_ema_3b_fp16.safetensors",
125
+ "--preserve_vram",
126
+ "--debug",
127
+ "--output_format", "video" if not is_image else "png",
128
+ ]
129
+
130
+
131
+ print("SeedVRServer Comando:", " ".join(cmd))
132
+ try:
133
+ subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr)
134
+ # Constrói a tupla de retorno de forma determinística
135
+ if is_image:
136
+ # CLI salva PNGs em diretório args.output (tratado como diretório quando outputformat=png)
137
+ image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
138
+ return str(image_dir), None, outdir
139
+ else:
140
+ # CLI salva vídeo exatamente em output_filepath
141
+ return None, str(output_filepath), outdir
142
+ except Exception as e:
143
+ print(f"[UI ERROR] A inferência falhou: {e}")
144
+ return None, None, None
seed_server.py DELETED
@@ -1,187 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- SeedVR Server (CLI torchrun)
4
-
5
- - Garante repositório SeedVR e checkpoints baixados via snapshot_download.
6
- - Cria symlink SeedVR/ckpts/SeedVR2-3B -> CKPTS_ROOT.
7
- - Executa projects/inference_seedvr2_3b.py com torchrun e NUM_GPUS.
8
- - API: run_inference(file_path, seed, res_h, res_w, sp_size) -> (video_out, image_out, out_dir).
9
- """
10
-
11
- import os
12
- import shutil
13
- import subprocess
14
- from pathlib import Path
15
- from typing import Optional, Tuple, List
16
- import time
17
- import mimetypes
18
-
19
- from huggingface_hub import snapshot_download # requerido no container
20
-
21
- class SeedVRServer:
22
- def __init__(
23
- self,
24
- *,
25
- seedvr_root: Optional[str] = None,
26
- ckpts_root: Optional[str] = None,
27
- output_root: Optional[str] = None,
28
- input_root: Optional[str] = None,
29
- repo_url: Optional[str] = None,
30
- repo_id: Optional[str] = None,
31
- num_gpus: Optional[int] = None,
32
- ):
33
- # Paths e envs
34
- self.SEEDVR_ROOT = Path(seedvr_root or os.getenv("SEEDVR_ROOT", "/app/SeedVR"))
35
- self.CKPTS_ROOT = Path(ckpts_root or os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B"))
36
- self.OUTPUT_ROOT = Path(output_root or os.getenv("OUTPUT_ROOT", "/app/outputs"))
37
- self.INPUT_ROOT = Path(input_root or os.getenv("INPUT_ROOT", "/app/inputs"))
38
- self.REPO_URL = repo_url or os.getenv("SEEDVR_GIT_URL", "https://github.com/ByteDance-Seed/SeedVR.git")
39
- self.REPO_ID = repo_id or os.getenv("SEEDVR_REPO_ID", "ByteDance-Seed/SeedVR2-3B")
40
- self.NUM_GPUS = int(num_gpus or os.getenv("NUM_GPUS", "8"))
41
- self.HF_HOME = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
42
- self.HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or None
43
-
44
- # Diretórios necessários
45
- for p in [self.SEEDVR_ROOT, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME]:
46
- p.mkdir(parents=True, exist_ok=True)
47
-
48
- # Bootstrap direto
49
- self._ensure_repo()
50
- self._ensure_model()
51
- self._ensure_ckpt_symlink()
52
-
53
- # ---------- Preparação ----------
54
- def _ensure_repo(self) -> None:
55
- if not (self.SEEDVR_ROOT / ".git").exists():
56
- print(f"[seed_server] cloning repo into {self.SEEDVR_ROOT}")
57
- subprocess.run(["git", "clone", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
58
- else:
59
- print(f"[seed_server] repo present at {self.SEEDVR_ROOT}")
60
-
61
- def _ensure_model(self) -> None:
62
- print(f"[seed_server] downloading model {self.REPO_ID} into {self.CKPTS_ROOT} (snapshot_download)")
63
- self.CKPTS_ROOT.mkdir(parents=True, exist_ok=True)
64
- snapshot_download(
65
- repo_id=self.REPO_ID,
66
- cache_dir=str(self.HF_HOME),
67
- local_dir=str(self.CKPTS_ROOT),
68
- local_dir_use_symlinks=False,
69
- resume_download=True,
70
- allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"],
71
- token=self.HF_TOKEN,
72
- )
73
- print("[seed_server] model ready")
74
-
75
- def _ensure_ckpt_symlink(self) -> None:
76
- ckpts_repo_dir = self.SEEDVR_ROOT / "ckpts"
77
- ckpts_repo_dir.mkdir(parents=True, exist_ok=True)
78
- link = ckpts_repo_dir / "SeedVR2-3B"
79
- try:
80
- if link.is_symlink():
81
- try:
82
- if link.resolve() != self.CKPTS_ROOT:
83
- link.unlink()
84
- except Exception:
85
- link.unlink(missing_ok=True)
86
- if not link.exists():
87
- link.symlink_to(self.CKPTS_ROOT, target_is_directory=True)
88
- print(f"[seed_server] symlink ok: {link} -> {self.CKPTS_ROOT}")
89
- except Exception as e:
90
- print("[seed_server] warn: ckpt symlink failed:", e)
91
-
92
- # ---------- Util ----------
93
- @staticmethod
94
- def _is_video(path: str) -> bool:
95
- mime, _ = mimetypes.guess_type(path)
96
- return (mime or "").startswith("video") or str(path).lower().endswith(".mp4")
97
-
98
- @staticmethod
99
- def _is_image(path: str) -> bool:
100
- mime, _ = mimetypes.guess_type(path)
101
- if mime and mime.startswith("image"):
102
- return True
103
- return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
104
-
105
- def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
106
- ts = int(time.time())
107
- job_dir = self.INPUT_ROOT / f"job_{ts}"
108
- out_dir = self.OUTPUT_ROOT / f"run_{ts}"
109
- job_dir.mkdir(parents=True, exist_ok=True)
110
- out_dir.mkdir(parents=True, exist_ok=True)
111
-
112
- #####
113
- shutil.copy2(input_file, job_dir / Path(input_file).name)
114
- return out_dir, out_dir
115
-
116
- # ---------- Execução ----------
117
- def run_inference(
118
- self,
119
- file_path: str,
120
- *,
121
- seed: int = 42,
122
- res_h: int = 720,
123
- res_w: int = 1280,
124
- sp_size: int = 4,
125
- extra_args: Optional[List[str]] = None,
126
- ) -> Tuple[Optional[str], Optional[str], Path]:
127
- """
128
- Executa inferência via torchrun com NUM_GPUS:
129
- - file_path: vídeo .mp4 ou imagem .png/.jpg/.jpeg/.webp
130
- - Retorna (video_out, image_out, out_dir). Um dos dois primeiros será não-nulo.
131
- """
132
- if not Path(file_path).exists():
133
- raise FileNotFoundError(f"input not found: {file_path}")
134
-
135
- script = self.SEEDVR_ROOT / "projects" / "inference_seedvr2_3b.py"
136
- if not script.exists():
137
- raise FileNotFoundError(f"inference script not found: {script}")
138
-
139
- job_dir, out_dir = self._prepare_job(file_path)
140
- self._ensure_ckpt_symlink()
141
-
142
- out_dir.mkdir(parents=True, exist_ok=True)
143
- os.chmod(out_dir, 777)
144
-
145
- job_dir.mkdir(parents=True, exist_ok=True)
146
- os.chmod(job_dir, 777)
147
-
148
- cmd = [
149
- "torchrun",
150
- f"--nproc-per-node={self.NUM_GPUS}",
151
- str(script),
152
- "--video_path", str(job_dir),
153
- "--output_dir", str(out_dir),
154
- "--seed", str(seed),
155
- "--res_h", str(res_h),
156
- "--res_w", str(res_w),
157
- "--sp_size", str(sp_size),
158
- ]
159
- if extra_args:
160
- cmd.extend(extra_args)
161
-
162
- env = os.environ.copy()
163
- env.setdefault("HF_HOME", str(self.HF_HOME))
164
- env.setdefault("NCCL_P2P_LEVEL", os.getenv("NCCL_P2P_LEVEL", "NVL"))
165
- #env.setdefault("NCCL_ASYNC_ERROR_HANDLING", os.getenv("NCCL_ASYNC_ERROR_HANDLING", "1"))
166
- env.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "8"))
167
-
168
- print("[seed_server] running:", " ".join(cmd))
169
- try:
170
- subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=env)
171
- except subprocess.CalledProcessError as e:
172
- print("[seed_server] torchrun error:", e)
173
- return None, None, out_dir
174
-
175
- # Buscar artefatos
176
- videos = sorted(out_dir.rglob("*.mp4"), key=lambda p: p.stat().st_mtime)
177
- # Cobrir formatos comuns caso upstream mude
178
- if not videos:
179
- videos = sorted([*out_dir.rglob("*.mov"), *out_dir.rglob("*.avi")], key=lambda p: p.stat().st_mtime)
180
- images = sorted(
181
- [*out_dir.rglob("*.png"), *out_dir.rglob("*.jpg"), *out_dir.rglob("*.jpeg"), *out_dir.rglob("*.webp")],
182
- key=lambda p: p.stat().st_mtime
183
- )
184
-
185
- video_out = str(videos[-1]) if videos else None
186
- image_out = str(images[-1]) if images else None
187
- return video_out, image_out, out_dir