File size: 13,862 Bytes
97682d1
cc8649a
 
 
c9413de
9a6b3d7
 
cc8649a
07e6937
cc8649a
 
07e6937
 
 
cc8649a
 
 
07e6937
 
1cacf10
07e6937
6a43e4c
07e6937
5969983
80114de
5969983
07e6937
 
5969983
 
 
80114de
07e6937
 
 
 
 
5969983
1cacf10
07e6937
1cacf10
3aea503
 
 
 
d88f6e0
07e6937
d88f6e0
5ae8585
d88f6e0
07e6937
 
 
 
 
1cacf10
 
07e6937
1cacf10
 
a7e6912
cc8649a
 
1cacf10
cc8649a
 
 
 
 
 
1cacf10
c9413de
07e6937
cc8649a
07e6937
 
cc8649a
 
 
 
a7e6912
07e6937
 
 
 
 
 
 
cc8649a
07e6937
cc8649a
 
07e6937
cc8649a
 
07e6937
 
 
c9413de
b8d1239
9f023b7
07e6937
 
 
 
cc8649a
07e6937
 
9f023b7
07e6937
 
cc8649a
 
1cacf10
07e6937
9f023b7
07e6937
 
 
 
 
 
 
 
 
 
9f023b7
07e6937
 
 
 
b8d1239
cc8649a
 
 
 
07e6937
 
 
9f023b7
07e6937
cc8649a
 
 
 
07e6937
 
 
1fedca7
07e6937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f023b7
 
 
 
 
 
 
 
07e6937
 
 
 
 
 
 
 
 
 
 
 
 
9f023b7
 
 
07e6937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f023b7
07e6937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f023b7
 
 
 
 
 
 
 
 
 
 
07e6937
9f023b7
 
 
 
 
 
 
 
 
07e6937
 
 
 
 
 
 
 
 
 
9f023b7
a7e6912
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A dedicated, "hot" VAE service specialist.
# It holds the VAE model on a dedicated GPU and provides high-level services
# for encoding images/tensors into conditioning items and decoding latents back to pixels.

import os
import sys
import time
import copy
import threading
from pathlib import Path
from typing import List, Union, Tuple, Optional
from dataclasses import dataclass

import torch
import numpy as np
from PIL import Image
from einops import rearrange
import torch.nn.functional as F

from managers.gpu_manager import gpu_manager
from utils.debug_utils import log_function_io
from diffusers.utils.torch_utils import randn_tensor

import logging
import warnings

# --- Configuração de Logging e Warnings ---
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*")

try:
    from huggingface_hub import logging as hf_logging
    hf_logging.set_verbosity_error()
except ImportError:
    pass

# ==============================================================================
# --- IMPORTAÇÕES E DEFINIÇÕES DE TIPO ---
# ==============================================================================

LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
    sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))

from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode, latent_to_pixel_coords
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline 

@dataclass
class LatentConditioningItem:
    latent_tensor: torch.Tensor
    media_frame_number: int
    conditioning_strength: float

# ==============================================================================
# --- CLASSE PRINCIPAL DO SERVIÇO VAE ---
# ==============================================================================

