File size: 10,821 Bytes
6feecd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# managers/seedvr_manager.py
#
# Copyright (C) 2025 Carlos Rodrigues dos Santos
#
# Version: 4.0.0 (Root Installer & Executor)
#
# This version fully adopts the logic from the functional hd_specialist.py example.
# It acts as a setup manager: it clones the SeedVR repo and then copies all
# necessary directories (projects, common, models, configs, ckpts) to the
# application root. It also handles the pip installation of the Apex dependency.
# This ensures that the SeedVR code runs in the exact file structure it expects.

import torch
import torch.distributed as dist
import os
import gc
import logging
import sys
import subprocess
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file
import gradio as gr
import mediapy
from einops import rearrange
import shutil
from omegaconf import OmegaConf

logger = logging.getLogger(__name__)

# --- Caminhos Globais ---
APP_ROOT = Path("/home/user/app")
DEPS_DIR = APP_ROOT / "deps"
SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"

class SeedVrManager:
    def __init__(self, workspace_dir="deformes_workspace"):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.runner = None
        self.workspace_dir = workspace_dir
        self.is_initialized = False
        self._original_barrier = None
        self.setup_complete = False # Flag para rodar o setup apenas uma vez
        logger.info("SeedVrManager initialized. Setup will run on first use.")

    def _full_setup(self):
        """
        Executa todo o processo de setup uma única vez.
        """
        if self.setup_complete:
            return
        
        logger.info("--- Starting Full SeedVR Setup ---")
        
        # 1. Clonar o repositório se não existir
        if not SEEDVR_SPACE_DIR.exists():
            logger.info(f"Cloning SeedVR Space repo to {SEEDVR_SPACE_DIR}...")
            DEPS_DIR.mkdir(exist_ok=True, parents=True)
            subprocess.run(
                ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
                check=True, capture_output=True, text=True
            )
        
        # 2. Copiar as pastas necessárias para a raiz da aplicação
        required_dirs = ["projects", "common", "models", "configs_3b", "configs_7b"]
        for dirname in required_dirs:
            source = SEEDVR_SPACE_DIR / dirname
            target = APP_ROOT / dirname
            if not target.exists():
                logger.info(f"Copying '{dirname}' to application root...")
                shutil.copytree(source, target)
        
        # 3. Adicionar a raiz ao sys.path para garantir que os imports funcionem
        if str(APP_ROOT) not in sys.path:
            sys.path.insert(0, str(APP_ROOT))
            logger.info(f"Added '{APP_ROOT}' to sys.path.")

        # 4. Instalar dependências complexas como Apex
        try:
            import apex
            logger.info("Apex is already installed.")
        except ImportError:
            logger.info("Installing Apex dependency...")
            apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
            apex_wheel_path = _load_file_from_url(url=apex_url, model_dir=str(DEPS_DIR))
            subprocess.run(f"pip install {apex_wheel_path}", check=True, shell=True)
            logger.info("Apex installed successfully.")

        # 5. Baixar os modelos para a pasta ./ckpts na raiz
        ckpt_dir = APP_ROOT / 'ckpts'
        ckpt_dir.mkdir(exist_ok=True)
        pretrain_model_urls = {
            'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
            'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
            'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
            'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
            'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
        }
        for name, url in pretrain_model_urls.items():
            _load_file_from_url(url=url, model_dir=str(ckpt_dir))
            
        self.setup_complete = True
        logger.info("--- Full SeedVR Setup Complete ---")

    def _initialize_runner(self, model_version: str):
        if self.runner is not None: return

        # Garante que todo o ambiente está configurado antes de prosseguir
        self._full_setup()

        # Agora que o setup está feito, podemos importar os módulos
        from projects.video_diffusion_sr.infer import VideoDiffusionInfer
        from common.config import load_config
        from common.seed import set_seed
        
        if dist.is_available() and not dist.is_initialized():
            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "12355"
            os.environ["RANK"] = str(0)
            os.environ["WORLD_SIZE"] = str(1)
            dist.init_process_group(backend='gloo')
            logger.info("Initialized torch.distributed process group.")

        logger.info(f"Initializing SeedVR2 {model_version} runner...")
        if model_version == '3B':
            config_path = APP_ROOT / 'configs_3b' / 'main.yaml'
            checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_3b.pth'
        else: # Assumimos 7B
            config_path = APP_ROOT / 'configs_7b' / 'main.yaml'
            checkpoint_path = APP_ROOT / 'ckpts' / 'seedvr2_ema_7b.pth'

        config = load_config(str(config_path))
        
        self.runner = VideoDiffusionInfer(config)
        OmegaConf.set_readonly(self.runner.config, False)
        
        self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
        self.runner.configure_vae_model()
        
        if hasattr(self.runner.vae, "set_memory_limit"):
            self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
        
        self.is_initialized = True
        logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
        
    def _unload_runner(self):
        if self.runner is not None:
            del self.runner
            self.runner = None
            gc.collect()
            torch.cuda.empty_cache()
            self.is_initialized = False
            logger.info("Runner do SeedVR2 descarregado da VRAM.")
        if dist.is_initialized():
            dist.destroy_process_group()
            logger.info("Destroyed torch.distributed process group.")

    def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
                      model_version: str = '7B', steps: int = 100, seed: int = 666, 
                      progress: gr.Progress = None) -> str:
        try:
            self._initialize_runner(model_version)
            
            # Precisamos importar aqui, pois o sys.path é modificado no setup
            from common.seed import set_seed
            from data.image.transforms.divisible_crop import DivisibleCrop
            from data.image.transforms.na_resize import NaResize
            from data.video.transforms.rearrange import Rearrange
            from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
            from torchvision.transforms import Compose, Lambda, Normalize
            from torchvision.io.video import read_video

            set_seed(seed, same_across_ranks=True)
            self.runner.config.diffusion.timesteps.sampling.steps = steps
            self.runner.configure_diffusion()
            
            video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
            res_h, res_w = video_tensor.shape[-2:]
            video_transform = Compose([
                NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
                Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
                DivisibleCrop((16, 16)),
                Normalize(0.5, 0.5),
                Rearrange("t c h w -> c t h w"),
            ])
            cond_latents = [video_transform(video_tensor.to(self.device))]
            input_videos = cond_latents
            self.runner.dit.to("cpu")
            self.runner.vae.to(self.device)
            cond_latents = self.runner.vae_encode(cond_latents)
            self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
            self.runner.dit.to(self.device)
            
            pos_emb = torch.load(APP_ROOT / 'pos_emb.pt').to(self.device)
            neg_emb = torch.load(APP_ROOT / 'neg_emb.pt').to(self.device)
            text_embeds_dict = {"texts_pos": [pos_emb], "texts_neg": [neg_emb]}
            
            noises = [torch.randn_like(latent) for latent in cond_latents]
            conditions = [self.runner.get_condition(noise, latent_blur=latent, task="sr") for noise, latent in zip(noises, cond_latents)]
            
            with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
                video_tensors = self.runner.inference(noises=noises, conditions=conditions, dit_offload=True, **text_embeds_dict)
                
            self.runner.dit.to("cpu"); gc.collect(); torch.cuda.empty_cache()
            self.runner.vae.to(self.device)
            samples = self.runner.vae_decode(video_tensors)
            final_sample = samples[0]
            input_video_sample = input_videos[0]
            if final_sample.shape[1] < input_video_sample.shape[1]:
                input_video_sample = input_video_sample[:, :final_sample.shape[1]]
                
            final_sample = wavelet_reconstruction(rearrange(final_sample, "c t h w -> t c h w"), rearrange(input_video_sample, "c t h w -> t c h w"))
            final_sample = rearrange(final_sample, "t c h w -> t h w c")
            final_sample = final_sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
            final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
            
            mediapy.write_video(output_video_path, final_sample_np, fps=24)
            logger.info(f"HD Mastered video saved to: {output_video_path}")
            return output_path
        finally:
            self._unload_runner()

def _load_file_from_url(url, model_dir='./', file_name=None):
    os.makedirs(model_dir, exist_ok=True)
    filename = file_name or os.path.basename(urlparse(url).path)
    cached_file = os.path.abspath(os.path.join(model_dir, filename))
    if not os.path.exists(cached_file):
        logger.info(f'Downloading: "{url}" to {cached_file}')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
    return cached_file

seedvr_manager_singleton = SeedVrManager()