euiia commited on
Commit
2f6b6e4
·
verified ·
1 Parent(s): 1d6cce1

Update managers/audio_specialist.py

Browse files
Files changed (1) hide show
  1. managers/audio_specialist.py +217 -138
managers/audio_specialist.py CHANGED
@@ -1,163 +1,242 @@
1
- # audio_specialist.py
2
- # Especialista ADUC para geração de áudio, com gerenciamento de memória GPU.
3
- # Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
 
 
 
 
 
 
 
 
4
 
5
  import torch
6
- import logging
7
- import subprocess
8
  import os
9
- import time
10
- import yaml
11
  import gc
 
 
 
12
  from pathlib import Path
 
 
13
  import gradio as gr
14
-
15
- # Importa as classes e funções necessárias do MMAudio
16
- try:
17
- from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate as mmaudio_generate, load_video, make_video
18
- from mmaudio.model.flow_matching import FlowMatching
19
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
20
- from mmaudio.model.utils.features_utils import FeaturesUtils
21
- from mmaudio.model.sequence_config import SequenceConfig
22
- except ImportError:
23
- raise ImportError("MMAudio não foi encontrado. Por favor, instale-o a partir do GitHub: git+https://github.com/hkchengrex/MMAudio.git")
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
- class AudioSpecialist:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
- Especialista responsável por gerar áudio para fragmentos de vídeo.
30
- Gerencia o carregamento e descarregamento de modelos de áudio da VRAM.
31
  """
32
- def __init__(self, workspace_dir):
33
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
- self.cpu_device = torch.device("cpu")
35
- self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
36
  self.workspace_dir = workspace_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- self.model_config: ModelConfig = all_model_cfg['large_44k_v2']
39
- self.net: MMAudio = None
40
- self.feature_utils: FeaturesUtils = None
41
- self.seq_cfg: SequenceConfig = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- self._load_models_to_cpu()
 
44
 
45
- def _load_models_to_cpu(self):
46
- """Carrega os modelos MMAudio para a memória da CPU na inicialização."""
47
- try:
48
- logger.info("Verificando e baixando modelos MMAudio, se necessário...")
49
- self.model_config.download_if_needed()
50
-
51
- self.seq_cfg = self.model_config.seq_cfg
52
-
53
- logger.info(f"Carregando modelo MMAudio: {self.model_config.model_name} para a CPU...")
54
- self.net = get_my_mmaudio(self.model_config.model_name).eval()
55
- self.net.load_weights(torch.load(self.model_config.model_path, map_location=self.cpu_device, weights_only=True))
 
 
 
 
 
56
 
57
- logger.info("Carregando utilitários de features do MMAudio para a CPU...")
58
- self.feature_utils = FeaturesUtils(
59
- tod_vae_ckpt=self.model_config.vae_path,
60
- synchformer_ckpt=self.model_config.synchformer_ckpt,
61
- enable_conditions=True,
62
- mode=self.model_config.mode,
63
- bigvgan_vocoder_ckpt=self.model_config.bigvgan_16k_path,
64
- need_vae_encoder=False
65
- )
66
- self.feature_utils = self.feature_utils.eval()
67
- self.net.to(self.cpu_device)
68
- self.feature_utils.to(self.cpu_device)
69
- logger.info("Especialista de áudio pronto na CPU.")
70
- except Exception as e:
71
- logger.error(f"Falha ao carregar modelos de áudio: {e}", exc_info=True)
72
- self.net = None
73
-
74
- def to_gpu(self):
75
- """Move os modelos e utilitários para a GPU antes da inferência."""
76
- if self.device == 'cpu': return
77
- logger.info(f"Movendo especialista de áudio para a GPU ({self.device})...")
78
- self.net.to(self.device, self.dtype)
79
- self.feature_utils.to(self.device, self.dtype)
80
-
81
- def to_cpu(self):
82
- """Move os modelos de volta para a CPU e limpa a VRAM após a inferência."""
83
- if self.device == 'cpu': return
84
- logger.info("Descarregando especialista de áudio da GPU...")
85
- self.net.to(self.cpu_device)
86
- self.feature_utils.to(self.cpu_device)
87
- gc.collect()
88
- if torch.cuda.is_available(): torch.cuda.empty_cache()
89
-
90
- def generate_audio_for_video(self, video_path: str, prompt: str, duration_seconds: float, output_path_override: str = None) -> str:
91
- """
92
- Gera áudio para um arquivo de vídeo, aplicando um prompt negativo para evitar fala.
93
 
