aducsdr commited on
Commit
0115fac
·
verified ·
1 Parent(s): e01bb3e

Delete aduc_framework/managers/seedvr_manager.py

Browse files
aduc_framework/managers/seedvr_manager.py DELETED
@@ -1,243 +0,0 @@
1
- # hd_specialist.py (Versão Final - Corrigindo o Contexto de Execução de Caminhos)
2
- # https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B
3
-
4
- import torch
5
- import imageio
6
- import os
7
- import gc
8
- import logging
9
- import numpy as np
10
- import shlex
11
- import subprocess
12
- from pathlib import Path
13
- from urllib.parse import urlparse
14
- from torch.hub import download_url_to_file
15
- from omegaconf import OmegaConf
16
- import sys
17
-
18
- # --- Configuração do Logging ---
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
- # --- Constantes de Caminho ---
23
- # Define a raiz do projeto (onde este script está) e cria um diretório para dependências
24
- # Usamos 'Path.cwd()' para ter certeza que é relativo ao diretório de execução do projeto principal.
25
- PROJECT_ROOT = Path.cwd()
26
- DEPS_DIR = PROJECT_ROOT / "deps"
27
- SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
28
- SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
29
-
30
- def setup_environment():
31
- """
32
- Clona o repositório SeedVR se não existir e o adiciona ao sys.path
33
- para que seus módulos (common, projects, etc.) possam ser importados.
34
- """
35
- if not SEEDVR_SPACE_DIR.is_dir():
36
- logger.info(f"Repositório SeedVR não encontrado. Clonando de '{SEEDVR_SPACE_URL}'...")
37
- try:
38
- DEPS_DIR.mkdir(exist_ok=True)
39
- subprocess.run(
40
- ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
41
- check=True, capture_output=True, text=True
42
- )
43
- logger.info(f"✅ Repositório clonado com sucesso em '{SEEDVR_SPACE_DIR}'")
44
- except subprocess.CalledProcessError as e:
45
- logger.error(f"❌ Falha ao clonar o repositório. Erro do Git: {e.stderr}")
46
- raise RuntimeError("Não foi possível clonar a dependência SeedVR do Hugging Face.")
47
- else:
48
- logger.info(f"Repositório SeedVR já existe em '{SEEDVR_SPACE_DIR}'.")
49
-
50
- resolved_path = str(SEEDVR_SPACE_DIR.resolve())
51
- if resolved_path not in sys.path:
52
- sys.path.insert(0, resolved_path)
53
- logger.info(f"Adicionado '{resolved_path}' ao sys.path.")
54
-
55
- # Executa a configuração do ambiente assim que o módulo é carregado
56
- setup_environment()
57
-
58
- # Função auxiliar de download (permanece a mesma)
59
- def _load_file_from_url(url, model_dir='./', file_name=None):
60
- os.makedirs(model_dir, exist_ok=True)
61
- filename = file_name or os.path.basename(urlparse(url).path)
62
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
63
- if not os.path.exists(cached_file):
64
- logger.info(f'Baixando: "{url}" para {cached_file}')
65
- download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
66
- return cached_file
67
-
68
- # --- Importações do Repositório Clonado ---
69
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
70
- from common.config import load_config
71
- from common.seed import set_seed
72
- from data.image.transforms.divisible_crop import DivisibleCrop
73
- from data.image.transforms.na_resize import NaResize
74
- from data.video.transforms.rearrange import Rearrange
75
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
76
- from torchvision.transforms import Compose, Lambda, Normalize
77
- from torchvision.io.video import read_video
78
- from einops import rearrange
79
-
80
- class SeedVrManager:
81
- """
82
- Implementa o Especialista HD (Δ+) usando a infraestrutura oficial do SeedVR.
83
- """
84
- def __init__(self, workspace_dir="deformes_workspace"):
85
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
86
- self.runner = None
87
- self.workspace_dir = Path(workspace_dir)
88
- self.is_initialized = False
89
- logger.info("Especialista HD (SeedVR) inicializado. Modelo será carregado sob demanda.")
90
-
91
- def _setup_dependencies(self):
92
- """Instala dependências complexas como Apex."""
93
- logger.info("Configurando dependências do SeedVR (Apex)...")
94
- apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
95
- apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
96
- subprocess.run(shlex.split(f"pip install {apex_wheel_path}"), check=True)
97
- logger.info("✅ Dependência Apex instalada com sucesso.")
98
-
99
- def _download_models(self):
100
- """Baixa os checkpoints necessários para o SeedVR2 DENTRO do repositório clonado."""
101
- logger.info("Verificando e baixando modelos do SeedVR2...")
102
- ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
103
-
104
- pretrain_model_url = {
105
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
106
- 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
107
- 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
108
- 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
109
- }
110
-
111
- _load_file_from_url(url=pretrain_model_url['dit'], model_dir=str(ckpt_dir))
112
- _load_file_from_url(url=pretrain_model_url['vae'], model_dir=str(ckpt_dir))
113
- _load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir=str(SEEDVR_SPACE_DIR))
114
- _load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir=str(SEEDVR_SPACE_DIR))
115
- logger.info("✅ Modelos do SeedVR2 baixados com sucesso.")
116
-
117
- def _initialize_runner(self):
118
- """Carrega e configura o modelo SeedVR sob demanda."""
119
- if self.runner is not None:
120
- return
121
-
122
- self._setup_dependencies()
123
- self._download_models()
124
-
125
- logger.info("Inicializando o runner do SeedVR2...")
126
-
127
- # --- CORREÇÃO CRÍTICA: MUDANÇA DE DIRETÓRIO DE TRABALHO ---
128
- original_cwd = Path.cwd()
129
- try:
130
- # Muda para o diretório do repositório clonado. Isso é essencial para que
131
- # o `load_config` encontre os arquivos .yaml herdados (como os da pasta `models`).
132
- os.chdir(SEEDVR_SPACE_DIR)
133
-
134
- logger.info(f"Diretório de trabalho alterado para: {SEEDVR_SPACE_DIR}")
135
-
136
- # Agora todos os caminhos são relativos à raiz do repositório
137
- config_path = './configs_3b/main.yaml'
138
- dit_checkpoint_path = './ckpts/seedvr2_ema_3b.pth'
139
-
140
- config = load_config(config_path)
141
-
142
- self.runner = VideoDiffusionInfer(config)
143
- OmegaConf.set_readonly(self.runner.config, False)
144
-
145
- self.runner.configure_dit_model(device=self.device, checkpoint=dit_checkpoint_path)
146
- self.runner.configure_vae_model()
147
-
148
- if hasattr(self.runner.vae, "set_memory_limit"):
149
- self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
150
-
151
- self.is_initialized = True
152
- logger.info("✅ Runner do SeedVR2 inicializado e pronto.")
153
-
154
- finally:
155
- # Garante que o diretório de trabalho original seja restaurado,
156
- # não importa se a inicialização foi bem-sucedida ou falhou.
157
- os.chdir(original_cwd)
158
- logger.info(f"Diretório de trabalho restaurado para: {original_cwd}")
159
- # --- FIM DA CORREÇÃO ---
160
-
161
- def _unload_runner(self):
162
- """Remove o runner da VRAM para liberar recursos."""
163
- if self.runner is not None:
164
- del self.runner; self.runner = None
165
- gc.collect(); torch.cuda.empty_cache()
166
- self.is_initialized = False
167
- logger.info("Runner do SeedVR2 descarregado da VRAM.")
168
-
169
- def process_video(self, input_video_path: str, output_video_path: str, prompt: str) -> str:
170
- """Aplica o aprimoramento HD a um vídeo usando a lógica oficial do SeedVR."""
171
- original_cwd = Path.cwd() # Salva o diretório original
172
- try:
173
- self._initialize_runner()
174
-
175
- # --- CORREÇÃO CRÍTICA 2: MUDAR DIRETÓRIO TAMBÉM DURANTE O PROCESSAMENTO ---
176
- # O código pode precisar acessar arquivos de embedding novamente
177
- os.chdir(SEEDVR_SPACE_DIR)
178
-
179
- set_seed(seed, same_across_ranks=True)
180
-
181
- self.runner.config.diffusion.cfg.scale = 1.0
182
- self.runner.config.diffusion.cfg.rescale = 0.0
183
- self.runner.config.diffusion.timesteps.sampling.steps = 1
184
- self.runner.configure_diffusion()
185
-
186
- logger.info(f"Processando vídeo de entrada: {input_video_path}")
187
- # Garante que os caminhos de entrada/saída sejam absolutos para não quebrar com a mudança de CWD
188
- abs_input_path = original_cwd / input_video_path
189
- abs_output_path = original_cwd / output_video_path
190
-
191
- video_tensor = read_video(str(abs_input_path), output_format="TCHW")[0] / 255.0
192
- if video_tensor.size(0) > 121:
193
- logger.warning(f"Vídeo com {video_tensor.size(0)} frames. Truncando para 121 frames.")
194
- video_tensor = video_tensor[:121]
195
-
196
- video_transform = Compose([
197
- NaResize(resolution=(1280 * 720)**0.5, mode="area", downsample_only=False),
198
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), DivisibleCrop((16, 16)),
199
- Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w"),
200
- ])
201
-
202
- cond_latent = video_transform(video_tensor.to(self.device))
203
- input_video_for_colorfix = cond_latent.clone()
204
- ori_length = cond_latent.size(1)
205
-
206
- logger.info("Codificando vídeo para o espaço latente...")
207
- cond_latent = self.runner.vae_encode([cond_latent])[0]
208
-
209
- # Carrega embeddings com caminhos relativos ao diretório do repo
210
- text_pos_embeds = torch.load('pos_emb.pt').to(self.device)
211
- text_neg_embeds = torch.load('neg_emb.pt').to(self.device)
212
- text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
213
-
214
- noise = torch.randn_like(cond_latent)
215
-
216
- logger.info(f"Iniciando a geração de restauração para {ori_length} frames...")
217
- with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
218
- video_tensor_out = self.runner.inference(
219
- noises=[noise],
220
- conditions=[self.runner.get_condition(noise, task="sr", latent_blur=cond_latent)],
221
- dit_offload=False, **text_embeds_dict,
222
- )[0]
223
-
224
- sample = rearrange(video_tensor_out, "c t h w -> t c h w")
225
- if ori_length < sample.shape[0]:
226
- sample = sample[:ori_length]
227
-
228
- input_video_for_colorfix = rearrange(input_video_for_colorfix, "c t h w -> t c h w")
229
- sample = wavelet_reconstruction(sample.cpu(), input_video_for_colorfix[:sample.size(0)].cpu())
230
- sample = rearrange(sample, "t c h w -> t h w c")
231
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
232
-
233
- logger.info(f"Salvando vídeo aprimorado em: {abs_output_path}")
234
- self.workspace_dir.mkdir(parents=True, exist_ok=True)
235
- imageio.get_writer(str(abs_output_path), fps=fps_out, codec='libx264', quality=9).extend(sample)
236
-
237
- return str(abs_output_path)
238
- finally:
239
- os.chdir(original_cwd) # Restaura o diretório de trabalho original
240
- self._unload_runner()
241
-
242
- # Instância Singleton
243
- seedvr_manager_singleton = SeedVrManager()