euiia commited on
Commit
9fcad90
·
verified ·
1 Parent(s): 7974f2d

Rename managers/hd_specialist.py to managers/seedvr_manager.py

Browse files
managers/{hd_specialist.py → seedvr_manager.py} RENAMED
@@ -1,14 +1,12 @@
1
- # managers/hd_specialist.py
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.2.0
6
  #
7
- # This file implements the HD Specialist (Δ+), which uses the SeedVR model
8
- # for video super-resolution. It has been refactored to be self-contained by
9
- # automatically cloning its own dependencies from the official SeedVR repository
10
- # if they are not found locally. This removes the need for manual file copying
11
- # and makes the ADUC-SDR framework more robust and portable.
12
 
13
  import torch
14
  import os
@@ -23,6 +21,9 @@ import gradio as gr
23
  import mediapy
24
  from einops import rearrange
25
 
 
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
  # --- Dependency Management ---
@@ -30,6 +31,47 @@ DEPS_DIR = Path("./deps")
30
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
31
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def _load_file_from_url(url, model_dir='./', file_name=None):
34
  """Helper function to download files from a URL to a local directory."""
35
  os.makedirs(model_dir, exist_ok=True)
@@ -40,67 +82,16 @@ def _load_file_from_url(url, model_dir='./', file_name=None):
40
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
41
  return cached_file
42
 
43
- class HDSpecialist:
44
  """
45
- Implements the HD Specialist (Δ+) using the SeedVR infrastructure.
46
- Manages model loading, inference, and memory on demand.
47
  """
48
  def __init__(self, workspace_dir="deformes_workspace"):
49
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
  self.runner = None
51
  self.workspace_dir = workspace_dir
52
  self.is_initialized = False
