File size: 10,821 Bytes
6feecd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 4.0.0 (Root Installer & Executor)
#
# This version fully adopts the logic from the functional hd_specialist.py example.
# It acts as a setup manager: it clones the SeedVR repo and then copies all
# necessary directories (projects, common, models, configs, ckpts) to the
# application root. It also handles the pip installation of the Apex dependency.
# This ensures that the SeedVR code runs in the exact file structure it expects.
import torch
import torch.distributed as dist
import os
import gc
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import gradio as gr
import mediapy
from einops import rearrange
import shutil
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
# --- Caminhos Globais ---
APP_ROOT = Path("/home/user/app")
DEPS_DIR = APP_ROOT / "deps"
SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
class SeedVrManager:
def __init__(self, workspace_dir="deformes_workspace"):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.runner = None
self.workspace_dir = workspace_dir
self.is_initialized = False
self._original_barrier = None
self.setup_complete = False # Flag para rodar o setup apenas uma vez
logger.info("SeedVrManager initialized. Setup will run on first use.")
def _full_setup(self):
"""
Executa todo o processo de setup uma única vez.
"""
if self.setup_complete:
return
logger.info("--- Starting Full SeedVR Setup ---")
# 1. Clonar o repositório se não existir
if not SEEDVR_SPACE_DIR.exists():
logger.info(f"Cloning SeedVR Space repo to {SEEDVR_SPACE_DIR}...")
DEPS_DIR.mkdir(exist_ok=True, parents=True)
subprocess.run(
["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
check=True, capture_output=True, text=True
)
# 2. Copiar as pastas necessárias para a raiz da aplicação
required_dirs = ["projects", "common", "models", "configs_3b", "configs_7b"]
for dirname in required_dirs:
source = SEEDVR_SPACE_DIR / dirname
target = APP_ROOT / dirname
if not target.exists():
logger.info(f"Copying '{dirname}' to application root...")
shutil.copytree(source, target)
# 3. Adicionar a raiz ao sys.path para garantir que os imports funcionem
if str(APP_ROOT) not in sys.path:
sys.path.insert(0, str(APP_ROOT))
logger.info(f"Added '{APP_ROOT}' to sys.path.")
# 4. Instalar dependências complexas como Apex
try:
import apex
logger.info("Apex is already installed.")
except ImportError:
logger.info("Installing Apex dependency...")
apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
subprocess.run(f"pip install {apex_wheel_path}", check=True, shell=True)
logger.info("Apex installed successfully.")
# 5. Baixar os modelos para a pasta ./ckpts na raiz
ckpt_dir = APP_ROOT / 'ckpts'
ckpt_dir.mkdir(exist_ok=True)
pretrain_model_urls = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
}
for name, url in pretrain_model_urls.items():
_load_file_from_url(url=url, model_dir=str(ckpt_dir))
self.setup_complete = True
logger.info("--- Full SeedVR Setup Complete ---")
def _initialize_runner(self, model_version: str):
if self.runner is not None: return
# Garante que todo o ambiente está configurado antes de prosseguir
self._full_setup()
# Agora que o setup está feito, podemos importar os módulos
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.seed import set_seed
if dist.is_available() and not dist.is_initialized():
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
dist.init_process_group(backend='gloo')
logger.info("Initialized torch.distributed process group.")
logger.info(f"Initializing SeedVR2 {model_version} runner...")
if model_version == '3B':
config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b.pth'
else: # Assumimos 7B
config_path = APP_ROOT / 'configs_7b' / 'main.yaml'
checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_7b.pth'
config = load_config(str(config_path))
self.runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(self.runner.config, False)
self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
self.runner.configure_vae_model()
if hasattr(self.runner.vae, "set_memory_limit"):
self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
self.is_initialized = True
logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
def _unload_runner(self):
if self.runner is not None:
del self.runner
self.runner = None
gc.collect()
torch.cuda.empty_cache()
self.is_initialized = False
logger.info("Runner do SeedVR2 descarregado da VRAM.")
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Destroyed torch.distributed process group.")
def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
model_version: str = '7B', steps: int = 100, seed: int = 666,
progress: gr.Progress = None) -> str:
try:
self._initialize_runner(model_version)
# Precisamos importar aqui, pois o sys.path é modificado no setup
from common.seed import set_seed
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
set_seed(seed, same_across_ranks=True)
self.runner.config.diffusion.timesteps.sampling.steps = steps
self.runner.configure_diffusion()
video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
res_h, res_w = video_tensor.shape[-2:]
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
cond_latents = [video_transform(video_tensor.to(self.device))]
input_videos = cond_latents
self.runner.dit.to("cpu")
self.runner.vae.to(self.device)
cond_latents = self.runner.vae_encode(cond_latents)
self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.dit.to(self.device)
pos_emb = torch.load(APP_ROOT / 'pos_emb.pt').to(self.device)
neg_emb = torch.load(APP_ROOT / 'neg_emb.pt').to(self.device)
text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
noises = [torch.randn_like(latent) for latent in cond_latents]
conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.vae.to(self.device)
samples = self.runner.vae_decode(video_tensors)
final_sample = samples[0]
input_video_sample = input_videos[0]
if final_sample.shape[1] < input_video_sample.shape[1]:
input_video_sample = input_video_sample[:, :final_sample.shape[1]]
final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
final_sample = rearrange(final_sample, "t c h w -> t h w c")
final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
mediapy.write_video(output_video_path, final_sample_np, fps=24)
logger.info(f"HD Mastered video saved to: {output_video_path}")
return output_path
finally:
self._unload_runner()
def _load_file_from_url(url, model_dir='./', file_name=None):
os.makedirs(model_dir, exist_ok=True)
filename = file_name or os.path.basename(urlparse(url).path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
logger.info(f'Downloading: "{url}" to {cached_file}')
download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
return cached_file
seedvr_manager_singleton = SeedVrManager() |