94
- Args:
95
- video_path (str): Caminho para o vídeo silencioso.
96
- prompt (str): Descrição da cena para guiar a geração de SFX.
97
- duration_seconds (float): Duração do áudio a ser gerado.
 
 
 
 
 
98
 
99
- Returns:
100
- str: Caminho para o novo arquivo de vídeo com áudio.
101
- """
102
- if self.net is None:
103
- raise gr.Error("Modelo MMAudio não está carregado. Não é possível gerar áudio.")
104
-
105
- logger.info("------------------------------------------------------")
106
- logger.info("--- Gerando Áudio para Fragmento de Vídeo ---")
107
- logger.info(f"--- Vídeo Fragmento: {os.path.basename(video_path)}")
108
- logger.info(f"--- Duração: {duration_seconds:.2f}s")
109
- logger.info(f"--- Prompt (Descrição da Cena): '{prompt}'")
110
 
111
- negative_prompt = "human voice"
112
- logger.info(f"--- Negative Prompt: '{negative_prompt}'")
113
 
114
- if duration_seconds < 1:
115
- logger.warning("Fragmento muito curto (<1s). Retornando vídeo silencioso.")
116
- logger.info("------------------------------------------------------")
117
- return video_path
 
118
 
119
- if self.device == 'cpu':
120
- logger.warning("Gerando áudio na CPU. Isso pode ser muito lento.")
 
 
 
 
 
 
 
121
 
 
 
 
 
122
  try:
123
- self.to_gpu()
124
- with torch.no_grad():
125
- rng = torch.Generator(device=self.device).manual_seed(int(time.time()))
126
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
127
-
128
- video_info = load_video(Path(video_path), duration_seconds)
129
- self.seq_cfg.duration = video_info.duration_sec
130
- self.net.update_seq_lengths(self.seq_cfg.latent_seq_len, self.seq_cfg.clip_seq_len, self.seq_cfg.sync_seq_len)
131
-
132
- audios = mmaudio_generate(
133
- clip_video=video_info.clip_frames.unsqueeze(0),
134
- sync_video=video_info.sync_frames.unsqueeze(0),
135
- text=[prompt],
136
- negative_text=[negative_prompt],
137
- feature_utils=self.feature_utils,
138
- net=self.net,
139
- fm=fm,
140
- rng=rng,
141
- cfg_strength=4.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
143
- audio_waveform = audios.float().cpu()[0]
144
-
145
- fragment_name = Path(video_path).stem
146
- output_video_path = output_path_override if output_path_override else os.path.join(self.workspace_dir, f"{fragment_name}_com_audio.mp4")
147
-
148
- make_video(video_info, Path(output_video_path), audio_waveform, sampling_rate=self.seq_cfg.sampling_rate)
149
- logger.info(f"--- Fragmento com áudio salvo em: {os.path.basename(output_video_path)}")
150
- logger.info("------------------------------------------------------")
151
- return output_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  finally:
153
- self.to_cpu()
154
-
155
- # Singleton instantiation
156
- try:
157
- with open("config.yaml", 'r') as f:
158
- config = yaml.safe_load(f)
159
- WORKSPACE_DIR = config['application']['workspace_dir']
160
- audio_specialist_singleton = AudioSpecialist(workspace_dir=WORKSPACE_DIR)
161
- except Exception as e:
162
- logger.error(f"Não foi possível inicializar o AudioSpecialist: {e}", exc_info=True)
163
- audio_specialist_singleton = None
 
1
+ # hd_specialist.py
2
+ #
3
+ # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 2.2.0
6
+ #
7
+ # This file implements the HD Specialist (Δ+), which uses the SeedVR model
8
+ # for video super-resolution. It has been refactored to be self-contained by
9
+ # automatically cloning its own dependencies from the official SeedVR repository
10
+ # if they are not found locally. This removes the need for manual file copying
11
+ # and makes the ADUC-SDR framework more robust and portable.
12
 
13
  import torch
 
 
14
  import os
 
 
15
  import gc
16
+ import logging
17
+ import sys
18
+ import subprocess
19
  from pathlib import Path
20
+ from urllib.parse import urlparse
21
+ from torch.hub import download_url_to_file
22
  import gradio as gr
23
+ import mediapy
24
+ from einops import rearrange
 
 
 
 
 
 
 
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ # --- Dependency Management ---
29
+ DEPS_DIR = Path("./deps")
30
+ SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
31
+ SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
32
+
33
+ def _load_file_from_url(url, model_dir='./', file_name=None):
34
+ """Helper function to download files from a URL to a local directory."""
35
+ os.makedirs(model_dir, exist_ok=True)
36
+ filename = file_name or os.path.basename(urlparse(url).path)
37
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
38
+ if not os.path.exists(cached_file):
39
+ logger.info(f'Downloading: "{url}" to {cached_file}')
40
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
41
+ return cached_file
42
+
43
+ class HDSpecialist:
44
  """