class VaeAducPipeline:
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._initialized = False
            return cls._instance

    def __init__(self):
        if hasattr(self, '_initialized') and self._initialized: return
        with self._lock:
            if hasattr(self, '_initialized') and self._initialized: return
            logging.info("⚙️ Initializing VaeAducPipeline Singleton...")
            t0 = time.time()
            self.device = gpu_manager.get_ltx_vae_device()
            
            try:
                from api.ltx.ltx_aduc_manager import ltx_aduc_manager
                main_pipeline = ltx_aduc_manager.get_pipeline()
                if main_pipeline is None:
                    raise RuntimeError("LTXPoolManager must be initialized before VaeAducPipeline.")
                self.vae: CausalVideoAutoencoder = main_pipeline.vae
                self.patchifier = main_pipeline.patchifier
                self.transformer = main_pipeline.transformer
                self.vae_scale_factor = main_pipeline.vae_scale_factor
            except Exception as e:
                logging.critical(f"Failed to get components from LTXPoolManager. Error: {e}", exc_info=True)
                raise

            self.vae.to(self.device).eval()
            self.dtype = self.vae.dtype
            self._initialized = True
            logging.info(f"✅ VaeAducPipeline ready. Components are 'hot' on {self.device}. Startup time: {time.time() - t0:.2f}s")
    
    # --- MÉTODOS PÚBLICOS DE SERVIÇO ---

    @log_function_io
    def encode_video(self, video_tensor: torch.Tensor, vae_per_channel_normalize: bool = True) -> torch.Tensor:
        logging.info(f"VaeAducPipeline: Encoding video with shape {video_tensor.shape}")
        if not (video_tensor.ndim == 5):
            raise ValueError(f"Input video tensor must be 5D (B, C, F, H, W), but got shape {video_tensor.shape}")
        video_tensor_normalized = (video_tensor * 2.0) - 1.0
        try:
            video_gpu = video_tensor_normalized.to(self.device, dtype=self.dtype)
            with torch.no_grad():
                latents = vae_encode(video_gpu, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
            logging.info(f"VaeAducPipeline: Successfully encoded video to latents of shape {latents.shape}")
            return latents.cpu()
        finally:
            self._cleanup_gpu()

    @log_function_io
    def decode_and_resize_video(self, latent_tensor: torch.Tensor, target_height: int, target_width: int, decode_timestep: float = 0.05) -> torch.Tensor:
        logging.info(f"VaeAducPipeline: Decoding latents {latent_tensor.shape} and resizing to {target_height}x{target_width}")
        pixel_video = self.decode_to_pixels(latent_tensor, decode_timestep)
        num_frames = pixel_video.shape[2]
        current_height, current_width = pixel_video.shape[3:]

        if current_height == target_height and current_width == target_width:
            logging.info("VaeAducPipeline: Resizing skipped, already at target resolution.")
            return pixel_video

        videos_flat = rearrange(pixel_video, "b c f h w -> (b f) c h w")
        videos_resized = F.interpolate(videos_flat, size=(target_height, target_width), mode="bilinear", align_corners=False)
        final_video = rearrange(videos_resized, "(b f) c h w -> b c f h w", f=num_frames)
        logging.info(f"VaeAducPipeline: Resized video to final shape {final_video.shape}")
        return final_video

    @log_function_io
    def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        t0 = time.time()
        try:
            latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
            num_items = latent_tensor_gpu.shape[0]
            timestep_tensor = torch.tensor([decode_timestep] * num_items, device=self.device, dtype=self.dtype)
            with torch.no_grad():
                pixels = vae_decode(latent_tensor_gpu, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True)
            logging.info(f"VaeAducPipeline: Decoded latents {latent_tensor.shape} in {time.time() - t0:.2f}s.")
            return pixels.cpu()
        finally:
            self._cleanup_gpu()

    @log_function_io
    def prepare_conditioning(
        self,
        conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]],
        init_latents: torch.Tensor,
        num_frames: int,
        height: int,
        width: int,
        vae_per_channel_normalize: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
        init_latents = init_latents.to(self.device, dtype=self.dtype)

        if not conditioning_items:
            latents_p, coords_l = self.patchifier.patchify(latents=init_latents)
            coords_p = self._latent_to_pixel_coords(coords_l)
            return latents_p.cpu(), coords_p.cpu(), None, 0

        mask = torch.zeros(init_latents.shape[0], *init_latents.shape[2:], dtype=torch.float32, device=self.device)
        extra_latents, extra_coords, extra_masks = [], [], []
        num_extra_latents = 0

        is_latent_mode = isinstance(conditioning_items[0], LatentConditioningItem)

        with torch.no_grad():
            if is_latent_mode:
                for item in conditioning_items:
                    latents = item.latent_tensor.to(device=self.device, dtype=self.dtype)
                    if item.media_frame_number == 0:
                        f, h, w = latents.shape[-3:]
                        init_latents[..., :f, :h, :w] = torch.lerp(init_latents[..., :f, :h, :w], latents, item.conditioning_strength)
                        mask[..., :f, :h, :w] = item.conditioning_strength
                    else:
                        if latents.shape[2] > 1:
                            init_latents, mask, latents = self._handle_non_first_sequence(
                                init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
                            )
                        if latents is not None:
                            latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
                            extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
                            num_extra_latents += num_new
            else:
                for item in conditioning_items:
                    item_resized = self._resize_conditioning_item(item, height, width)
                    media_item = item_resized.media_item.to(self.device, dtype=self.dtype)
                    latents = vae_encode(media_item, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)

                    if item.media_frame_number == 0:
                        latents_pos, lx, ly = self._get_latent_spatial_position(latents, item_resized, height, width)
                        f, h, w = latents_pos.shape[-3:]
                        init_latents[..., :f, ly:ly+h, lx:lx+w] = torch.lerp(init_latents[..., :f, ly:ly+h, lx:lx+w], latents_pos, item.conditioning_strength)
                        mask[..., :f, ly:ly+h, lx:lx+w] = item.conditioning_strength
                    else:
                        if media_item.shape[2] > 1:
                            init_latents, mask, latents = self._handle_non_first_sequence(
                                init_latents, mask, latents, item.media_frame_number, item.conditioning_strength
                            )
                        if latents is not None:
                            latents_p, coords_p, new_mask, num_new = self._process_extra_item(latents, item, generator)
                            extra_latents.append(latents_p); extra_coords.append(coords_p); extra_masks.append(new_mask)
                            num_extra_latents += num_new

        # --- Consolidação final ---
        latents_p, coords_l = self.patchifier.patchify(latents=init_latents)
        coords_p = self._latent_to_pixel_coords(coords_l)
        mask_p, _ = self.patchifier.patchify(latents=mask.unsqueeze(1))
        mask_p = mask_p.squeeze(-1)

        if extra_latents:
            latents_p = torch.cat([*extra_latents, latents_p], dim=1)
            coords_p = torch.cat([*extra_coords, coords_p], dim=2)
            mask_p = torch.cat([*extra_masks, mask_p], dim=1)
            
            use_flash = getattr(self.transformer.config, 'use_tpu_flash_attention', False)
            if use_flash:
                latents_p = latents_p[:, :-num_extra_latents]
                coords_p = coords_p[:, :, :-num_extra_latents]
                mask_p = mask_p[:, :-num_extra_latents]
        
        return latents_p.cpu(), coords_p.cpu(), mask_p.cpu(), num_extra_latents

    # --- MÉTODOS PRIVADOS AUXILIARES ---
    def _cleanup_gpu(self):
        if torch.cuda.is_available():
            with torch.cuda.device(self.device): torch.cuda.empty_cache()

    def _latent_to_pixel_coords(self, c): return latent_to_pixel_coords(c, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
    
    @staticmethod
    def _resize_tensor(m, h, w):
        if m.shape[-2:] != (h, w):
            n = m.shape[2]
            flat = rearrange(m, "b c n h w -> (b n) c h w")
            resized = F.interpolate(flat, (h, w), mode="bilinear", align_corners=False)
            return rearrange(resized, "(b n) c h w -> b c n h w", n=n)
        return m

    def _resize_conditioning_item(self, i, h, w):
        n = copy.copy(i); n.media_item = self._resize_tensor(i.media_item, h, w); return n

    def _get_latent_spatial_position(self, l, i, h, w, strip=True):
        s, hi, wi = self.vae_scale_factor, i.media_item.shape[-2], i.media_item.shape[-1]
        xs = (w - wi) // 2 if i.media_x is None else i.media_x
        ys = (h - hi) // 2 if i.media_y is None else i.media_y
        if strip:
            if xs > 0: xs += s; l = l[..., :, 1:]
            if ys > 0: ys += s; l = l[..., 1:, :]
            if (xs + wi) < w: l = l[..., :, :-1]
            if (ys + hi) < h: l = l[..., :-1, :]
        return l, xs // s, ys // s

    def _handle_non_first_sequence(
        self,
        init_latents: torch.Tensor,
        mask: torch.Tensor,
        latents: torch.Tensor,
        media_frame_number: int,
        conditioning_strength: float,
        num_prefix=2,
        mode="concat"
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        fl, flp = latents.shape[2], num_prefix
        if fl > flp:
            start = media_frame_number // 8 + flp
            end = start + fl - flp
            init_latents[..., start:end, :, :] = torch.lerp(init_latents[..., start:end, :, :], latents[..., flp:, :, :], conditioning_strength)
            mask[..., start:end, :, :] = conditioning_strength
        if mode == "concat":
            latents = latents[..., :flp, :, :]
        else:
            latents = None
        return init_latents, mask, latents

    def _process_extra_item(self, l, i, g):
        n = randn_tensor(l.shape, generator=g, device=self.device, dtype=self.dtype)
        l = torch.lerp(n, l, i.conditioning_strength)
        lp, cl = self.patchifier.patchify(l)
        cp = self._latent_to_pixel_coords(cl); cp[:, 0] += i.media_frame_number
        nl = lp.shape[1]
        nm = torch.full(lp.shape[:2], i.conditioning_strength, dtype=torch.float32, device=self.device)
        return lp, cp, nm, nl

# --- Instânciação do Singleton ---
vae_aduc_pipeline = VaeAducPipeline()