File size: 7,274 Bytes
97682d1
cc8649a
 
 
c9413de
9a6b3d7
 
cc8649a
5969983
cc8649a
 
dcaa934
dcda1c6
cc8649a
 
 
a7e6912
 
 
1cacf10
5969983
 
 
 
 
 
 
 
 
 
 
1cacf10
cc8649a
1cacf10
3aea503
 
 
 
 
a7e6912
 
1cacf10
 
cc8649a
1cacf10
 
a7e6912
cc8649a
 
1cacf10
cc8649a
 
 
 
 
 
1cacf10
c9413de
cc8649a
 
 
 
 
 
 
 
 
 
 
 
 
a7e6912
 
cc8649a
a7e6912
cc8649a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9413de
cc8649a
 
c9413de
cc8649a
 
 
 
 
c9413de
cc8649a
 
c9413de
 
cc8649a
 
 
 
 
 
 
 
 
 
 
 
 
 
c9413de
cc8649a
 
1cacf10
cc8649a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e6912
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A dedicated, "hot" VAE service specialist.
# It holds the VAE model on a dedicated GPU and provides high-level services
# for encoding images/tensors into conditioning items and decoding latents back to pixels.

import os
import sys
import time

import threading
from pathlib import Path
from typing import List, Union, Tuple, Dict, Optional
import yaml
import torch
import numpy as np
from PIL import Image
from api.ltx.ltx_aduc_manager import LatentConditioningItem
from managers.gpu_manager import gpu_manager
from api.ltx.ltx_aduc_manager import ltx_aduc_manager


import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*")
from huggingface_hub import logging
logging.set_verbosity_error()
logging.set_verbosity_warning()
logging.set_verbosity_info()
logging.set_verbosity_debug()

# ==============================================================================
# --- IMPORTAÇÕES DA ARQUITETURA E DO LTX ---
# ==============================================================================

# Adiciona o path para as bibliotecas do LTX
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
    sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
    from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
    from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode

# ==============================================================================
# --- CLASSE DO SERVIÇO VAE ---
# ==============================================================================

class VaeAducPipeline:
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._initialized = False
            return cls._instance

    def __init__(self):
        if self._initialized: return
        with self._lock:
            if self._initialized: return
        
            logging.info("⚙️ Initializing VaeServer Singleton...")
            t0 = time.time()
            
            # 1. Obter o dispositivo VAE dedicado do gerenciador central
            self.device = gpu_manager.get_ltx_vae_device()
            
            # 2. Obter o modelo VAE já carregado pelo LTXPoolManager
            #    Isso garante consistência e evita carregar o modelo duas vezes.
            try:
                from api.ltx.ltx_aduc_manager import ltx_aduc_manager
                if ltx_aduc_manager is None or ltx_aduc_manager.get_pipeline() is None:
                    raise RuntimeError("LTXPoolManager is not initialized yet. VaeServer must be initialized after.")
                self.vae = ltx_aduc_manager.get_pipeline().vae
            except Exception as e:
                logging.critical(f"Failed to get VAE from LTXPoolManager. Error: {e}", exc_info=True)
                raise

            # 3. Garante que o VAE está no dispositivo correto e em modo de avaliação
            self.vae.to(self.device)
            self.vae.eval()
            self.dtype = self.vae.dtype
            
            self._initialized = True
            logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")

    def _cleanup_gpu(self):
        """Limpa a VRAM da GPU do VAE."""
        if torch.cuda.is_available():
            with torch.cuda.device(self.device):
                torch.cuda.empty_cache()

    def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
        """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar."""
        if isinstance(item, Image.Image):
            from PIL import ImageOps
            img = item.convert("RGB")
            processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
            image_np = np.array(processed_img).astype(np.float32) / 255.0
            tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
        elif isinstance(item, torch.Tensor):
            if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0)
            elif item.ndim == 3: tensor = item
            else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
        else:
            raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")

        # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
        tensor_5d = tensor.unsqueeze(0).unsqueeze(2)
        return (tensor_5d * 2.0) - 1.0

    @torch.no_grad()
    def generate_conditioning_items(
        self,
        media_items: List[Union[Image.Image, torch.Tensor]],
        target_frames: List[int],
        strengths: List[float],
        target_resolution: Tuple[int, int]
    ) -> List[LatentConditioningItem]:
        """
        [FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel)
        em uma lista de LatentConditioningItem, pronta para a pipeline LTX corrigida.
        """
        t0 = time.time()
        logging.info(f"VaeServer: Generating {len(media_items)} latent conditioning items...")
        
        if not (len(media_items) == len(target_frames) == len(strengths)):
            raise ValueError("Input lists for conditioning items must have the same length.")
        
        conditioning_items = []
        try:
            for item, frame, strength in zip(media_items, target_frames, strengths):
                pixel_tensor = self._preprocess_input(item, target_resolution)
                pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
                latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
                conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))

            logging.info(f"VaeServer: Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
            return conditioning_items
        finally:
            self._cleanup_gpu()

    @torch.no_grad()
    def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
        t0 = time.time()
        try:
            latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
            num_items_in_batch = latent_tensor_gpu.shape[0]
            timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
            
            pixels = vae_decode(
                latent_tensor_gpu, self.vae, is_video=True, 
                timestep=timestep_tensor, vae_per_channel_normalize=True
            )
            logging.info(f"VaeServer: Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
            return pixels.cpu()
        finally:
            self._cleanup_gpu()

vae_aduc_pipeline = VaeAducPipeline()