aducsdr commited on
Commit
e01bb3e
·
verified ·
1 Parent(s): 180ac2c

Update aduc_framework/managers/seedvr_manager.py

Browse files
aduc_framework/managers/seedvr_manager.py CHANGED
@@ -1,4 +1,4 @@
1
- # hd_specialist.py (Versão Final - Gerenciando o Repositório Completo)
2
  # https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B
3
 
4
  import torch
@@ -7,8 +7,6 @@ import os
7
  import gc
8
  import logging
9
  import numpy as np
10
- from PIL import Image
11
- from tqdm import tqdm
12
  import shlex
13
  import subprocess
14
  from pathlib import Path
@@ -23,7 +21,8 @@ logger = logging.getLogger(__name__)
23
 
24
  # --- Constantes de Caminho ---
25
  # Define a raiz do projeto (onde este script está) e cria um diretório para dependências
26
- PROJECT_ROOT = Path(__file__).resolve().parent
 
27
  DEPS_DIR = PROJECT_ROOT / "deps"
28
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
29
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
@@ -36,9 +35,7 @@ def setup_environment():
36
  if not SEEDVR_SPACE_DIR.is_dir():
37
  logger.info(f"Repositório SeedVR não encontrado. Clonando de '{SEEDVR_SPACE_URL}'...")
38
  try:
39
- # Garante que o diretório de dependências exista
40
  DEPS_DIR.mkdir(exist_ok=True)
41
- # Clona o repositório. '--depth 1' baixa apenas a versão mais recente.
42
  subprocess.run(
43
  ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
44
  check=True, capture_output=True, text=True
@@ -50,7 +47,6 @@ def setup_environment():
50
  else:
51
  logger.info(f"Repositório SeedVR já existe em '{SEEDVR_SPACE_DIR}'.")
52
 
53
- # Adiciona a raiz do repositório clonado ao path do Python
54
  resolved_path = str(SEEDVR_SPACE_DIR.resolve())
55
  if resolved_path not in sys.path:
56
  sys.path.insert(0, resolved_path)
@@ -70,7 +66,6 @@ def _load_file_from_url(url, model_dir='./', file_name=None):
70
  return cached_file
71
 
72
  # --- Importações do Repositório Clonado ---
73
- # Agora estas importações funcionam porque `setup_environment` adicionou o repositório ao sys.path
74
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
75
  from common.config import load_config
76
  from common.seed import set_seed
@@ -97,7 +92,6 @@ class SeedVrManager:
97
  """Instala dependências complexas como Apex."""
98
  logger.info("Configurando dependências do SeedVR (Apex)...")
99
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
100
- # Baixa para um diretório temporário ou de dependências
101
  apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
102
  subprocess.run(shlex.split(f"pip install {apex_wheel_path}"), check=True)
103
  logger.info("✅ Dependência Apex instalada com sucesso.")
@@ -105,7 +99,6 @@ class SeedVrManager:
105
  def _download_models(self):
106
  """Baixa os checkpoints necessários para o SeedVR2 DENTRO do repositório clonado."""
107
  logger.info("Verificando e baixando modelos do SeedVR2...")
108
- # O diretório de checkpoints agora é DENTRO do repositório clonado
109
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
110
 
111
  pretrain_model_url = {
@@ -115,7 +108,6 @@ class SeedVrManager:
115
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
116
  }
117
 
118
- # Salva os modelos nos locais corretos que o código espera
119
  _load_file_from_url(url=pretrain_model_url['dit'], model_dir=str(ckpt_dir))
120
  _load_file_from_url(url=pretrain_model_url['vae'], model_dir=str(ckpt_dir))
121
  _load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir=str(SEEDVR_SPACE_DIR))
@@ -132,29 +124,39 @@ class SeedVrManager:
132
 
133
  logger.info("Inicializando o runner do SeedVR2...")
134
 
135
- # CORREÇÃO: O caminho para o config agora é relativo ao repositório clonado.
136
- config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
137
-
138
- if not config_path.is_file():
139
- # Esta verificação agora é uma salvaguarda, pois o git clone deve ter baixado o arquivo.
140
- raise FileNotFoundError(f"Arquivo de configuração principal não encontrado em {config_path}. O clone do repositório pode ter falhado.")
 
 
141
 
142
- # `load_config` agora encontrará os arquivos herdados pois estamos no contexto do repo clonado.
143
- config = load_config(str(config_path))
144
-
145
- self.runner = VideoDiffusionInfer(config)
146
- OmegaConf.set_readonly(self.runner.config, False)
147
-
148
- # CORREÇÃO: Os caminhos para os checkpoints também são relativos ao repositório clonado.
149
- dit_checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
150
- self.runner.configure_dit_model(device=self.device, checkpoint=str(dit_checkpoint_path))
151
- self.runner.configure_vae_model()
152
-
153
- if hasattr(self.runner.vae, "set_memory_limit"):
154
- self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
155
-
156
- self.is_initialized = True
157
- logger.info("✅ Runner do SeedVR2 inicializado e pronto.")
 
 
 
 
 
 
 
 
158
 
