euiia commited on
Commit
3dce029
·
verified ·
1 Parent(s): 18aaf4a

Update hd_specialist.py

Browse files
Files changed (1) hide show
  1. hd_specialist.py +158 -75
hd_specialist.py CHANGED
@@ -1,97 +1,180 @@
1
- # hd_specialist.py
 
 
2
  import torch
3
  import imageio
4
  import os
 
5
  import logging
 
6
  from PIL import Image
7
- from diffusers import SeedVR2Pipeline
8
  from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class HDSpecialist:
13
  """
14
- Implementa o Especialista HD (Δ+) da arquitetura ADUC-SDR.
15
- Utiliza o modelo SeedVR2 para realizar a restauração e o aprimoramento
16
- de vídeo, atuando como a etapa final de pós-produção.
17
  """
18
- def __init__(self):
19
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
- self.pipeline = None
21
- self.model_id = "ByteDance-Seed/SeedVR2-3B"
22
- logger.info(f"Especialista HD inicializado. Aguardando carregamento do modelo {self.model_id}...")
23
-
24
- def _load_pipeline(self):
25
- """Carrega o pipeline do modelo sob demanda para economizar memória."""
26
- if self.pipeline is None:
27
- logger.info("Carregando o pipeline SeedVR2... Isso pode levar alguns minutos.")
28
- try:
29
- self.pipeline = SeedVR2Pipeline.from_pretrained(
30
- self.model_id,
31
- torch_dtype=torch.float16,
32
- variant="fp16"
33
- ).to(self.device)
34
- logger.info("Pipeline SeedVR2 carregado com sucesso.")
35
- except Exception as e:
36
- logger.error(f"Falha ao carregar o pipeline SeedVR2: {e}")
37
- raise
38
-
39
- def process_video(self, input_video_path: str, output_video_path: str, prompt: str, strength: float = 0.8, batch_size: int = 8) -> str:
40
- """
41
- Aplica o aprimoramento HD a um vídeo.
42
 
43
- Args:
44
- input_video_path (str): Caminho para o vídeo de entrada.
45
- output_video_path (str): Caminho para salvar o vídeo aprimorado.
46
- prompt (str): Um prompt para guiar a restauração (pode ser o prompt global).
47
- strength (float): Força do efeito de restauração.
48
- batch_size (int): Número de frames para processar por lote para gerenciar a VRAM.
49
 
50
- Returns:
51
- str: O caminho para o vídeo aprimorado.
52
- """
53
- self._load_pipeline()
54
- if not os.path.exists(input_video_path):
55
- logger.error(f"Vídeo de entrada não encontrado em: {input_video_path}")
56
- raise FileNotFoundError(f"Vídeo de entrada não encontrado: {input_video_path}")
57
 
58
- logger.info(f"Iniciando processo HD para: {input_video_path}")
 
 
 
 
 
 
 
 
59
 
60
- # 1. Ler os frames do vídeo de entrada
61
- reader = imageio.get_reader(input_video_path)
62
- fps = reader.get_meta_data()['fps']
63
- input_frames = [Image.fromarray(frame) for frame in reader]
64
- reader.close()
65
 
66
- logger.info(f"Vídeo lido. Total de {len(input_frames)} frames a {fps} FPS.")
67
 
