|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from omegaconf import DictConfig, ListConfig |
|
|
from torch import Tensor |
|
|
|
|
|
from common.config import create_object |
|
|
from common.decorators import log_on_entry, log_runtime |
|
|
from common.diffusion import ( |
|
|
classifier_free_guidance_dispatcher, |
|
|
create_sampler_from_config, |
|
|
create_sampling_timesteps_from_config, |
|
|
create_schedule_from_config, |
|
|
) |
|
|
from common.distributed import ( |
|
|
get_device, |
|
|
get_global_rank, |
|
|
) |
|
|
|
|
|
from common.distributed.meta_init_utils import ( |
|
|
meta_non_persistent_buffer_init_fn, |
|
|
) |
|
|
|
|
|
|
|
|
from models.dit_v2 import na |
|
|
|
|
|
class VideoDiffusionInfer(): |
|
|
def __init__(self, config: DictConfig): |
|
|
self.config = config |
|
|
self.device = "cuda" |
|
|
|
|
|
def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: |
|
|
t, h, w, c = latent.shape |
|
|
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) |
|
|
if task == "t2v" or t == 1: |
|
|
|
|
|
if task == "sr": |
|
|
cond[:, ..., :-1] = latent_blur[:] |
|
|
cond[:, ..., -1:] = 1.0 |
|
|
return cond |
|
|
if task == "i2v": |
|
|
|
|
|
cond[:1, ..., :-1] = latent[:1] |
|
|
cond[:1, ..., -1:] = 1.0 |
|
|
return cond |
|
|
if task == "v2v": |
|
|
|
|
|
cond[:2, ..., :-1] = latent[:2] |
|
|
cond[:2, ..., -1:] = 1.0 |
|
|
return cond |
|
|
if task == "sr": |
|
|
|
|
|
cond[:, ..., :-1] = latent_blur[:] |
|
|
cond[:, ..., -1:] = 1.0 |
|
|
return cond |
|
|
raise NotImplementedError |
|
|
|
|
|
@log_on_entry |
|
|
@log_runtime |
|
|
def configure_dit_model(self, device="cuda", checkpoint=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.device(self.device): |
|
|
self.dit = create_object(self.config.dit.model) |
|
|
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint) |
|
|
|
|
|
if checkpoint: |
|
|
state = torch.load(checkpoint, map_location=self.device, mmap=True) |
|
|
loading_info = self.dit.load_state_dict(state, strict=True, assign=True) |
|
|
print(f"Loading pretrained ckpt from {checkpoint}") |
|
|
print(f"Loading info: {loading_info}") |
|
|
self.dit = meta_non_persistent_buffer_init_fn(self.dit) |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad) |
|
|
print(f"DiT trainable parameters: {num_params:,}") |
|
|
|
|
|
@log_on_entry |
|
|
@log_runtime |
|
|
def configure_vae_model(self): |
|
|
|
|
|
dtype = getattr(torch, self.config.vae.dtype) |
|
|
self.vae = create_object(self.config.vae.model) |
|
|
self.vae.requires_grad_(False).eval() |
|
|
self.vae.to(device=get_device(), dtype=dtype) |
|
|
|
|
|
|
|
|
state = torch.load( |
|
|
self.config.vae.checkpoint, map_location=get_device(), mmap=True |
|
|
) |
|
|
self.vae.load_state_dict(state) |
|
|
|
|
|
|
|
|
if hasattr(self.vae, "set_causal_slicing") and hasattr(self.config.vae, "slicing"): |
|
|
self.vae.set_causal_slicing(**self.config.vae.slicing) |
|
|
|
|
|
|
|
|
|
|
|
def configure_diffusion(self): |
|
|
self.schedule = create_schedule_from_config( |
|
|
config=self.config.diffusion.schedule, |
|
|
device=get_device(), |
|
|
) |
|
|
self.sampling_timesteps = create_sampling_timesteps_from_config( |
|
|
config=self.config.diffusion.timesteps.sampling, |
|
|
schedule=self.schedule, |
|
|
device=get_device(), |
|
|
) |
|
|
self.sampler = create_sampler_from_config( |
|
|
config=self.config.diffusion.sampler, |
|
|
schedule=self.schedule, |
|
|
timesteps=self.sampling_timesteps, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: |
|
|
use_sample = self.config.vae.get("use_sample", True) |
|
|
latents = [] |
|
|
if len(samples) > 0: |
|
|
device = get_device() |
|
|
dtype = getattr(torch, self.config.vae.dtype) |
|
|
scale = self.config.vae.scaling_factor |
|
|
shift = self.config.vae.get("shifting_factor", 0.0) |
|
|
|
|
|
if isinstance(scale, ListConfig): |
|
|
scale = torch.tensor(scale, device=device, dtype=dtype) |
|
|
if isinstance(shift, ListConfig): |
|
|
shift = torch.tensor(shift, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
if self.config.vae.grouping: |
|
|
batches, indices = na.pack(samples) |
|
|
else: |
|
|
batches = [sample.unsqueeze(0) for sample in samples] |
|
|
|
|
|
|
|
|
for sample in batches: |
|
|
sample = sample.to(device, dtype) |
|
|
if hasattr(self.vae, "preprocess"): |
|
|
sample = self.vae.preprocess(sample) |
|
|
if use_sample: |
|
|
latent = self.vae.encode(sample).latent |
|
|
else: |
|
|
|
|
|
latent = self.vae.encode(sample).posterior.mode().squeeze(2) |
|
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent |
|
|
latent = rearrange(latent, "b c ... -> b ... c") |
|
|
latent = (latent - shift) * scale |
|
|
latents.append(latent) |
|
|
|
|
|
|
|
|
if self.config.vae.grouping: |
|
|
latents = na.unpack(latents, indices) |
|
|
else: |
|
|
latents = [latent.squeeze(0) for latent in latents] |
|
|
|
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]: |
|
|
samples = [] |
|
|
if len(latents) > 0: |
|
|
device = get_device() |
|
|
dtype = getattr(torch, self.config.vae.dtype) |
|
|
scale = self.config.vae.scaling_factor |
|
|
shift = self.config.vae.get("shifting_factor", 0.0) |
|
|
|
|
|
if isinstance(scale, ListConfig): |
|
|
scale = torch.tensor(scale, device=device, dtype=dtype) |
|
|
if isinstance(shift, ListConfig): |
|
|
shift = torch.tensor(shift, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
if self.config.vae.grouping: |
|
|
latents, indices = na.pack(latents) |
|
|
else: |
|
|
latents = [latent.unsqueeze(0) for latent in latents] |
|
|
|
|
|
|
|
|
for latent in latents: |
|
|
latent = latent.to(device, dtype) |
|
|
latent = latent / scale + shift |
|
|
latent = rearrange(latent, "b ... c -> b c ...") |
|
|
latent = latent.squeeze(2) |
|
|
sample = self.vae.decode(latent).sample |
|
|
if hasattr(self.vae, "postprocess"): |
|
|
sample = self.vae.postprocess(sample) |
|
|
samples.append(sample) |
|
|
|
|
|
|
|
|
if self.config.vae.grouping: |
|
|
samples = na.unpack(samples, indices) |
|
|
else: |
|
|
samples = [sample.squeeze(0) for sample in samples] |
|
|
|
|
|
return samples |
|
|
|
|
|
def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): |
|
|
|
|
|
if not self.config.diffusion.timesteps.get("transform", False): |
|
|
return timesteps |
|
|
|
|
|
|
|
|
vt = self.config.vae.model.get("temporal_downsample_factor", 4) |
|
|
vs = self.config.vae.model.get("spatial_downsample_factor", 8) |
|
|
frames = (latents_shapes[:, 0] - 1) * vt + 1 |
|
|
heights = latents_shapes[:, 1] * vs |
|
|
widths = latents_shapes[:, 2] * vs |
|
|
|
|
|
|
|
|
def get_lin_function(x1, y1, x2, y2): |
|
|
m = (y2 - y1) / (x2 - x1) |
|
|
b = y1 - m * x1 |
|
|
return lambda x: m * x + b |
|
|
|
|
|
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) |
|
|
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) |
|
|
shift = torch.where( |
|
|
frames > 1, |
|
|
vid_shift_fn(heights * widths * frames), |
|
|
img_shift_fn(heights * widths), |
|
|
) |
|
|
|
|
|
|
|
|
timesteps = timesteps / self.schedule.T |
|
|
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) |
|
|
timesteps = timesteps * self.schedule.T |
|
|
return timesteps |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
noises: List[Tensor], |
|
|
conditions: List[Tensor], |
|
|
texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], |
|
|
texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], |
|
|
cfg_scale: Optional[float] = None, |
|
|
dit_offload: bool = False, |
|
|
) -> List[Tensor]: |
|
|
assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) |
|
|
batch_size = len(noises) |
|
|
|
|
|
|
|
|
if batch_size == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
if cfg_scale is None: |
|
|
cfg_scale = self.config.diffusion.cfg.scale |
|
|
|
|
|
|
|
|
assert type(texts_pos[0]) is type(texts_neg[0]) |
|
|
if isinstance(texts_pos[0], str): |
|
|
text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) |
|
|
text_neg_embeds, text_neg_shapes = self.text_encode(texts_neg) |
|
|
elif isinstance(texts_pos[0], tuple): |
|
|
text_pos_embeds, text_pos_shapes = [], [] |
|
|
text_neg_embeds, text_neg_shapes = [], [] |
|
|
for pos in zip(*texts_pos): |
|
|
emb, shape = na.flatten(pos) |
|
|
text_pos_embeds.append(emb) |
|
|
text_pos_shapes.append(shape) |
|
|
for neg in zip(*texts_neg): |
|
|
emb, shape = na.flatten(neg) |
|
|
text_neg_embeds.append(emb) |
|
|
text_neg_shapes.append(shape) |
|
|
else: |
|
|
text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) |
|
|
text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) |
|
|
|
|
|
|
|
|
latents, latents_shapes = na.flatten(noises) |
|
|
latents_cond, _ = na.flatten(conditions) |
|
|
|
|
|
|
|
|
was_training = self.dit.training |
|
|
self.dit.eval() |
|
|
|
|
|
|
|
|
latents = self.sampler.sample( |
|
|
x=latents, |
|
|
f=lambda args: classifier_free_guidance_dispatcher( |
|
|
pos=lambda: self.dit( |
|
|
vid=torch.cat([args.x_t, latents_cond], dim=-1), |
|
|
txt=text_pos_embeds, |
|
|
vid_shape=latents_shapes, |
|
|
txt_shape=text_pos_shapes, |
|
|
timestep=args.t.repeat(batch_size), |
|
|
).vid_sample, |
|
|
neg=lambda: self.dit( |
|
|
vid=torch.cat([args.x_t, latents_cond], dim=-1), |
|
|
txt=text_neg_embeds, |
|
|
vid_shape=latents_shapes, |
|
|
txt_shape=text_neg_shapes, |
|
|
timestep=args.t.repeat(batch_size), |
|
|
).vid_sample, |
|
|
scale=( |
|
|
cfg_scale |
|
|
if (args.i + 1) / len(self.sampler.timesteps) |
|
|
<= self.config.diffusion.cfg.get("partial", 1) |
|
|
else 1.0 |
|
|
), |
|
|
rescale=self.config.diffusion.cfg.rescale, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self.dit.train(was_training) |
|
|
|
|
|
|
|
|
latents = na.unflatten(latents, latents_shapes) |
|
|
|
|
|
if dit_offload: |
|
|
self.dit.to("cpu") |
|
|
|
|
|
|
|
|
self.vae.to(get_device()) |
|
|
samples = self.vae_decode(latents) |
|
|
|
|
|
if dit_offload: |
|
|
self.dit.to(get_device()) |
|
|
return samples |