159
  def _unload_runner(self):
160
  """Remove o runner da VRAM para liberar recursos."""
@@ -166,8 +168,14 @@ class SeedVrManager:
166
 
167
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str) -> str:
168
  """Aplica o aprimoramento HD a um vídeo usando a lógica oficial do SeedVR."""
 
169
  try:
170
  self._initialize_runner()
 
 
 
 
 
171
  set_seed(seed, same_across_ranks=True)
172
 
173
  self.runner.config.diffusion.cfg.scale = 1.0
@@ -176,7 +184,11 @@ class SeedVrManager:
176
  self.runner.configure_diffusion()
177
 
178
  logger.info(f"Processando vídeo de entrada: {input_video_path}")
179
- video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
 
 
 
 
180
  if video_tensor.size(0) > 121:
181
  logger.warning(f"Vídeo com {video_tensor.size(0)} frames. Truncando para 121 frames.")
182
  video_tensor = video_tensor[:121]
@@ -194,11 +206,9 @@ class SeedVrManager:
194
  logger.info("Codificando vídeo para o espaço latente...")
195
  cond_latent = self.runner.vae_encode([cond_latent])[0]
196
 
197
- # CORREÇÃO: Carrega os embeddings a partir do repositório clonado
198
- pos_emb_path = SEEDVR_SPACE_DIR / 'pos_emb.pt'
199
- neg_emb_path = SEEDVR_SPACE_DIR / 'neg_emb.pt'
200
- text_pos_embeds = torch.load(pos_emb_path).to(self.device)
201
- text_neg_embeds = torch.load(neg_emb_path).to(self.device)
202
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
203
 
204
  noise = torch.randn_like(cond_latent)
@@ -220,12 +230,13 @@ class SeedVrManager:
220
  sample = rearrange(sample, "t c h w -> t h w c")
221
  sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
222
 
223
- logger.info(f"Salvando vídeo aprimorado em: {output_video_path}")
224
  self.workspace_dir.mkdir(parents=True, exist_ok=True)
225
- imageio.get_writer(output_video_path, fps=fps_out, codec='libx264', quality=9).extend(sample)
226
 
227
- return output_video_path
228
  finally:
 
229
  self._unload_runner()
230
 
231
  # Instância Singleton
 
1
+ # hd_specialist.py (Versão Final - Corrigindo o Contexto de Execução de Caminhos)
2
  # https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B
3
 
4
  import torch
 
7
  import gc
8
  import logging
9
  import numpy as np
 
 
10
  import shlex
11
  import subprocess
12
  from pathlib import Path
 
21
 
22
  # --- Constantes de Caminho ---
23
  # Define a raiz do projeto (onde este script está) e cria um diretório para dependências
24
+ # Usamos 'Path.cwd()' para ter certeza que é relativo ao diretório de execução do projeto principal.
25
+ PROJECT_ROOT = Path.cwd()
26
  DEPS_DIR = PROJECT_ROOT / "deps"
27
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
28
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
 
35
  if not SEEDVR_SPACE_DIR.is_dir():
36
  logger.info(f"Repositório SeedVR não encontrado. Clonando de '{SEEDVR_SPACE_URL}'...")
37
  try:
 
38
  DEPS_DIR.mkdir(exist_ok=True)
 
39
  subprocess.run(
40
  ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
41
  check=True, capture_output=True, text=True
 
47
  else:
48
  logger.info(f"Repositório SeedVR já existe em '{SEEDVR_SPACE_DIR}'.")
49
 
 
50
  resolved_path = str(SEEDVR_SPACE_DIR.resolve())
51
  if resolved_path not in sys.path:
52
  sys.path.insert(0, resolved_path)
 
66
  return cached_file
67
 
68
  # --- Importações do Repositório Clonado ---
 
69
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
70
  from common.config import load_config
71
  from common.seed import set_seed
 
92
  """Instala dependências complexas como Apex."""
93
  logger.info("Configurando dependências do SeedVR (Apex)...")
94
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
 
95
  apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
96
  subprocess.run(shlex.split(f"pip install {apex_wheel_path}"), check=True)
97
  logger.info("✅ Dependência Apex instalada com sucesso.")
 
99
  def _download_models(self):
100
  """Baixa os checkpoints necessários para o SeedVR2 DENTRO do repositório clonado."""
101
  logger.info("Verificando e baixando modelos do SeedVR2...")
 
102
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
103
 
104
  pretrain_model_url = {
 
108
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
109
  }
110
 
 
111
  _load_file_from_url(url=pretrain_model_url['dit'], model_dir=str(ckpt_dir))
112
  _load_file_from_url(url=pretrain_model_url['vae'], model_dir=str(ckpt_dir))
113
  _load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir=str(SEEDVR_SPACE_DIR))
 
