EuuIia commited on
Commit
1300afb
·
verified ·
1 Parent(s): eec60ab

Upload seed_server.py

Browse files
Files changed (1) hide show
  1. seed_server.py +187 -0
seed_server.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SeedVR Server (CLI torchrun)
4
+
5
+ - Garante repositório SeedVR e checkpoints baixados via snapshot_download.
6
+ - Cria symlink SeedVR/ckpts/SeedVR2-3B -> CKPTS_ROOT.
7
+ - Executa projects/inference_seedvr2_3b.py com torchrun e NUM_GPUS.
8
+ - API: run_inference(file_path, seed, res_h, res_w, sp_size) -> (video_out, image_out, out_dir).
9
+ """
10
+
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ from pathlib import Path
15
+ from typing import Optional, Tuple, List
16
+ import time
17
+ import mimetypes
18
+
19
+ from huggingface_hub import snapshot_download # requerido no container
20
+
21
+ class SeedVRServer:
22
+ def __init__(
23
+ self,
24
+ *,
25
+ seedvr_root: Optional[str] = None,
26
+ ckpts_root: Optional[str] = None,
27
+ output_root: Optional[str] = None,
28
+ input_root: Optional[str] = None,
29
+ repo_url: Optional[str] = None,
30
+ repo_id: Optional[str] = None,
31
+ num_gpus: Optional[int] = None,
32
+ ):
33
+ # Paths e envs
34
+ self.SEEDVR_ROOT = Path(seedvr_root or os.getenv("SEEDVR_ROOT", "/app/SeedVR"))
35
+ self.CKPTS_ROOT = Path(ckpts_root or os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B"))
36
+ self.OUTPUT_ROOT = Path(output_root or os.getenv("OUTPUT_ROOT", "/app/outputs"))
37
+ self.INPUT_ROOT = Path(input_root or os.getenv("INPUT_ROOT", "/app/inputs"))
38
+ self.REPO_URL = repo_url or os.getenv("SEEDVR_GIT_URL", "https://github.com/ByteDance-Seed/SeedVR.git")
39
+ self.REPO_ID = repo_id or os.getenv("SEEDVR_REPO_ID", "ByteDance-Seed/SeedVR2-3B")
40
+ self.NUM_GPUS = int(num_gpus or os.getenv("NUM_GPUS", "8"))
41
+ self.HF_HOME = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
42
+ self.HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or None
43
+
44
+ # Diretórios necessários
45
+ for p in [self.SEEDVR_ROOT, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME]:
46
+ p.mkdir(parents=True, exist_ok=True)
47
+
48
+ # Bootstrap direto
49
+ self._ensure_repo()
50
+ self._ensure_model()
51
+ self._ensure_ckpt_symlink()
52
+
53
+ # ---------- Preparação ----------
54
+ def _ensure_repo(self) -> None:
55
+ if not (self.SEEDVR_ROOT / ".git").exists():
56
+ print(f"[seed_server] cloning repo into {self.SEEDVR_ROOT}")
57
+ subprocess.run(["git", "clone", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
58
+ else:
59
+ print(f"[seed_server] repo present at {self.SEEDVR_ROOT}")
60
+
61
+ def _ensure_model(self) -> None:
62
+ print(f"[seed_server] downloading model {self.REPO_ID} into {self.CKPTS_ROOT} (snapshot_download)")
63
+ self.CKPTS_ROOT.mkdir(parents=True, exist_ok=True)
64
+ snapshot_download(
65
+ repo_id=self.REPO_ID,
66
+ cache_dir=str(self.HF_HOME),
67
+ local_dir=str(self.CKPTS_ROOT),
68
+ local_dir_use_symlinks=False,
69
+ resume_download=True,
70
+ allow_patterns=["*.json", "*.safetensors", "*.pth", "*.bin", "*.py", "*.md", "*.txt"],
71
+ token=self.HF_TOKEN,
72
+ )
73
+ print("[seed_server] model ready")
74
+
75
+ def _ensure_ckpt_symlink(self) -> None:
76
+ ckpts_repo_dir = self.SEEDVR_ROOT / "ckpts"
77
+ ckpts_repo_dir.mkdir(parents=True, exist_ok=True)
78
+ link = ckpts_repo_dir / "SeedVR2-3B"
79
+ try:
80
+ if link.is_symlink():
81
+ try:
82
+ if link.resolve() != self.CKPTS_ROOT:
83
+ link.unlink()
84
+ except Exception:
85
+ link.unlink(missing_ok=True)
86
+ if not link.exists():
87
+ link.symlink_to(self.CKPTS_ROOT, target_is_directory=True)
88
+ print(f"[seed_server] symlink ok: {link} -> {self.CKPTS_ROOT}")
89
+ except Exception as e:
90
+ print("[seed_server] warn: ckpt symlink failed:", e)
91
+
92
+ # ---------- Util ----------
93
+ @staticmethod
94
+ def _is_video(path: str) -> bool:
95
+ mime, _ = mimetypes.guess_type(path)
96
+ return (mime or "").startswith("video") or str(path).lower().endswith(".mp4")
97
+
98
+ @staticmethod
99
+ def _is_image(path: str) -> bool:
100
+ mime, _ = mimetypes.guess_type(path)
101
+ if mime and mime.startswith("image"):
102
+ return True
103
+ return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
104
+
105
+ def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
106
+ ts = int(time.time())
107
+ job_dir = self.INPUT_ROOT / f"job_{ts}"
108
+ out_dir = self.OUTPUT_ROOT / f"run_{ts}"
109
+ job_dir.mkdir(parents=True, exist_ok=True)
110
+ out_dir.mkdir(parents=True, exist_ok=True)
111
+
112
+ #####
113
+ shutil.copy2(input_file, job_dir / Path(input_file).name)
114
+ return out_dir, out_dir
115
+
116
+ # ---------- Execução ----------
117
+ def run_inference(
118
+ self,
119
+ file_path: str,
120
+ *,
121
+ seed: int = 42,
122
+ res_h: int = 720,
123
+ res_w: int = 1280,
124
+ sp_size: int = 4,
125
+ extra_args: Optional[List[str]] = None,
126
+ ) -> Tuple[Optional[str], Optional[str], Path]:
127
+ """
128
+ Executa inferência via torchrun com NUM_GPUS:
129
+ - file_path: vídeo .mp4 ou imagem .png/.jpg/.jpeg/.webp
130
+ - Retorna (video_out, image_out, out_dir). Um dos dois primeiros será não-nulo.
131
+ """
132
+ if not Path(file_path).exists():
133
+ raise FileNotFoundError(f"input not found: {file_path}")
134
+
135
+ script = self.SEEDVR_ROOT / "projects" / "inference_seedvr2_3b.py"
136
+ if not script.exists():
137
+ raise FileNotFoundError(f"inference script not found: {script}")
138
+
139
+ job_dir, out_dir = self._prepare_job(file_path)
140
+ self._ensure_ckpt_symlink()
141
+
142
+ out_dir.mkdir(parents=True, exist_ok=True)
143
+ os.chmod(out_dir, 777)
144
+
145
+ job_dir.mkdir(parents=True, exist_ok=True)
146
+ os.chmod(job_dir, 777)
147
+
148
+ cmd = [
149
+ "torchrun",
150
+ f"--nproc-per-node={self.NUM_GPUS}",
151
+ str(script),
152
+ "--video_path", str(job_dir),
153
+ "--output_dir", str(out_dir),
154
+ "--seed", str(seed),
155
+ "--res_h", str(res_h),
156
+ "--res_w", str(res_w),
157
+ "--sp_size", str(sp_size),
158
+ ]
159
+ if extra_args:
160
+ cmd.extend(extra_args)
161
+
162
+ env = os.environ.copy()
163
+ env.setdefault("HF_HOME", str(self.HF_HOME))
164
+ env.setdefault("NCCL_P2P_LEVEL", os.getenv("NCCL_P2P_LEVEL", "NVL"))
165
+ #env.setdefault("NCCL_ASYNC_ERROR_HANDLING", os.getenv("NCCL_ASYNC_ERROR_HANDLING", "1"))
166
+ env.setdefault("OMP_NUM_THREADS", os.getenv("OMP_NUM_THREADS", "8"))
167
+
168
+ print("[seed_server] running:", " ".join(cmd))
169
+ try:
170
+ subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=env)
171
+ except subprocess.CalledProcessError as e:
172
+ print("[seed_server] torchrun error:", e)
173
+ return None, None, out_dir
174
+
175
+ # Buscar artefatos
176
+ videos = sorted(out_dir.rglob("*.mp4"), key=lambda p: p.stat().st_mtime)
177
+ # Cobrir formatos comuns caso upstream mude
178
+ if not videos:
179
+ videos = sorted([*out_dir.rglob("*.mov"), *out_dir.rglob("*.avi")], key=lambda p: p.stat().st_mtime)
180
+ images = sorted(
181
+ [*out_dir.rglob("*.png"), *out_dir.rglob("*.jpg"), *out_dir.rglob("*.jpeg"), *out_dir.rglob("*.webp")],
182
+ key=lambda p: p.stat().st_mtime
183
+ )
184
+
185
+ video_out = str(videos[-1]) if videos else None
186
+ image_out = str(images[-1]) if images else None
187
+ return video_out, image_out, out_dir