Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from diffusers import ( | |
| AutoencoderKL, | |
| AutoencoderKLTemporalDecoder, | |
| StableDiffusionPipeline, | |
| ) | |
| def default(value, default_value): | |
| return default_value if value is None else value | |
| def load_stable_model(model_path): | |
| vae_model = StableDiffusionPipeline.from_pretrained(model_path) | |
| vae_model.set_use_memory_efficient_attention_xformers(True) | |
| return vae_model.vae | |
| def process_image(image: torch.Tensor, resolution=None) -> torch.Tensor: | |
| """ | |
| Process image tensor by resizing and normalizing. | |
| Args: | |
| image: Input image tensor | |
| resolution: Target resolution for resizing | |
| Returns: | |
| Processed image tensor | |
| """ | |
| if resolution is not None: | |
| image = torch.nn.functional.interpolate( | |
| image.float(), size=resolution, mode="bilinear", align_corners=False | |
| ) | |
| return image / 127.5 - 1.0 | |
| def encode_video_chunk( | |
| model, | |
| video, | |
| target_resolution, | |
| ) -> torch.Tensor: | |
| """ | |
| Encode a chunk of video frames into latent space. | |
| Args: | |
| model: VAE model for encoding | |
| video: Video tensor to encode | |
| target_resolution: Target resolution for processing | |
| Returns: | |
| Encoded latent tensor | |
| """ | |
| video = rearrange(video, "t h w c -> c t h w") | |
| vid_rez = min(video.shape[-1], video.shape[-2]) | |
| to_rez = default(target_resolution, vid_rez) | |
| video = process_image(video, to_rez) | |
| encoded = model.encode_video(video.cuda().unsqueeze(0)).squeeze(0) | |
| return rearrange(encoded, "c t h w -> t c h w") | |
| class VaeWrapper(nn.Module): | |
| def __init__(self, latent_type, max_chunk_decode=16, variant="fp16"): | |
| super().__init__() | |
| self.vae_model = self.get_vae(latent_type, variant) | |
| # self.latent_scale = latent_scale | |
| self.latent_type = latent_type | |
| self.max_chunk_decode = max_chunk_decode | |
| def get_vae(self, latent_type, variant="fp16"): | |
| if latent_type == "stable": | |
| vae_model = load_stable_model("stabilityai/stable-diffusion-x4-upscaler") | |
| vae_model.enable_slicing() | |
| vae_model.set_use_memory_efficient_attention_xformers(True) | |
| self.down_factor = 4 | |
| elif latent_type == "video": | |
| vae_model = AutoencoderKLTemporalDecoder.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid", | |
| subfolder="vae", | |
| torch_dtype=torch.float16 if variant == "fp16" else torch.float32, | |
| variant="fp16" if variant == "fp16" else None, | |
| ) | |
| vae_model.set_use_memory_efficient_attention_xformers(True) | |
| self.down_factor = 8 | |
| elif latent_type == "refiner": | |
| vae_model = AutoencoderKL.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-refiner-1.0", | |
| subfolder="vae", | |
| revision=None, | |
| ) | |
| vae_model.enable_slicing() | |
| vae_model.set_use_memory_efficient_attention_xformers(True) | |
| self.down_factor = 8 | |
| vae_model.eval() | |
| vae_model.requires_grad_(False) | |
| vae_model.cuda() | |
| vae_model = torch.compile(vae_model) | |
| return vae_model | |
| # def accelerate_model(self, example_shape): | |
| # self.vae_model = torch.jit.trace(self.vae_model, torch.randn(example_shape).cuda()) | |
| # self.vae_model = torch.compile(self.vae_model) | |
| # self.is_accelerated = True | |
| def disable_slicing(self): | |
| self.vae_model.disable_slicing() | |
| def encode_video(self, video): | |
| """ | |
| video: (B, C, T, H, W) | |
| """ | |
| is_video = False | |
| if len(video.shape) == 5: | |
| is_video = True | |
| T = video.shape[2] | |
| video = rearrange(video, "b c t h w -> (b t) c h w") | |
| or_dtype = video.dtype | |
| # if not self.is_accelerated: | |
| # self.accelerate_model(video.shape) | |
| if self.latent_type in ["stable", "refiner", "video"]: | |
| encoded_video = ( | |
| self.vae_model.encode(video.to(dtype=self.vae_model.dtype)) | |
| .latent_dist.sample() | |
| .to(dtype=or_dtype) | |
| * self.vae_model.config.scaling_factor | |
| ) | |
| elif self.latent_type == "ldm": | |
| encoded_video = self.vae_model.encode_first_stage(video) * 0.18215 | |
| if not is_video: | |
| return encoded_video | |
| return rearrange(encoded_video, "(b t) c h w -> b c t h w", t=T) | |
| def decode_video(self, encoded_video): | |
| """ | |
| encoded_video: (B, C, T, H, W) | |
| """ | |
| is_video = False | |
| B, T = encoded_video.shape[0], 1 | |
| if len(encoded_video.shape) == 5: | |
| is_video = True | |
| T = encoded_video.shape[2] | |
| encoded_video = rearrange(encoded_video, "b c t h w -> (b t) c h w") | |
| decoded_full = [] | |
| or_dtype = encoded_video.dtype | |
| for i in range(0, T * B, self.max_chunk_decode): # Slow but no memory issues | |
| if self.latent_type in ["stable", "refiner"]: | |
| decoded_full.append( | |
| self.vae_model.decode( | |
| (1 / self.vae_model.config.scaling_factor) | |
| * encoded_video[i : i + self.max_chunk_decode] | |
| ).sample | |
| ) | |
| elif self.latent_type == "video": | |
| chunk = encoded_video[i : i + self.max_chunk_decode].to( | |
| dtype=self.vae_model.dtype | |
| ) | |
| num_frames_in = chunk.shape[0] | |
| decode_kwargs = {} | |
| decode_kwargs["num_frames"] = num_frames_in | |
| decoded_full.append( | |
| self.vae_model.decode( | |
| 1 / self.vae_model.config.scaling_factor * chunk, | |
| **decode_kwargs, | |
| ).sample.to(or_dtype) | |
| ) | |
| elif self.latent_type == "ldm": | |
| decoded_full.append( | |
| self.vae_model.decode_first_stage( | |
| 1 / 0.18215 * encoded_video[i : i + self.max_chunk_decode] | |
| ) | |
| ) | |
| decoded_video = torch.cat(decoded_full, dim=0) | |
| if not is_video: | |
| return decoded_video.clamp(-1.0, 1.0) | |
| return rearrange(decoded_video, "(b t) c h w -> b c t h w", t=T).clamp( | |
| -1.0, 1.0 | |
| ) | |