aducsdr commited on
Commit
6feecd4
·
verified ·
1 Parent(s): 971c7f9

Upload seedvr_manager (2).py

Browse files
aduc_framework/managers/seedvr_manager (2).py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # managers/seedvr_manager.py
2
+ #
3
+ # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 4.0.0 (Root Installer & Executor)
6
+ #
7
+ # This version fully adopts the logic from the functional hd_specialist.py example.
8
+ # It acts as a setup manager: it clones the SeedVR repo and then copies all
9
+ # necessary directories (projects, common, models, configs, ckpts) to the
10
+ # application root. It also handles the pip installation of the Apex dependency.
11
+ # This ensures that the SeedVR code runs in the exact file structure it expects.
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ import os
16
+ import gc
17
+ import logging
18
+ import sys
19
+ import subprocess
20
+ from pathlib import Path
21
+ from urllib.parse import urlparse
22
+ from torch.hub import download_url_to_file
23
+ import gradio as gr
24
+ import mediapy
25
+ from einops import rearrange
26
+ import shutil
27
+ from omegaconf import OmegaConf
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # --- Caminhos Globais ---
32
+ APP_ROOT = Path("/home/user/app")
33
+ DEPS_DIR = APP_ROOT / "deps"
34
+ SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
35
+ SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
36
+
37
+ class SeedVrManager:
38
+ def __init__(self, workspace_dir="deformes_workspace"):
39
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ self.runner = None
41
+ self.workspace_dir = workspace_dir
42
+ self.is_initialized = False
43
+ self._original_barrier = None
44
+ self.setup_complete = False # Flag para rodar o setup apenas uma vez
45
+ logger.info("SeedVrManager initialized. Setup will run on first use.")
46
+
47
+ def _full_setup(self):
48
+ """
49
+ Executa todo o processo de setup uma única vez.
50
+ """
51
+ if self.setup_complete:
52
+ return
53
+
54
+ logger.info("--- Starting Full SeedVR Setup ---")
55
+
56
+ # 1. Clonar o repositório se não existir
57
+ if not SEEDVR_SPACE_DIR.exists():
58
+ logger.info(f"Cloning SeedVR Space repo to {SEEDVR_SPACE_DIR}...")
59
+ DEPS_DIR.mkdir(exist_ok=True, parents=True)
60
+ subprocess.run(
61
+ ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
62
+ check=True, capture_output=True, text=True
63
+ )
64
+
65
+ # 2. Copiar as pastas necessárias para a raiz da aplicação
66
+ required_dirs = ["projects", "common", "models", "configs_3b", "configs_7b"]
67
+ for dirname in required_dirs:
68
+ source = SEEDVR_SPACE_DIR / dirname
69
+ target = APP_ROOT / dirname
70
+ if not target.exists():
71
+ logger.info(f"Copying '{dirname}' to application root...")
72
+ shutil.copytree(source, target)
73
+
74
+ # 3. Adicionar a raiz ao sys.path para garantir que os imports funcionem
75
+ if str(APP_ROOT) not in sys.path:
76
+ sys.path.insert(0, str(APP_ROOT))
77
+ logger.info(f"Added '{APP_ROOT}' to sys.path.")
78
+
79
+ # 4. Instalar dependências complexas como Apex
80
+ try:
81
+ import apex
82
+ logger.info("Apex is already installed.")
83
+ except ImportError:
84
+ logger.info("Installing Apex dependency...")
85
+ apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
86
+ apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
87
+ subprocess.run(f"pip install {apex_wheel_path}", check=True, shell=True)
88
+ logger.info("Apex installed successfully.")
89
+
90
+ # 5. Baixar os modelos para a pasta ./ckpts na raiz
91
+ ckpt_dir = APP_ROOT / 'ckpts'
92
+ ckpt_dir.mkdir(exist_ok=True)
93
+ pretrain_model_urls = {
94
+ 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
95
+ 'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
96
+ 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
97
+ 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
98
+ 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
99
+ }
100
+ for name, url in pretrain_model_urls.items():
101
+ _load_file_from_url(url=url, model_dir=str(ckpt_dir))
102
+
103
+ self.setup_complete = True
104
+ logger.info("--- Full SeedVR Setup Complete ---")
105
+
106
+ def _initialize_runner(self, model_version: str):
107
+ if self.runner is not None: return
108
+
109
+ # Garante que todo o ambiente está configurado antes de prosseguir
110
+ self._full_setup()
111
+
112
+ # Agora que o setup está feito, podemos importar os módulos
113
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
114
+ from common.config import load_config
115
+ from common.seed import set_seed
116
+
117
+ if dist.is_available() and not dist.is_initialized():
118
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
119
+ os.environ["MASTER_PORT"] = "12355"
120
+ os.environ["RANK"] = str(0)
121
+ os.environ["WORLD_SIZE"] = str(1)
122
+ dist.init_process_group(backend='gloo')
123
+ logger.info("Initialized torch.distributed process group.")
124
+
125
+ logger.info(f"Initializing SeedVR2 {model_version} runner...")
126
+ if model_version == '3B':
127
+ config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
128
+ checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b.pth'
129
+ else: # Assumimos 7B
130
+ config_path = APP_ROOT / 'configs_7b' / 'main.yaml'
131
+ checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_7b.pth'
132
+
133
+ config = load_config(str(config_path))
134
+
135
+ self.runner = VideoDiffusionInfer(config)
136
+ OmegaConf.set_readonly(self.runner.config, False)
137
+
138
+ self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
139
+ self.runner.configure_vae_model()
140
+
141
+ if hasattr(self.runner.vae, "set_memory_limit"):
142
+ self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
143
+
144
+ self.is_initialized = True
145
+ logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
146
+
147
+ def _unload_runner(self):
148
+ if self.runner is not None:
149
+ del self.runner
150
+ self.runner = None
151
+ gc.collect()
152
+ torch.cuda.empty_cache()
153
+ self.is_initialized = False
154
+ logger.info("Runner do SeedVR2 descarregado da VRAM.")
155
+ if dist.is_initialized():
156
+ dist.destroy_process_group()
157
+ logger.info("Destroyed torch.distributed process group.")
158
+
159
+ def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
160
+ model_version: str = '7B', steps: int = 100, seed: int = 666,
161
+ progress: gr.Progress = None) -> str:
162
+ try:
163
+ self._initialize_runner(model_version)
164
+
165
+ # Precisamos importar aqui, pois o sys.path é modificado no setup
166
+ from common.seed import set_seed
167
+ from data.image.transforms.divisible_crop import DivisibleCrop
168
+ from data.image.transforms.na_resize import NaResize
169
+ from data.video.transforms.rearrange import Rearrange
170
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
171
+ from torchvision.transforms import Compose, Lambda, Normalize
172
+ from torchvision.io.video import read_video
173
+
174
+ set_seed(seed, same_across_ranks=True)
175
+ self.runner.config.diffusion.timesteps.sampling.steps = steps
176
+ self.runner.configure_diffusion()
177
+
178
+ video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
179
+ res_h, res_w = video_tensor.shape[-2:]
180
+ video_transform = Compose([
181
+ NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
182
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
183
+ DivisibleCrop((16, 16)),
184
+ Normalize(0.5, 0.5),
185
+ Rearrange("t c h w -> c t h w"),
186
+ ])
187
+ cond_latents = [video_transform(video_tensor.to(self.device))]
188
+ input_videos = cond_latents
189
+ self.runner.dit.to("cpu")
190
+ self.runner.vae.to(self.device)
191
+ cond_latents = self.runner.vae_encode(cond_latents)
192
+ self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
193
+ self.runner.dit.to(self.device)
194
+
195
+ pos_emb = torch.load(APP_ROOT / 'pos_emb.pt').to(self.device)
196
+ neg_emb = torch.load(APP_ROOT / 'neg_emb.pt').to(self.device)
197
+ text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
198
+
199
+ noises = [torch.randn_like(latent) for latent in cond_latents]
200
+ conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
201
+
202
+ with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
203
+ video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
204
+
205
+ self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
206
+ self.runner.vae.to(self.device)
207
+ samples = self.runner.vae_decode(video_tensors)
208
+ final_sample = samples[0]
209
+ input_video_sample = input_videos[0]
210
+ if final_sample.shape[1] < input_video_sample.shape[1]:
211
+ input_video_sample = input_video_sample[:, :final_sample.shape[1]]
212
+
213
+ final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
214
+ final_sample = rearrange(final_sample, "t c h w -> t h w c")
215
+ final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
216
+ final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
217
+
218
+ mediapy.write_video(output_video_path, final_sample_np, fps=24)
219
+ logger.info(f"HD Mastered video saved to: {output_video_path}")
220
+ return output_path
221
+ finally:
222
+ self._unload_runner()
223
+
224
+ def _load_file_from_url(url, model_dir='./', file_name=None):
225
+ os.makedirs(model_dir, exist_ok=True)
226
+ filename = file_name or os.path.basename(urlparse(url).path)
227
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
228
+ if not os.path.exists(cached_file):
229
+ logger.info(f'Downloading: "{url}" to {cached_file}')
230
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
231
+ return cached_file
232
+
233
+ seedvr_manager_singleton = SeedVrManager()