45
+ Implements the HD Specialist (Δ+) using the SeedVR infrastructure.
46
+ Manages model loading, inference, and memory on demand.
47
  """
48
+ def __init__(self, workspace_dir="deformes_workspace"):
49
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ self.runner = None
 
51
  self.workspace_dir = workspace_dir
52
+ self.is_initialized = False
53
+ self._seedvr_modules_loaded = False
54
+ self._setup_dependencies()
55
+ logger.info("HD Specialist (SeedVR) initialized. Dependencies checked. Model will be loaded on demand.")
56
+
57
+ def _setup_dependencies(self):
58
+ """
59
+ Checks for the SeedVR repository locally. If not found, clones it.
60
+ Then, it adds the repository to the Python path to make its modules importable.
61
+ """
62
+ if not SEEDVR_REPO_DIR.exists():
63
+ logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
64
+ try:
65
+ DEPS_DIR.mkdir(exist_ok=True)
66
+ subprocess.run(
67
+ ["git", "clone", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
68
+ check=True, capture_output=True, text=True
69
+ )
70
+ logger.info("SeedVR repository cloned successfully.")
71
+ except subprocess.CalledProcessError as e:
72
+ logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
73
+ raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
74
+ else:
75
+ logger.info("Found local SeedVR repository.")
76
 
77
+ # Add the cloned repo to Python's path to allow direct imports
78
+ if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
79
+ sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
80
+ logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
81
+
82
+ def _lazy_load_seedvr_modules(self):
83
+ """
84
+ Dynamically imports SeedVR modules only when needed.
85
+ This prevents ImportError if the class is instantiated before dependencies are ready.
86
+ """
87
+ if self._seedvr_modules_loaded:
88
+ return
89
+
90
+ global VideoDiffusionInfer, load_config, set_seed, DivisibleCrop, NaResize, Rearrange, wavelet_reconstruction, Compose, Lambda, Normalize, read_video, OmegaConf
91
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
+ from common.config import load_config
93
+ from common.seed import set_seed
94
+ from data.image.transforms.divisible_crop import DivisibleCrop
95
+ from data.image.transforms.na_resize import NaResize
96
+ from data.video.transforms.rearrange import Rearrange
97
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
98
+ from torchvision.transforms import Compose, Lambda, Normalize
99
+ from torchvision.io.video import read_video
100
+ from omegaconf import OmegaConf
101
 
102
+ self._seedvr_modules_loaded = True
103
+ logger.info("SeedVR modules have been dynamically loaded.")
104
 
105
+ def _download_models(self):
106
+ """Downloads the necessary checkpoints for SeedVR2."""
107
+ logger.info("Verifying and downloading SeedVR2 models...")
108
+ ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
109
+ ckpt_dir.mkdir(exist_ok=True)
110
+
111
+ pretrain_model_urls = {
112
+ 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
113
+ 'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
114
+ 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
115
+ 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
116
+ 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
117
+ }
118
+
119
+ for key, url in pretrain_model_urls.items():
120
+ _load_file_from_url(url=url, model_dir=str(ckpt_dir))
121
 
122
+ logger.info("SeedVR2 models downloaded successfully.")
123
+
124
+ def _initialize_runner(self, model_version: str):
125
+ """Loads and configures the SeedVR model on demand based on the selected version."""
126
+ if self.runner is not None:
127
+ return
128
+
129
+ self._lazy_load_seedvr_modules()
130
+ self._download_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ logger.info(f"Initializing SeedVR2 {model_version} runner...")
133
+ if model_version == '3B':
134
+ config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
135
+ checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
136
+ elif model_version == '7B':
137
+ config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
138
+ checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
139
+ else:
140
+ raise ValueError(f"Unsupported SeedVR model version: {model_version}")
141
 
142
+ config = load_config(str(config_path))
143
+
144
+ self.runner = VideoDiffusionInfer(config)
145
+ OmegaConf.set_readonly(self.runner.config, False)
 
 
 
 
 
 
 
146
 
147
+ self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
148
+ self.runner.configure_vae_model()
149
 
150
+ if hasattr(self.runner.vae, "set_memory_limit"):
151
+ self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
152
+
153
+ self.is_initialized = True
154
+ logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
155
 
156
+ def _unload_runner(self):
157
+ """Removes the runner from VRAM to free resources."""
158
+ if self.runner is not None:
159
+ del self.runner
160
+ self.runner = None
161
+ gc.collect()
162
+ torch.cuda.empty_cache()
163
+ self.is_initialized = False
164
+ logger.info("SeedVR2 runner unloaded from VRAM.")
165
 
166
+ def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
167
+ model_version: str = '3B', steps: int = 50, seed: int = 666,
168
+ progress: gr.Progress = None) -> str:
169
+ """Applies HD enhancement to a video using the SeedVR logic."""
170
  try:
171
+ self._initialize_runner(model_version)
172
+ set_seed(seed, same_across_ranks=True)
173
+
174
+ self.runner.config.diffusion.timesteps.sampling.steps = steps
175
+ self.runner.configure_diffusion()
176
+
177
+ video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
178
+ res_h, res_w = video_tensor.shape[-2:]
179
+
180
+ video_transform = Compose([
181
+ NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
182
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
183
+ DivisibleCrop((16, 16)),
184
+ Normalize(0.5, 0.5),
185
+ Rearrange("t c h w -> c t h w"),
186
+ ])
187
+
188
+ cond_latents = [video_transform(video_tensor.to(self.device))]
189
+ input_videos = cond_latents
190
+
191
+ self.runner.dit.to("cpu")
192
+ self.runner.vae.to(self.device)
193
+ cond_latents = self.runner.vae_encode(cond_latents)
194
+ self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
195
+ self.runner.dit.to(self.device)
196
+
197
+ pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
198
+ neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
199
+ text_pos_embeds = torch.load(pos_emb_path).to(self.device)
200
+ text_neg_embeds = torch.load(neg_emb_path).to(self.device)
201
+ text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
202
+
203
+ noises = [torch.randn_like(latent) for latent in cond_latents]
204
+ conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
205
+
206
+ with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
207
+ video_tensors = self.runner.inference(
208
+ noises=noises,
209
+ conditions=conditions,
210
+ dit_offload=True,
211
+ **text_embeds_dict,
212
  )
213
+
214
+ self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
215
+
216
+ self.runner.vae.to(self.device)
217
+ samples = self.runner.vae_decode(video_tensors)
218
+
219
+ final_sample = samples[0]
220
+ input_video_sample = input_videos[0]
221
+
222
+ if final_sample.shape[1] < input_video_sample.shape[1]:
223
+ input_video_sample = input_video_sample[:, :final_sample.shape[1]]
224
+
225
+ final_sample = wavelet_reconstruction(
226
+ rearrange(final_sample, "c t h w -> t c h w"),
227
+ rearrange(input_video_sample, "c t h w -> t c h w")
228
+ )
229
+
230
+ final_sample = rearrange(final_sample, "t c h w -> t h w c")
231
+ final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
232
+ final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
233
+
234
+ mediapy.write_video(output_video_path, final_sample_np, fps=24)
235
+ logger.info(f"HD Mastered video saved to: {output_video_path}")
236
+ return output_video_path
237
+
238
  finally:
239
+ self._unload_runner()
240
+
241
+ # Singleton instance
242
+ hd_specialist_singleton = HDSpecialist()