Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import sys | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from collections import OrderedDict | |
| from einops import rearrange | |
| from diffusers.utils.torch_utils import randn_tensor | |
| import numpy as np | |
| import math | |
| import random | |
| import PIL | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torchvision import transforms | |
| from copy import deepcopy | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from accelerate import Accelerator | |
| from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler | |
| from video_vae.modeling_causal_vae import CausalVideoVAE | |
| from trainer_misc import ( | |
| all_to_all, | |
| is_sequence_parallel_initialized, | |
| get_sequence_parallel_group, | |
| get_sequence_parallel_group_rank, | |
| get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_rank, | |
| ) | |
| from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT | |
| from .modeling_text_encoder import SD3TextEncoderWithMask | |
| def compute_density_for_timestep_sampling( | |
| weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None | |
| ): | |
| if weighting_scheme == "logit_normal": | |
| # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). | |
| u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") | |
| u = torch.nn.functional.sigmoid(u) | |
| elif weighting_scheme == "mode": | |
| u = torch.rand(size=(batch_size,), device="cpu") | |
| u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) | |
| else: | |
| u = torch.rand(size=(batch_size,), device="cpu") | |
| return u | |
| class PyramidDiTForVideoGeneration: | |
| """ | |
| The pyramid dit for both image and video generation, The running class wrapper | |
| This class is mainly for fixed unit implementation: 1 + n + n + n | |
| """ | |
| def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True, | |
| model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1], | |
| sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False, | |
| load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True, | |
| corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs, | |
| ): | |
| super().__init__() | |
| if model_dtype == 'bf16': | |
| torch_dtype = torch.bfloat16 | |
| elif model_dtype == 'fp16': | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| self.stages = stages | |
| self.sample_ratios = sample_ratios | |
| self.corrupt_ratio = corrupt_ratio | |
| dit_path = os.path.join(model_path, model_variant) | |
| # The dit | |
| if use_mixed_training: | |
| print("using mixed precision training, do not explicitly casting models") | |
| self.dit = PyramidDiffusionMMDiT.from_pretrained( | |
| dit_path, use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_flash_attn=use_flash_attn, use_t5_mask=True, | |
| add_temp_pos_embed=True, temp_pos_embed_type='rope', | |
| use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos, | |
| ) | |
| else: | |
| print("using half precision") | |
| self.dit = PyramidDiffusionMMDiT.from_pretrained( | |
| dit_path, torch_dtype=torch_dtype, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_flash_attn=use_flash_attn, use_t5_mask=True, | |
| add_temp_pos_embed=True, temp_pos_embed_type='rope', | |
| use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos, | |
| ) | |
| # The text encoder | |
| if load_text_encoder: | |
| self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype) | |
| else: | |
| self.text_encoder = None | |
| # The base video vae decoder | |
| if load_vae: | |
| self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False) | |
| # Freeze vae | |
| for parameter in self.vae.parameters(): | |
| parameter.requires_grad = False | |
| else: | |
| self.vae = None | |
| # For the image latent | |
| self.vae_shift_factor = 0.1490 | |
| self.vae_scale_factor = 1 / 1.8415 | |
| # For the video latent | |
| self.vae_video_shift_factor = -0.2343 | |
| self.vae_video_scale_factor = 1 / 3.0986 | |
| self.downsample = 8 | |
| # Configure the video training hyper-parameters | |
| # The video sequence: one frame + N * unit | |
| self.frame_per_unit = frame_per_unit | |
| self.max_temporal_length = max_temporal_length | |
| assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit" | |
| self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios)) | |
| self.scheduler = PyramidFlowMatchEulerDiscreteScheduler( | |
| shift=timestep_shift, stages=len(self.stages), | |
| stage_range=stage_range, gamma=scheduler_gamma, | |
| ) | |
| print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}") | |
| self.cfg_rate = 0.1 | |
| self.return_log = return_log | |
| self.use_flash_attn = use_flash_attn | |
| def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs): | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| dit_checkpoint = OrderedDict() | |
| for key in checkpoint: | |
| if key.startswith('vae') or key.startswith('text_encoder'): | |
| continue | |
| if key.startswith('dit'): | |
| new_key = key.split('.') | |
| new_key = '.'.join(new_key[1:]) | |
| dit_checkpoint[new_key] = checkpoint[key] | |
| else: | |
| dit_checkpoint[key] = checkpoint[key] | |
| load_result = self.dit.load_state_dict(dit_checkpoint, strict=True) | |
| print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}") | |
| def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'): | |
| checkpoint = torch.load(vae_checkpoint_path, map_location='cpu') | |
| checkpoint = checkpoint[model_key] | |
| loaded_checkpoint = OrderedDict() | |
| for key in checkpoint.keys(): | |
| if key.startswith('vae.'): | |
| new_key = key.split('.') | |
| new_key = '.'.join(new_key[1:]) | |
| loaded_checkpoint[new_key] = checkpoint[key] | |
| load_result = self.vae.load_state_dict(loaded_checkpoint) | |
| print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}") | |
| def get_pyramid_latent(self, x, stage_num): | |
| # x is the origin vae latent | |
| vae_latent_list = [] | |
| vae_latent_list.append(x) | |
| temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1] | |
| for _ in range(stage_num): | |
| height //= 2 | |
| width //= 2 | |
| x = rearrange(x, 'b c t h w -> (b t) c h w') | |
| x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear') | |
| x = rearrange(x, '(b t) c h w -> b c t h w', t=temp) | |
| vae_latent_list.append(x) | |
| vae_latent_list = list(reversed(vae_latent_list)) | |
| return vae_latent_list | |
| def prepare_latents( | |
| self, | |
| batch_size, | |
| num_channels_latents, | |
| temp, | |
| height, | |
| width, | |
| dtype, | |
| device, | |
| generator, | |
| ): | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| int(temp), | |
| int(height) // self.downsample, | |
| int(width) // self.downsample, | |
| ) | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| return latents | |
| def sample_block_noise(self, bs, ch, temp, height, width): | |
| gamma = self.scheduler.config.gamma | |
| dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma) | |
| block_number = bs * ch * temp * (height // 2) * (width // 2) | |
| noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4] | |
| noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2) | |
| return noise | |
| def generate_one_unit( | |
| self, | |
| latents, | |
| past_conditions, # List of past conditions, contains the conditions of each stage | |
| prompt_embeds, | |
| prompt_attention_mask, | |
| pooled_prompt_embeds, | |
| num_inference_steps, | |
| height, | |
| width, | |
| temp, | |
| device, | |
| dtype, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| is_first_frame: bool = False, | |
| ): | |
| stages = self.stages | |
| intermed_latents = [] | |
| for i_s in range(len(stages)): | |
| self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device) | |
| timesteps = self.scheduler.timesteps | |
| if i_s > 0: | |
| height *= 2; width *= 2 | |
| latents = rearrange(latents, 'b c t h w -> (b t) c h w') | |
| latents = F.interpolate(latents, size=(height, width), mode='nearest') | |
| latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) | |
| # Fix the stage | |
| ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal | |
| gamma = self.scheduler.config.gamma | |
| alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) | |
| beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) | |
| bs, ch, temp, height, width = latents.shape | |
| noise = self.sample_block_noise(bs, ch, temp, height, width) | |
| noise = noise.to(device=device, dtype=dtype) | |
| latents = alpha * latents + beta * noise # To fix the block artifact | |
| for idx, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) | |
| latent_model_input = past_conditions[i_s] + [latent_model_input] | |
| noise_pred = self.dit( | |
| sample=[latent_model_input], | |
| timestep_ratio=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_attention_mask=prompt_attention_mask, | |
| pooled_projections=pooled_prompt_embeds, | |
| ) | |
| noise_pred = noise_pred[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| if is_first_frame: | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| else: | |
| noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step( | |
| model_output=noise_pred, | |
| timestep=timestep, | |
| sample=latents, | |
| generator=generator, | |
| ).prev_sample | |
| intermed_latents.append(latents) | |
| return intermed_latents | |
| def generate_i2v( | |
| self, | |
| prompt: Union[str, List[str]] = '', | |
| input_image: PIL.Image = None, | |
| temp: int = 1, | |
| num_inference_steps: Optional[Union[int, List[int]]] = 28, | |
| guidance_scale: float = 7.0, | |
| video_guidance_scale: float = 4.0, | |
| min_guidance_scale: float = 2.0, | |
| use_linear_guidance: bool = False, | |
| alpha: float = 0.5, | |
| negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| output_type: Optional[str] = "pil", | |
| save_memory: bool = True, | |
| ): | |
| device = self.device | |
| dtype = self.dtype | |
| width = input_image.width | |
| height = input_image.height | |
| assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit" | |
| if isinstance(prompt, str): | |
| batch_size = 1 | |
| prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics | |
| else: | |
| assert isinstance(prompt, list) | |
| batch_size = len(prompt) | |
| prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt] | |
| if isinstance(num_inference_steps, int): | |
| num_inference_steps = [num_inference_steps] * len(self.stages) | |
| negative_prompt = negative_prompt or "" | |
| # Get the text embeddings | |
| prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) | |
| negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) | |
| if use_linear_guidance: | |
| max_guidance_scale = guidance_scale | |
| guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)] | |
| print(guidance_scale_list) | |
| self._guidance_scale = guidance_scale | |
| self._video_guidance_scale = video_guidance_scale | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
| prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
| # Create the initial random noise | |
| num_channels_latents = self.dit.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| temp, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| ) | |
| temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1] | |
| latents = rearrange(latents, 'b c t h w -> (b t) c h w') | |
| # by defalut, we needs to start from the block noise | |
| for _ in range(len(self.stages)-1): | |
| height //= 2;width //= 2 | |
| latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2 | |
| latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) | |
| num_units = temp // self.frame_per_unit | |
| stages = self.stages | |
| # encode the image latents | |
| image_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), | |
| ]) | |
| input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) # [b c 1 h w] | |
| input_image_latent = (self.vae.encode(input_image_tensor.to(device)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor # [b c 1 h w] | |
| generated_latents_list = [input_image_latent] # The generated results | |
| last_generated_latents = input_image_latent | |
| for unit_index in tqdm(range(1, num_units + 1)): | |
| if use_linear_guidance: | |
| self._guidance_scale = guidance_scale_list[unit_index] | |
| self._video_guidance_scale = guidance_scale_list[unit_index] | |
| # prepare the condition latents | |
| past_condition_latents = [] | |
| clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1) | |
| for i_s in range(len(stages)): | |
| last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:] | |
| stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent] | |
| # pad the past clean latents | |
| cur_unit_num = unit_index | |
| cur_stage = i_s | |
| cur_unit_ptx = 1 | |
| while cur_unit_ptx < cur_unit_num: | |
| cur_stage = max(cur_stage - 1, 0) | |
| if cur_stage == 0: | |
| break | |
| cur_unit_ptx += 1 | |
| cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)] | |
| stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) | |
| if cur_stage == 0 and cur_unit_ptx < cur_unit_num: | |
| cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)] | |
| stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) | |
| stage_input = list(reversed(stage_input)) | |
| past_condition_latents.append(stage_input) | |
| intermed_latents = self.generate_one_unit( | |
| latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit], | |
| past_condition_latents, | |
| prompt_embeds, | |
| prompt_attention_mask, | |
| pooled_prompt_embeds, | |
| num_inference_steps, | |
| height, | |
| width, | |
| self.frame_per_unit, | |
| device, | |
| dtype, | |
| generator, | |
| is_first_frame=False, | |
| ) | |
| generated_latents_list.append(intermed_latents[-1]) | |
| last_generated_latents = intermed_latents | |
| generated_latents = torch.cat(generated_latents_list, dim=2) | |
| if output_type == "latent": | |
| image = generated_latents | |
| else: | |
| image = self.decode_latent(generated_latents, save_memory=save_memory) | |
| return image | |
| def generate( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| temp: int = 1, | |
| num_inference_steps: Optional[Union[int, List[int]]] = 28, | |
| video_num_inference_steps: Optional[Union[int, List[int]]] = 28, | |
| guidance_scale: float = 7.0, | |
| video_guidance_scale: float = 7.0, | |
| min_guidance_scale: float = 2.0, | |
| use_linear_guidance: bool = False, | |
| alpha: float = 0.5, | |
| negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror", | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| output_type: Optional[str] = "pil", | |
| save_memory: bool = True, | |
| ): | |
| device = self.device | |
| dtype = self.dtype | |
| assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit" | |
| if isinstance(prompt, str): | |
| batch_size = 1 | |
| prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics | |
| else: | |
| assert isinstance(prompt, list) | |
| batch_size = len(prompt) | |
| prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt] | |
| if isinstance(num_inference_steps, int): | |
| num_inference_steps = [num_inference_steps] * len(self.stages) | |
| if isinstance(video_num_inference_steps, int): | |
| video_num_inference_steps = [video_num_inference_steps] * len(self.stages) | |
| negative_prompt = negative_prompt or "" | |
| # Get the text embeddings | |
| prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device) | |
| negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device) | |
| if use_linear_guidance: | |
| max_guidance_scale = guidance_scale | |
| # guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist() | |
| guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)] | |
| print(guidance_scale_list) | |
| self._guidance_scale = guidance_scale | |
| self._video_guidance_scale = video_guidance_scale | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
| prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
| # Create the initial random noise | |
| num_channels_latents = self.dit.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| temp, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| ) | |
| temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1] | |
| latents = rearrange(latents, 'b c t h w -> (b t) c h w') | |
| # by defalut, we needs to start from the block noise | |
| for _ in range(len(self.stages)-1): | |
| height //= 2;width //= 2 | |
| latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2 | |
| latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp) | |
| num_units = 1 + (temp - 1) // self.frame_per_unit | |
| stages = self.stages | |
| generated_latents_list = [] # The generated results | |
| last_generated_latents = None | |
| for unit_index in tqdm(range(num_units)): | |
| if use_linear_guidance: | |
| self._guidance_scale = guidance_scale_list[unit_index] | |
| self._video_guidance_scale = guidance_scale_list[unit_index] | |
| if unit_index == 0: | |
| past_condition_latents = [[] for _ in range(len(stages))] | |
| intermed_latents = self.generate_one_unit( | |
| latents[:,:,:1], | |
| past_condition_latents, | |
| prompt_embeds, | |
| prompt_attention_mask, | |
| pooled_prompt_embeds, | |
| num_inference_steps, | |
| height, | |
| width, | |
| 1, | |
| device, | |
| dtype, | |
| generator, | |
| is_first_frame=True, | |
| ) | |
| else: | |
| # prepare the condition latents | |
| past_condition_latents = [] | |
| clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1) | |
| for i_s in range(len(stages)): | |
| last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):] | |
| stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent] | |
| # pad the past clean latents | |
| cur_unit_num = unit_index | |
| cur_stage = i_s | |
| cur_unit_ptx = 1 | |
| while cur_unit_ptx < cur_unit_num: | |
| cur_stage = max(cur_stage - 1, 0) | |
| if cur_stage == 0: | |
| break | |
| cur_unit_ptx += 1 | |
| cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)] | |
| stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) | |
| if cur_stage == 0 and cur_unit_ptx < cur_unit_num: | |
| cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)] | |
| stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents) | |
| stage_input = list(reversed(stage_input)) | |
| past_condition_latents.append(stage_input) | |
| intermed_latents = self.generate_one_unit( | |
| latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit], | |
| past_condition_latents, | |
| prompt_embeds, | |
| prompt_attention_mask, | |
| pooled_prompt_embeds, | |
| video_num_inference_steps, | |
| height, | |
| width, | |
| self.frame_per_unit, | |
| device, | |
| dtype, | |
| generator, | |
| is_first_frame=False, | |
| ) | |
| generated_latents_list.append(intermed_latents[-1]) | |
| last_generated_latents = intermed_latents | |
| generated_latents = torch.cat(generated_latents_list, dim=2) | |
| if output_type == "latent": | |
| image = generated_latents | |
| else: | |
| image = self.decode_latent(generated_latents, save_memory=save_memory) | |
| return image | |
| def decode_latent(self, latents, save_memory=True): | |
| if latents.shape[2] == 1: | |
| latents = (latents / self.vae_scale_factor) + self.vae_shift_factor | |
| else: | |
| latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor | |
| latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor | |
| if save_memory: | |
| # reducing the tile size and temporal chunk window size | |
| image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample | |
| else: | |
| image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample | |
| image = image.float() | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = rearrange(image, "B C T H W -> (B T) C H W") | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = self.numpy_to_pil(image) | |
| return image | |
| def numpy_to_pil(images): | |
| """ | |
| Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| if images.shape[-1] == 1: | |
| # special case for grayscale (single channel) images | |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
| else: | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def device(self): | |
| return next(self.dit.parameters()).device | |
| def dtype(self): | |
| return next(self.dit.parameters()).dtype | |
| def guidance_scale(self): | |
| return self._guidance_scale | |
| def video_guidance_scale(self): | |
| return self._video_guidance_scale | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 0 | |