File size: 10,909 Bytes
a9f0d89
 
 
 
52d1c8b
a5720bf
 
 
52d1c8b
 
 
a5720bf
52d1c8b
a9f0d89
a5720bf
a9f0d89
 
 
d158086
 
a9f0d89
 
d158086
a9f0d89
 
 
 
 
 
 
 
 
 
 
52d1c8b
a9f0d89
 
 
 
52d1c8b
 
 
a9f0d89
52d1c8b
 
 
 
a9f0d89
52d1c8b
 
 
 
 
a9f0d89
 
 
 
d158086
a9f0d89
 
 
52d1c8b
a9f0d89
d158086
52d1c8b
 
a9f0d89
52d1c8b
 
 
 
 
a9f0d89
 
52d1c8b
 
 
a9f0d89
52d1c8b
 
a9f0d89
52d1c8b
 
a9f0d89
52d1c8b
 
 
a9f0d89
52d1c8b
 
 
 
a9f0d89
52d1c8b
 
 
 
 
a9f0d89
52d1c8b
 
 
 
 
a9f0d89
 
52d1c8b
 
 
 
 
 
a9f0d89
 
52d1c8b
 
 
a9f0d89
 
 
52d1c8b
a5720bf
a9f0d89
a5720bf
52d1c8b
a9f0d89
d158086
 
 
 
a9f0d89
52d1c8b
a9f0d89
 
 
a5720bf
a9f0d89
52d1c8b
a9f0d89
52d1c8b
a9f0d89
 
 
52d1c8b
 
 
d158086
52d1c8b
a9f0d89
52d1c8b
d158086
a9f0d89
 
d158086
52d1c8b
d158086
a9f0d89
d158086
 
 
a9f0d89
d158086
a9f0d89
d158086
 
 
 
 
52d1c8b
d158086
a9f0d89
d158086
 
a5720bf
a9f0d89
d158086
 
 
 
 
 
 
 
a9f0d89
 
a5720bf
d158086
52d1c8b
d158086
a9f0d89
52d1c8b
d158086
a9f0d89
d158086
a9f0d89
d158086
 
52d1c8b
d158086
 
 
a9f0d89
d158086
52d1c8b
d158086
a9f0d89
 
d158086
a9f0d89
 
 
 
 
 
 
 
 
 
52d1c8b
a9f0d89
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# FILE: api/seedvr_server.py
# DESCRIPTION: Backend service for SeedVR video upscaling.
# Features multi-GPU processing, memory swapping with other services,
# and detailed debug logging.

import os
import sys
import time
import subprocess
import queue
import multiprocessing as mp
from pathlib import Path
from typing import Optional, Callable
import logging

# ==============================================================================
# --- IMPORTAÇÃO DOS MÓDulos Compartilhados ---
# ==============================================================================
try:
    from api.gpu_manager import gpu_manager
    from api.ltx_server_refactored_complete import video_generation_service
    from api.utils.debug_utils import log_function_io
except ImportError:
    # Fallback para o decorador caso o import falhe
    def log_function_io(func):
        return func
    logging.critical("CRITICAL: Failed to import shared modules like gpu_manager or video_generation_service.", exc_info=True)
    # Em um cenário real, poderíamos querer sair aqui ou desativar o servidor.
    # Por enquanto, a aplicação pode tentar continuar sem o SeedVR.
    raise

# ==============================================================================
# --- CONFIGURAÇÃO DE AMBIENTE ---
# ==============================================================================
if mp.get_start_method(allow_none=True) != 'spawn':
    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        logging.warning("Multiprocessing context is already set. Skipping.")

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")

# Adiciona o caminho do repositório SeedVR ao sys.path
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
    sys.path.insert(0, str(SEEDVR_REPO_PATH))

# Imports pesados após a configuração de path e multiprocessing
import torch
import cv2
import numpy as np
from datetime import datetime

# ==============================================================================
# --- FUNÇÕES WORKER E AUXILIARES (I/O de Vídeo) ---
# ==============================================================================
# (Estas funções são de baixo nível e não precisam do decorador de log principal)

def extract_frames_from_video(video_path, debug=False):
    if debug: logging.debug(f"🎬 [SeedVR] Extracting frames from: {video_path}")
    if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found: {video_path}")
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): raise IOError(f"Cannot open video file: {video_path}")
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret: break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame.astype(np.float32) / 255.0)
    cap.release()
    if not frames: raise ValueError(f"No frames extracted from: {video_path}")
    if debug: logging.debug(f"✅ [SeedVR] {len(frames)} frames extracted successfully.")
    return torch.from_numpy(np.stack(frames)).to(torch.float16), fps

def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
    if debug: logging.debug(f"💾 [SeedVR] Saving {frames_tensor.shape[0]} frames to: {output_path}")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
    _, H, W, _ = frames_np.shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
    if not out.isOpened(): raise IOError(f"Cannot create video writer for: {output_path}")
    for frame in frames_np:
        out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    out.release()
    if debug: logging.debug(f"✅ [SeedVR] Video saved successfully: {output_path}")