124
 
125
  logger.info("Inicializando o runner do SeedVR2...")
126
 
127
+ # --- CORREÇÃO CRÍTICA: MUDANÇA DE DIRETÓRIO DE TRABALHO ---
128
+ original_cwd = Path.cwd()
129
+ try:
130
+ # Muda para o diretório do repositório clonado. Isso é essencial para que
131
+ # o `load_config` encontre os arquivos .yaml herdados (como os da pasta `models`).
132
+ os.chdir(SEEDVR_SPACE_DIR)
133
+
134
+ logger.info(f"Diretório de trabalho alterado para: {SEEDVR_SPACE_DIR}")
135
 
136
+ # Agora todos os caminhos são relativos à raiz do repositório
137
+ config_path = './configs_3b/main.yaml'
138
+ dit_checkpoint_path = './ckpts/seedvr2_ema_3b.pth'
139
+
140
+ config = load_config(config_path)
141
+
142
+ self.runner = VideoDiffusionInfer(config)
143
+ OmegaConf.set_readonly(self.runner.config, False)
144
+
145
+ self.runner.configure_dit_model(device=self.device, checkpoint=dit_checkpoint_path)
146
+ self.runner.configure_vae_model()
147
+
148
+ if hasattr(self.runner.vae, "set_memory_limit"):
149
+ self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
150
+
151
+ self.is_initialized = True
152
+ logger.info("✅ Runner do SeedVR2 inicializado e pronto.")
153
+
154
+ finally:
155
+ # Garante que o diretório de trabalho original seja restaurado,
156
+ # não importa se a inicialização foi bem-sucedida ou falhou.
157
+ os.chdir(original_cwd)
158
+ logger.info(f"Diretório de trabalho restaurado para: {original_cwd}")
159
+ # --- FIM DA CORREÇÃO ---
160
 
161
  def _unload_runner(self):
162
  """Remove o runner da VRAM para liberar recursos."""
 
168
 
169
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str) -> str:
170
  """Aplica o aprimoramento HD a um vídeo usando a lógica oficial do SeedVR."""
171
+ original_cwd = Path.cwd() # Salva o diretório original
172
  try:
173
  self._initialize_runner()
174
+
175
+ # --- CORREÇÃO CRÍTICA 2: MUDAR DIRETÓRIO TAMBÉM DURANTE O PROCESSAMENTO ---
176
+ # O código pode precisar acessar arquivos de embedding novamente
177
+ os.chdir(SEEDVR_SPACE_DIR)
178
+
179
  set_seed(seed, same_across_ranks=True)
180
 
181
  self.runner.config.diffusion.cfg.scale = 1.0
 
184
  self.runner.configure_diffusion()
185
 
186
  logger.info(f"Processando vídeo de entrada: {input_video_path}")
187
+ # Garante que os caminhos de entrada/saída sejam absolutos para não quebrar com a mudança de CWD
188
+ abs_input_path = original_cwd / input_video_path
189
+ abs_output_path = original_cwd / output_video_path
190
+
191
+ video_tensor = read_video(str(abs_input_path), output_format="TCHW")[0] / 255.0
192
  if video_tensor.size(0) > 121:
193
  logger.warning(f"Vídeo com {video_tensor.size(0)} frames. Truncando para 121 frames.")
194
  video_tensor = video_tensor[:121]
 
206
  logger.info("Codificando vídeo para o espaço latente...")
207
  cond_latent = self.runner.vae_encode([cond_latent])[0]
208
 
209
+ # Carrega embeddings com caminhos relativos ao diretório do repo
210
+ text_pos_embeds = torch.load('pos_emb.pt').to(self.device)
211
+ text_neg_embeds = torch.load('neg_emb.pt').to(self.device)
 
 
212
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
213
 
214
  noise = torch.randn_like(cond_latent)
 
230
  sample = rearrange(sample, "t c h w -> t h w c")
231
  sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
232
 
233
+ logger.info(f"Salvando vídeo aprimorado em: {abs_output_path}")
234
  self.workspace_dir.mkdir(parents=True, exist_ok=True)
235
+ imageio.get_writer(str(abs_output_path), fps=fps_out, codec='libx264', quality=9).extend(sample)
236
 
237
+ return str(abs_output_path)
238
  finally:
239
+ os.chdir(original_cwd) # Restaura o diretório de trabalho original
240
  self._unload_runner()
241
 
242
  # Instância Singleton