53
- self._seedvr_modules_loaded = False
54
- self._setup_dependencies()
55
- logger.info("HD Specialist (SeedVR) initialized. Dependencies checked. Model will be loaded on demand.")
56
-
57
- def _setup_dependencies(self):
58
- """
59
- Checks for the SeedVR repository locally. If not found, clones it.
60
- Then, it adds the repository to the Python path to make its modules importable.
61
- """
62
- if not SEEDVR_REPO_DIR.exists():
63
- logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
64
- try:
65
- DEPS_DIR.mkdir(exist_ok=True)
66
- subprocess.run(
67
- ["git", "clone", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
68
- check=True, capture_output=True, text=True
69
- )
70
- logger.info("SeedVR repository cloned successfully.")
71
- except subprocess.CalledProcessError as e:
72
- logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
73
- raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
74
- else:
75
- logger.info("Found local SeedVR repository.")
76
-
77
- # Add the cloned repo to Python's path to allow direct imports
78
- if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
79
- sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
80
- logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
81
-
82
- def _lazy_load_seedvr_modules(self):
83
- """
84
- Dynamically imports SeedVR modules only when needed.
85
- This prevents ImportError if the class is instantiated before dependencies are ready.
86
- """
87
- if self._seedvr_modules_loaded:
88
- return
89
-
90
- global VideoDiffusionInfer, load_config, set_seed, DivisibleCrop, NaResize, Rearrange, wavelet_reconstruction, Compose, Lambda, Normalize, read_video, OmegaConf
91
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
- from common.config import load_config
93
- from common.seed import set_seed
94
- from data.image.transforms.divisible_crop import DivisibleCrop
95
- from data.image.transforms.na_resize import NaResize
96
- from data.video.transforms.rearrange import Rearrange
97
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
98
- from torchvision.transforms import Compose, Lambda, Normalize
99
- from torchvision.io.video import read_video
100
- from omegaconf import OmegaConf
101
-
102
- self._seedvr_modules_loaded = True
103
- logger.info("SeedVR modules have been dynamically loaded.")
104
 
105
  def _download_models(self):
106
  """Downloads the necessary checkpoints for SeedVR2."""
@@ -123,10 +114,8 @@ class HDSpecialist:
123
 
124
  def _initialize_runner(self, model_version: str):
125
  """Loads and configures the SeedVR model on demand based on the selected version."""
126
- if self.runner is not None:
127
- return
128
-
129
- self._lazy_load_seedvr_modules()
130
  self._download_models()
131
 
132
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
@@ -156,10 +145,8 @@ class HDSpecialist:
156
  def _unload_runner(self):
157
  """Removes the runner from VRAM to free resources."""
158
  if self.runner is not None:
159
- del self.runner
160
- self.runner = None
161
- gc.collect()
162
- torch.cuda.empty_cache()
163
  self.is_initialized = False
164
  logger.info("SeedVR2 runner unloaded from VRAM.")
165
 
@@ -204,12 +191,7 @@ class HDSpecialist:
204
  conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
205
 
206
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
207
- video_tensors = self.runner.inference(
208
- noises=noises,
209
- conditions=conditions,
210
- dit_offload=True,
211
- **text_embeds_dict,
212
- )
213
 
214
  self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
215
 
@@ -223,10 +205,10 @@ class HDSpecialist:
223
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
224
 
225
  final_sample = wavelet_reconstruction(
226
- rearrange(final_sample, "c t h w -> t c h w"),
227
  rearrange(input_video_sample, "c t h w -> t c h w")
228
  )
229
-
230
  final_sample = rearrange(final_sample, "t c h w -> t h w c")
231
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
232
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
@@ -234,9 +216,8 @@ class HDSpecialist:
234
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
235
  logger.info(f"HD Mastered video saved to: {output_video_path}")
236
  return output_video_path
237
-
238
  finally:
239
  self._unload_runner()
240
 
241
- # Singleton instance
242
- hd_specialist_singleton = HDSpecialist()
 
1
+ # managers/seedvr_manager.py
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.3.0
6
  #
7
+ # This file implements the SeedVrManager, which uses the SeedVR model for
8
+ # video super-resolution. It is self-contained, automatically cloning its own
9
+ # dependencies from the official SeedVR repository.
 
 
10
 
11
  import torch
12
  import os
 
21
  import mediapy
22
  from einops import rearrange
23
 
24
+ # Internalized utility for color correction, ensuring stability.
25
+ from tools.tensor_utils import wavelet_reconstruction
26
+
27
  logger = logging.getLogger(__name__)
28
 
29
  # --- Dependency Management ---
 
31
  SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
32
  SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
33
 
34
+ def setup_seedvr_dependencies():
35
+ """
36
+ Ensures the SeedVR repository is cloned and available in the sys.path.
37
+ This function is run once when the module is first imported.
38
+ """
39
+ if not SEEDVR_REPO_DIR.exists():
40
+ logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
41
+ try:
42
+ DEPS_DIR.mkdir(exist_ok=True)
43
+ # Use --depth 1 for a shallow clone to save space and time
44
+ subprocess.run(
45
+ ["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)],
46
+ check=True, capture_output=True, text=True
47
+ )
48
+ logger.info("SeedVR repository cloned successfully.")
49
+ except subprocess.CalledProcessError as e:
50
+ logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
51
+ raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
52
+ else:
53
+ logger.info("Found local SeedVR repository.")
54
+
55
+ # Add the cloned repo to Python's path to allow direct imports
56
+ if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
57
+ sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
58
+ logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
59
+
60
+ # --- Execute dependency setup immediately upon module import ---
61
+ setup_seedvr_dependencies()
62
+
63
+ # --- Now that the path is set, we can safely import from the cloned repo ---
64
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
65
+ from common.config import load_config
66
+ from common.seed import set_seed
67
+ from data.image.transforms.divisible_crop import DivisibleCrop
68
+ from data.image.transforms.na_resize import NaResize
69
+ from data.video.transforms.rearrange import Rearrange
70
+ from torchvision.transforms import Compose, Lambda, Normalize
71
+ from torchvision.io.video import read_video
72
+ from omegaconf import OmegaConf
73
+
74
+
75
  def _load_file_from_url(url, model_dir='./', file_name=None):
76
  """Helper function to download files from a URL to a local directory."""
77
  os.makedirs(model_dir, exist_ok=True)
 
82
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
83
  return cached_file
84
 
85
+ class SeedVrManager:
86
  """
87
+ Manages the SeedVR model for HD Mastering tasks.
 
88
  """
89
  def __init__(self, workspace_dir="deformes_workspace"):
90
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
  self.runner = None
92
  self.workspace_dir = workspace_dir
93
  self.is_initialized = False
94
+ logger.info("SeedVrManager initialized. Model will be loaded on demand.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def _download_models(self):
97
  """Downloads the necessary checkpoints for SeedVR2."""
 
114
 
115
  def _initialize_runner(self, model_version: str):
116
  """Loads and configures the SeedVR model on demand based on the selected version."""
117
+ if self.runner is not None: return
118
+
 
 
119
  self._download_models()
120
 
121
  logger.info(f"Initializing SeedVR2 {model_version} runner...")
 
145
  def _unload_runner(self):
146
  """Removes the runner from VRAM to free resources."""
147
  if self.runner is not None:
148
+ del self.runner; self.runner = None
149
+ gc.collect(); torch.cuda.empty_cache()
 
 
150
  self.is_initialized = False
151
  logger.info("SeedVR2 runner unloaded from VRAM.")
152
 
 
191
  conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
192
 
193
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
194
+ video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
 
 
 
 
 
195
 
196
  self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
197
 
 
205
  input_video_sample = input_video_sample[:, :final_sample.shape[1]]
206
 
207
  final_sample = wavelet_reconstruction(
208
+ rearrange(final_sample, "c t h w -> t c h w"),
209
  rearrange(input_video_sample, "c t h w -> t c h w")
210
  )
211
+
212
  final_sample = rearrange(final_sample, "t c h w -> t h w c")
213
  final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
214
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
 
216
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
217
  logger.info(f"HD Mastered video saved to: {output_video_path}")
218
  return output_video_path
 
219
  finally:
220
  self._unload_runner()
221
 
222
+ # --- Singleton Instance ---
223
+ seedvr_manager_singleton = SeedVrManager()