def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
    """Processo filho (worker) que executa o upscaling em uma GPU dedicada."""
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    # É importante reimportar torch aqui para que ele respeite a variável de ambiente
    import torch
    from src.core.model_manager import configure_runner
    from src.core.generation import generation_loop
    
    try:
        frames_tensor = torch.from_numpy(frames_np).to('cuda', dtype=torch.float16)
        callback = (lambda b, t, _, m: progress_queue.put((proc_idx, b, t, m))) if progress_queue else None

        runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
        result_tensor = generation_loop(
            runner=runner, images=frames_tensor, cfg_scale=1.0, seed=shared_args["seed"],
            res_h=shared_args["resolution"], # Assumindo que a UI passa a altura
            batch_size=shared_args["batch_size"],
            preserve_vram=shared_args["preserve_vram"], temporal_overlap=0,
            debug=shared_args["debug"], progress_callback=callback
        )
        return_queue.put((proc_idx, result_tensor.cpu().numpy()))
    except Exception as e:
        import traceback
        error_msg = f"ERROR in worker {proc_idx} (GPU {device_id}): {e}\n{traceback.format_exc()}"
        logging.error(error_msg)
        if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
        return_queue.put((proc_idx, error_msg))

# ==============================================================================
# --- CLASSE DO SERVIDOR PRINCIPAL ---
# ==============================================================================

class SeedVRServer:
    @log_function_io
    def __init__(self, **kwargs):
        """Inicializa o servidor, define os caminhos e prepara o ambiente."""
        logging.info("⚙️ SeedVRServer initializing...")
        self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/output"))
        
        self.device_list = gpu_manager.get_seedvr_devices()
        self.num_gpus = len(self.device_list)
        logging.info(f"[SeedVR] Allocated to use {self.num_gpus} GPU(s): {self.device_list}")

        # O setup de dependências já é feito pelo setup.py principal, então aqui apenas verificamos
        if not SEEDVR_REPO_PATH.is_dir():
            raise NotADirectoryError(f"SeedVR repository not found at {SEEDVR_REPO_PATH}. Run setup.py first.")
        
        logging.info("📦 SeedVRServer ready.")

    @log_function_io
    def run_inference(
        self, file_path: str, *, seed: int, resolution: int, batch_size: int,
        model: str = "seedvr2_ema_7b_sharp_fp16.safetensors", fps: Optional[float] = None,
        debug: bool = True, preserve_vram: bool = True,
        progress: Optional[Callable] = None
    ) -> str:
        """
        Executa o pipeline completo de upscaling de vídeo, gerenciando a memória da GPU.
        """
        if progress: progress(0.01, "⌛ Initializing SeedVR inference...")

        if gpu_manager.requires_memory_swap():
            logging.warning("[SWAP] Memory swapping is active. Moving LTX service to CPU to free VRAM for SeedVR.")
            if progress: progress(0.02, "🔄 Freeing VRAM for SeedVR...")
            video_generation_service.move_to_cpu()

        try:
            if progress: progress(0.05, "🎬 Extracting frames from video...")
            frames_tensor, original_fps = extract_frames_from_video(file_path, debug)

            if self.num_gpus == 0:
                raise RuntimeError("SeedVR requires at least 1 allocated GPU, but found none.")
            
            logging.info(f"[SeedVR] Splitting {frames_tensor.shape[0]} frames into {self.num_gpus} chunks for parallel processing.")
            chunks = torch.chunk(frames_tensor, self.num_gpus, dim=0)
            
            manager = mp.Manager()
            return_queue = manager.Queue()
            progress_queue = manager.Queue() if progress else None
            
            shared_args = {
                "model": model, "model_dir": "/data/models/SeedVR", "preserve_vram": preserve_vram,
                "debug": debug, "seed": seed, "resolution": resolution, "batch_size": batch_size
            }

            if progress: progress(0.1, f"🚀 Starting generation on {self.num_gpus} GPU(s)...")
            workers = []
            for idx, device_id in enumerate(self.device_list):
                p = mp.Process(target=_worker_process, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
                p.start()
                workers.append(p)
            
            results_np = [None] * self.num_gpus
            finished_workers = 0
            # (Loop de monitoramento de progresso e coleta de resultados)
            # ...

            for p in workers: p.join()

            if any(r is None for r in results_np):
                raise RuntimeError("One or more workers failed to return a result.")

            result_tensor = torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
            if progress: progress(0.95, "💾 Saving final video...")
            
            out_dir = self.OUTPUT_ROOT / f"seedvr_run_{int(time.time())}"
            out_dir.mkdir(parents=True, exist_ok=True)
            output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"

            final_fps = fps if fps and fps > 0 else original_fps
            save_frames_to_video(result_tensor, str(output_filepath), final_fps, debug)
            
            logging.info(f"✅ Video successfully saved to: {output_filepath}")
            return str(output_filepath)

        finally:
            # --- CORREÇÃO IMPORTANTE ---
            # Restaura o LTX para seus dispositivos corretos (main e vae)
            if gpu_manager.requires_memory_swap():
                logging.warning("[SWAP] SeedVR inference finished. Moving LTX service back to GPU(s)...")
                if progress: progress(0.99, "🔄 Restoring LTX environment...")
                ltx_main_device = gpu_manager.get_ltx_device()
                ltx_vae_device = gpu_manager.get_ltx_vae_device()
                # Chama a função move_to_device com os dois dispositivos
                video_generation_service.move_to_device(
                    main_device_str=str(ltx_main_device),
                    vae_device_str=str(ltx_vae_device)
                )
                logging.info(f"[SWAP] LTX service restored to Main: {ltx_main_device}, VAE: {ltx_vae_device}.")

# --- PONTO DE ENTRADA E INSTANCIAÇÃO ---
# A instância é criada na primeira importação.
try:
    # A classe é instanciada globalmente para ser usada pela UI
    seedvr_server_singleton = SeedVRServer()
except Exception as e:
    logging.critical("Failed to initialize SeedVRServer singleton.", exc_info=True)
    seedvr_server_singleton = None