euiia commited on
Commit
be881e5
·
verified ·
1 Parent(s): 3a03753

Delete managers/audio_specialist.py

Browse files
Files changed (1) hide show
  1. managers/audio_specialist.py +0 -242
managers/audio_specialist.py DELETED
@@ -1,242 +0,0 @@
1
- # hd_specialist.py
2
- #
3
- # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
- #
5
- # Version: 2.2.0
6
- #
7
- # This file implements the HD Specialist (Δ+), which uses the SeedVR model
8
- # for video super-resolution. It has been refactored to be self-contained by
9
- # automatically cloning its own dependencies from the official SeedVR repository
10
- # if they are not found locally. This removes the need for manual file copying
11
- # and makes the ADUC-SDR framework more robust and portable.
12
-
13
- import torch
14
- import os
15
- import gc
16
- import logging
17
- import sys
18
- import subprocess
19
- from pathlib import Path
20
- from urllib.parse import urlparse
21
- from torch.hub import download_url_to_file
22
- import gradio as gr
23
- import mediapy
24
- from einops import rearrange
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
- # --- Dependency Management ---
29
- DEPS_DIR = Path("./deps")
30
- SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
31
- SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
32
-
33
- def _load_file_from_url(url, model_dir='./', file_name=None):
34
- """Helper function to download files from a URL to a local directory."""
35
- os.makedirs(model_dir, exist_ok=True)
36
- filename = file_name or os.path.basename(urlparse(url).path)
37
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
38
- if not os.path.exists(cached_file):
39
- logger.info(f'Downloading: "{url}" to {cached_file}')
40
- download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
41
- return cached_file
42
-
43
- class HDSpecialist:
44
- """
45
- Implements the HD Specialist (Δ+) using the SeedVR infrastructure.
46
- Manages model loading, inference, and memory on demand.
47
- """
48
- def __init__(self, workspace_dir="deformes_workspace"):
49
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- self.runner = None
51
- self.workspace_dir = workspace_dir
52
- self.is_initialized = False
53
- self._seedvr_modules_loaded = False
54
- self._setup_dependencies()
55
- logger.info("HD Specialist (SeedVR) initialized. Dependencies checked. Model will be loaded on demand.")
56
-
57
- def _setup_dependencies(self):
58
- """
59
- Checks for the SeedVR repository locally. If not found, clones it.
60
- Then, it adds the repository to the Python path to make its modules importable.
61
- """
62
- if not SEEDVR_REPO_DIR.exists():
63
- logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
64
- try:
65
- DEPS_DIR.mkdir(exist_ok=True)
66
- subprocess.run(
67
- ["git", "clone", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
68
- check=True, capture_output=True, text=True
69
- )
70
- logger.info("SeedVR repository cloned successfully.")
71
- except subprocess.CalledProcessError as e:
72
- logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
73
- raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
74
- else:
75
- logger.info("Found local SeedVR repository.")
76
-
77
- # Add the cloned repo to Python's path to allow direct imports
78
- if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
79
- sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
80
- logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
81
-
82
- def _lazy_load_seedvr_modules(self):
83
- """
84
- Dynamically imports SeedVR modules only when needed.
85
- This prevents ImportError if the class is instantiated before dependencies are ready.
86
- """
87
- if self._seedvr_modules_loaded:
88
- return
89
-
90
- global VideoDiffusionInfer, load_config, set_seed, DivisibleCrop, NaResize, Rearrange, wavelet_reconstruction, Compose, Lambda, Normalize, read_video, OmegaConf
91
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
- from common.config import load_config
93
- from common.seed import set_seed
94
- from data.image.transforms.divisible_crop import DivisibleCrop
95
- from data.image.transforms.na_resize import NaResize
96
- from data.video.transforms.rearrange import Rearrange
97
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
98
- from torchvision.transforms import Compose, Lambda, Normalize
99
- from torchvision.io.video import read_video
100
- from omegaconf import OmegaConf
101
-
102
- self._seedvr_modules_loaded = True
103
- logger.info("SeedVR modules have been dynamically loaded.")
104
-
105
- def _download_models(self):
106
- """Downloads the necessary checkpoints for SeedVR2."""
107
- logger.info("Verifying and downloading SeedVR2 models...")
108
- ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
109
- ckpt_dir.mkdir(exist_ok=True)
110
-
111
- pretrain_model_urls = {
112
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
113
- 'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
114
- 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
115
- 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
116
- 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
117
- }
118
-
119
- for key, url in pretrain_model_urls.items():
120
- _load_file_from_url(url=url, model_dir=str(ckpt_dir))
121
-
122
- logger.info("SeedVR2 models downloaded successfully.")
123
-
124
- def _initialize_runner(self, model_version: str):
125
- """Loads and configures the SeedVR model on demand based on the selected version."""
126
- if self.runner is not None:
127
- return
128
-
129
- self._lazy_load_seedvr_modules()
130
- self._download_models()
131
-
132
- logger.info(f"Initializing SeedVR2 {model_version} runner...")
133
- if model_version == '3B':
134
- config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
135
- checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
136
- elif model_version == '7B':
137
- config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
138
- checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
139
- else:
140
- raise ValueError(f"Unsupported SeedVR model version: {model_version}")
141
-
142
- config = load_config(str(config_path))
143
-
144
- self.runner = VideoDiffusionInfer(config)
145
- OmegaConf.set_readonly(self.runner.config, False)
146
-
147
- self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
148
- self.runner.configure_vae_model()
149
-
150
- if hasattr(self.runner.vae, "set_memory_limit"):
151
- self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
152
-
153
- self.is_initialized = True
154
- logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
155
-
156
- def _unload_runner(self):
157
- """Removes the runner from VRAM to free resources."""
158
- if self.runner is not None:
159
- del self.runner
160
- self.runner = None
161
- gc.collect()
162
- torch.cuda.empty_cache()
163
- self.is_initialized = False
164
- logger.info("SeedVR2 runner unloaded from VRAM.")
165
-
166
- def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
167
- model_version: str = '3B', steps: int = 50, seed: int = 666,
168
- progress: gr.Progress = None) -> str:
169
- """Applies HD enhancement to a video using the SeedVR logic."""
170
- try:
171
- self._initialize_runner(model_version)
172
- set_seed(seed, same_across_ranks=True)
173
-
174
- self.runner.config.diffusion.timesteps.sampling.steps = steps
175
- self.runner.configure_diffusion()
176
-
177
- video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
178
- res_h, res_w = video_tensor.shape[-2:]
179
-
180
- video_transform = Compose([
181
- NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
182
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
183
- DivisibleCrop((16, 16)),
184
- Normalize(0.5, 0.5),
185
- Rearrange("t c h w -> c t h w"),
186
- ])
187
-
188
- cond_latents = [video_transform(video_tensor.to(self.device))]
189
- input_videos = cond_latents
190
-
191
- self.runner.dit.to("cpu")
192
- self.runner.vae.to(self.device)
193
- cond_latents = self.runner.vae_encode(cond_latents)
194
- self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
195
- self.runner.dit.to(self.device)
196
-
197
- pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
198
- neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
199
- text_pos_embeds = torch.load(pos_emb_path).to(self.device)
200
- text_neg_embeds = torch.load(neg_emb_path).to(self.device)
201
- text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
202
-
203
- noises = [torch.randn_like(latent) for latent in cond_latents]
204
- conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
205
-
206
- with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
207
- video_tensors = self.runner.inference(
208
- noises=noises,
209
- conditions=conditions,
210
- dit_offload=True,
211
- **text_embeds_dict,
212
- )
213
-
214
- self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
215
-
216
- self.runner.vae.to(self.device)
217
- samples = self.runner.vae_decode(video_tensors)
218
-
219
- final_sample = samples[0]
220
- input_video_sample = input_videos[0]
221
-
222
- if final_sample.shape[1] < input_video_sample.shape[1]:
223
- input_video_sample = input_video_sample[:, :final_sample.shape[1]]
224
-
225
- final_sample = wavelet_reconstruction(
226
- rearrange(final_sample, "c t h w -> t c h w"),
227
- rearrange(input_video_sample, "c t h w -> t c h w")
228
- )
229
-
230
- final_sample = rearrange(final_sample, "t c h w -> t h w c")
231
- final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
232
- final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
233
-
234
- mediapy.write_video(output_video_path, final_sample_np, fps=24)
235
- logger.info(f"HD Mastered video saved to: {output_video_path}")
236
- return output_video_path
237
-
238
- finally:
239
- self._unload_runner()
240
-
241
- # Singleton instance
242
- hd_specialist_singleton = HDSpecialist()