68
- # 2. Processar frames em lotes
69
- processed_frames = []
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Usamos tqdm para ter uma barra de progresso no console
72
- for i in tqdm(range(0, len(input_frames), batch_size), desc="Aprimorando frames com SeedVR2"):
73
- batch = input_frames[i:i + batch_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # O pipeline retorna uma lista de imagens PIL
76
- enhanced_batch = self.pipeline(
77
- prompt=prompt,
78
- frames=batch,
79
- num_inference_steps=20,
80
- guidance_scale=7.5,
81
- strength=strength,
82
- ).frames
83
-
84
- processed_frames.extend(enhanced_batch)
85
-
86
- # 3. Salvar os frames aprimorados em um novo vídeo
87
- logger.info(f"Salvando vídeo aprimorado em: {output_video_path}")
88
- writer = imageio.get_writer(output_video_path, fps=fps, codec='libx264', quality=9)
89
- for frame in processed_frames:
90
- writer.append_data(np.array(frame))
91
- writer.close()
92
-
93
- logger.info("Processo HD concluído com sucesso.")
94
- return output_video_path
95
-
96
- # Instância Singleton para ser usada em toda a aplicação
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  hd_specialist_singleton = HDSpecialist()
 
1
+ # hd_specialist.py (Versão Corrigida, usando o código-fonte do SeedVR)
2
+ # https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B/tree/main
3
+
4
  import torch
5
  import imageio
6
  import os
7
+ import gc
8
  import logging
9
+ import numpy as np
10
  from PIL import Image
 
11
  from tqdm import tqdm
12
+ import shlex
13
+ import subprocess
14
+ from pathlib import Path
15
+ from urllib.parse import urlparse
16
+ from torch.hub import download_url_to_file, get_dir
17
+ from omegaconf import OmegaConf
18
+
19
+ # --- Importações do código-fonte do SeedVR ---
20
+ # Certifique-se de que a pasta 'SeedVR' está no seu projeto
21
+ from SeedVR.projects.video_diffusion_sr.infer import VideoDiffusionInfer
22
+ from SeedVR.common.config import load_config
23
+ from SeedVR.common.seed import set_seed
24
+ from SeedVR.data.image.transforms.divisible_crop import DivisibleCrop
25
+ from SeedVR.data.image.transforms.na_resize import NaResize
26
+ from SeedVR.data.video.transforms.rearrange import Rearrange
27
+ from SeedVR.projects.video_diffusion_sr.color_fix import wavelet_reconstruction
28
+ from torchvision.transforms import Compose, Lambda, Normalize
29
+ from torchvision.io.video import read_video
30
+ from einops import rearrange
31
 
32
  logger = logging.getLogger(__name__)
33
 
34
  class HDSpecialist:
35
  """
36
+ Implementa o Especialista HD (Δ+) usando a infraestrutura oficial do SeedVR.
 
 
37
  """
38
+ def __init__(self, workspace_dir="deformes_workspace"):
39
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ self.runner = None
41
+ self.workspace_dir = workspace_dir
42
+ self.is_initialized = False
43
+ logger.info("Especialista HD (SeedVR) inicializado. Modelo será carregado sob demanda.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def _download_models(self):
46
+ """Baixa os checkpoints e dependências necessários para o SeedVR2."""
47
+ logger.info("Verificando e baixando modelos do SeedVR2...")
48
+ ckpt_dir = Path('./ckpts')
49
+ ckpt_dir.mkdir(exist_ok=True)
 
50
 
51
+ pretrain_model_url = {
52
+ 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
53
+ 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
54
+ 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
55
+ 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
56
+ }
 
57
 
58
+ # Função auxiliar para download
59
+ def load_file_from_url(url, model_dir='./', file_name=None):
60
+ os.makedirs(model_dir, exist_ok=True)
61
+ filename = file_name or os.path.basename(urlparse(url).path)
62
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
63
+ if not os.path.exists(cached_file):
64
+ logger.info(f'Baixando: "{url}" para {cached_file}')
65
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
66
+ return cached_file
67
 
68
+ load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/')
69
+ load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/')
70
+ load_file_from_url(url=pretrain_model_url['pos_emb'])
71
+ load_file_from_url(url=pretrain_model_url['neg_emb'])
72
+ logger.info("Modelos do SeedVR2 baixados com sucesso.")
73
 
 
74
 
75
+ def _initialize_runner(self):
76
+ """Carrega e configura o modelo SeedVR sob demanda."""
77
+ if self.runner is not None:
78
+ return
79
+
80
+ self._download_models()
81
+
82
+ logger.info("Inicializando o runner do SeedVR2...")
83
+ config_path = os.path.join('./SeedVR/configs_3b', 'main.yaml')
84
+ config = load_config(config_path)
85
+
86
+ self.runner = VideoDiffusionInfer(config)
87
+ OmegaConf.set_readonly(self.runner.config, False)
88
 
89
+ self.runner.configure_dit_model(device=self.device, checkpoint='./ckpts/seedvr2_ema_3b.pth')
90
+ self.runner.configure_vae_model()
91
+
92
+ if hasattr(self.runner.vae, "set_memory_limit"):
93
+ self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
94
+
95
+ self.is_initialized = True
96
+ logger.info("Runner do SeedVR2 inicializado e pronto.")
97
+
98
+ def _unload_runner(self):
99
+ """Remove o runner da VRAM para liberar recursos."""
100
+ if self.runner is not None:
101
+ del self.runner
102
+ self.runner = None
103
+ gc.collect()
104
+ torch.cuda.empty_cache()
105
+ self.is_initialized = False
106
+ logger.info("Runner do SeedVR2 descarregado da VRAM.")
107
+
108
+ def process_video(self, input_video_path: str, output_video_path: str, prompt: str, seed: int = 666, fps_out: int = 24) -> str:
109
+ """
110
+ Aplica o aprimoramento HD a um vídeo usando a lógica oficial do SeedVR.
111
+ """
112
+ try:
113
+ self._initialize_runner()
114
+ set_seed(seed, same_across_ranks=True)
115
+
116
+ # --- Configuração do Pipeline (adaptado de app.py) ---
117
+ self.runner.config.diffusion.cfg.scale = 1.0 # cfg_scale
118
+ self.runner.config.diffusion.cfg.rescale = 0.0 # cfg_rescale
119
+ self.runner.config.diffusion.timesteps.sampling.steps = 1 # sample_steps (one-step model)
120
+ self.runner.configure_diffusion()
121
+
122
+ # --- Preparação do Vídeo de Entrada ---
123
+ logger.info(f"Processando vídeo de entrada: {input_video_path}")
124
+ video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
125
+ if video_tensor.size(0) > 121:
126
+ logger.warning(f"Vídeo com {video_tensor.size(0)} frames. Truncando para 121 frames.")
127
+ video_tensor = video_tensor[:121]
128
+
129
+ video_transform = Compose([
130
+ NaResize(resolution=(1280 * 720)**0.5, mode="area", downsample_only=False),
131
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
132
+ DivisibleCrop((16, 16)),
133
+ Normalize(0.5, 0.5),
134
+ Rearrange("t c h w -> c t h w"),
135
+ ])
136
 
137
+ cond_latent = video_transform(video_tensor.to(self.device))
138
+ input_video_for_colorfix = cond_latent.clone() # Salva para o color fix
139
+ ori_length = cond_latent.size(1)
140
+
141
+ # --- Codificação VAE e Geração ---
142
+ logger.info("Codificando vídeo para o espaço latente...")
143
+ cond_latent = self.runner.vae_encode([cond_latent])[0]
144
+
145
+ text_pos_embeds = torch.load('pos_emb.pt').to(self.device)
146
+ text_neg_embeds = torch.load('neg_emb.pt').to(self.device)
147
+ text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
148
+
149
+ noise = torch.randn_like(cond_latent)
150
+
151
+ logger.info(f"Iniciando a geração de restauração para {ori_length} frames...")
152
+ with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
153
+ video_tensor_out = self.runner.inference(
154
+ noises=[noise],
155
+ conditions=[self.runner.get_condition(noise, task="sr", latent_blur=cond_latent)],
156
+ dit_offload=False,
157
+ **text_embeds_dict,
158
+ )[0]
159
+
160
+ sample = rearrange(video_tensor_out, "c t h w -> t c h w")
161
+
162
+ # --- Pós-processamento e Salvamento ---
163
+ if ori_length < sample.shape[0]:
164
+ sample = sample[:ori_length]
165
+
166
+ input_video_for_colorfix = rearrange(input_video_for_colorfix, "c t h w -> t c h w")
167
+ sample = wavelet_reconstruction(sample.cpu(), input_video_for_colorfix[:sample.size(0)].cpu())
168
+
169
+ sample = rearrange(sample, "t c h w -> t h w c")
170
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
171
+
172
+ logger.info(f"Salvando vídeo aprimorado em: {output_video_path}")
173
+ imageio.get_writer(output_video_path, fps=fps_out, codec='libx264', quality=9).extend(sample)
174
+
175
+ return output_video_path
176
+ finally:
177
+ self._unload_runner()
178
+
179
+ # Instância Singleton
180
  hd_specialist_singleton = HDSpecialist()