Aduc_sdr / seedvr_manager (2).py
aducsdr's picture
Upload seedvr_manager (2).py
971c7f9 verified
raw
history blame
10.8 kB
# managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 4.0.0 (Root Installer & Executor)
#
# This version fully adopts the logic from the functional hd_specialist.py example.
# It acts as a setup manager: it clones the SeedVR repo and then copies all
# necessary directories (projects, common, models, configs, ckpts) to the
# application root. It also handles the pip installation of the Apex dependency.
# This ensures that the SeedVR code runs in the exact file structure it expects.
import torch
import torch.distributed as dist
import os
import gc
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import gradio as gr
import mediapy
from einops import rearrange
import shutil
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
# --- Caminhos Globais ---
APP_ROOT = Path("/home/user/app")
DEPS_DIR = APP_ROOT / "deps"
SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
class SeedVrManager:
def __init__(self, workspace_dir="deformes_workspace"):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.runner = None
self.workspace_dir = workspace_dir
self.is_initialized = False
self._original_barrier = None
self.setup_complete = False # Flag para rodar o setup apenas uma vez
logger.info("SeedVrManager initialized. Setup will run on first use.")
def _full_setup(self):
"""
Executa todo o processo de setup uma única vez.
"""
if self.setup_complete:
return
logger.info("--- Starting Full SeedVR Setup ---")
# 1. Clonar o repositório se não existir
if not SEEDVR_SPACE_DIR.exists():
logger.info(f"Cloning SeedVR Space repo to {SEEDVR_SPACE_DIR}...")
DEPS_DIR.mkdir(exist_ok=True, parents=True)
subprocess.run(
["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
check=True, capture_output=True, text=True
)
# 2. Copiar as pastas necessárias para a raiz da aplicação
required_dirs = ["projects", "common", "models", "configs_3b", "configs_7b"]
for dirname in required_dirs:
source = SEEDVR_SPACE_DIR / dirname
target = APP_ROOT / dirname
if not target.exists():
logger.info(f"Copying '{dirname}' to application root...")
shutil.copytree(source, target)
# 3. Adicionar a raiz ao sys.path para garantir que os imports funcionem
if str(APP_ROOT) not in sys.path:
sys.path.insert(0, str(APP_ROOT))
logger.info(f"Added '{APP_ROOT}' to sys.path.")
# 4. Instalar dependências complexas como Apex
try:
import apex
logger.info("Apex is already installed.")
except ImportError:
logger.info("Installing Apex dependency...")
apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
subprocess.run(f"pip install {apex_wheel_path}", check=True, shell=True)
logger.info("Apex installed successfully.")
# 5. Baixar os modelos para a pasta ./ckpts na raiz
ckpt_dir = APP_ROOT / 'ckpts'
ckpt_dir.mkdir(exist_ok=True)
pretrain_model_urls = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
}
for name, url in pretrain_model_urls.items():
_load_file_from_url(url=url, model_dir=str(ckpt_dir))
self.setup_complete = True
logger.info("--- Full SeedVR Setup Complete ---")
def _initialize_runner(self, model_version: str):
if self.runner is not None: return
# Garante que todo o ambiente está configurado antes de prosseguir
self._full_setup()
# Agora que o setup está feito, podemos importar os módulos
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.seed import set_seed
if dist.is_available() and not dist.is_initialized():
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
dist.init_process_group(backend='gloo')
logger.info("Initialized torch.distributed process group.")
logger.info(f"Initializing SeedVR2 {model_version} runner...")
if model_version == '3B':
config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b.pth'
else: # Assumimos 7B
config_path = APP_ROOT / 'configs_7b' / 'main.yaml'
checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_7b.pth'
config = load_config(str(config_path))
self.runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(self.runner.config, False)
self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
self.runner.configure_vae_model()
if hasattr(self.runner.vae, "set_memory_limit"):
self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
self.is_initialized = True
logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
def _unload_runner(self):
if self.runner is not None:
del self.runner
self.runner = None
gc.collect()
torch.cuda.empty_cache()
self.is_initialized = False
logger.info("Runner do SeedVR2 descarregado da VRAM.")
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Destroyed torch.distributed process group.")
def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
model_version: str = '7B', steps: int = 100, seed: int = 666,
progress: gr.Progress = None) -> str:
try:
self._initialize_runner(model_version)
# Precisamos importar aqui, pois o sys.path é modificado no setup
from common.seed import set_seed
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
set_seed(seed, same_across_ranks=True)
self.runner.config.diffusion.timesteps.sampling.steps = steps
self.runner.configure_diffusion()
video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
res_h, res_w = video_tensor.shape[-2:]
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
cond_latents = [video_transform(video_tensor.to(self.device))]
input_videos = cond_latents
self.runner.dit.to("cpu")
self.runner.vae.to(self.device)
cond_latents = self.runner.vae_encode(cond_latents)
self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.dit.to(self.device)
pos_emb = torch.load(APP_ROOT / 'pos_emb.pt').to(self.device)
neg_emb = torch.load(APP_ROOT / 'neg_emb.pt').to(self.device)
text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
noises = [torch.randn_like(latent) for latent in cond_latents]
conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
self.runner.vae.to(self.device)
samples = self.runner.vae_decode(video_tensors)
final_sample = samples[0]
input_video_sample = input_videos[0]
if final_sample.shape[1] < input_video_sample.shape[1]:
input_video_sample = input_video_sample[:, :final_sample.shape[1]]
final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
final_sample = rearrange(final_sample, "t c h w -> t h w c")
final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
mediapy.write_video(output_video_path, final_sample_np, fps=24)
logger.info(f"HD Mastered video saved to: {output_video_path}")
return output_path
finally:
self._unload_runner()
def _load_file_from_url(url, model_dir='./', file_name=None):
os.makedirs(model_dir, exist_ok=True)
filename = file_name or os.path.basename(urlparse(url).path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
logger.info(f'Downloading: "{url}" to {cached_file}')
download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
return cached_file
seedvr_manager_singleton = SeedVrManager()