|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import nullcontext |
|
|
from typing import Literal, Optional, Tuple, Union |
|
|
import diffusers |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers.models.attention_processor import Attention, SpatialNorm |
|
|
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution |
|
|
from diffusers.models.downsampling import Downsample2D |
|
|
from diffusers.models.lora import LoRACompatibleConv |
|
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput |
|
|
from diffusers.models.resnet import ResnetBlock2D |
|
|
from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D |
|
|
from diffusers.models.upsampling import Upsample2D |
|
|
from diffusers.utils import is_torch_version |
|
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
|
from einops import rearrange |
|
|
|
|
|
from common.distributed.advanced import get_sequence_parallel_world_size |
|
|
from common.logger import get_logger |
|
|
from models.video_vae_v3.modules.causal_inflation_lib import ( |
|
|
InflatedCausalConv3d, |
|
|
causal_norm_wrapper, |
|
|
init_causal_conv3d, |
|
|
remove_head, |
|
|
) |
|
|
from models.video_vae_v3.modules.context_parallel_lib import ( |
|
|
causal_conv_gather_outputs, |
|
|
causal_conv_slice_inputs, |
|
|
) |
|
|
from models.video_vae_v3.modules.global_config import set_norm_limit |
|
|
from models.video_vae_v3.modules.types import ( |
|
|
CausalAutoencoderOutput, |
|
|
CausalDecoderOutput, |
|
|
CausalEncoderOutput, |
|
|
MemoryState, |
|
|
_inflation_mode_t, |
|
|
_memory_device_t, |
|
|
_receptive_field_t, |
|
|
) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class Upsample3D(Upsample2D): |
|
|
"""A 3D upsampling layer with an optional convolution.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
temporal_up: bool = False, |
|
|
spatial_up: bool = True, |
|
|
slicing: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
conv = self.conv if self.name == "conv" else self.Conv2d_0 |
|
|
|
|
|
assert type(conv) is not nn.ConvTranspose2d |
|
|
|
|
|
|
|
|
conv = init_causal_conv3d( |
|
|
self.channels, |
|
|
self.out_channels, |
|
|
3, |
|
|
padding=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
self.temporal_up = temporal_up |
|
|
self.spatial_up = spatial_up |
|
|
self.temporal_ratio = 2 if temporal_up else 1 |
|
|
self.spatial_ratio = 2 if spatial_up else 1 |
|
|
self.slicing = slicing |
|
|
|
|
|
assert not self.interpolate |
|
|
|
|
|
if not self.interpolate: |
|
|
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio |
|
|
self.upscale_conv = nn.Conv3d( |
|
|
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 |
|
|
) |
|
|
identity = ( |
|
|
torch.eye(self.channels) |
|
|
.repeat(upscale_ratio, 1) |
|
|
.reshape_as(self.upscale_conv.weight) |
|
|
) |
|
|
self.upscale_conv.weight.data.copy_(identity) |
|
|
nn.init.zeros_(self.upscale_conv.bias) |
|
|
|
|
|
if self.name == "conv": |
|
|
self.conv = conv |
|
|
else: |
|
|
self.Conv2d_0 = conv |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.FloatTensor, |
|
|
output_size: Optional[int] = None, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
**kwargs, |
|
|
) -> torch.FloatTensor: |
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
|
|
if hasattr(self, "norm") and self.norm is not None: |
|
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm, hidden_states) |
|
|
|
|
|
if self.use_conv_transpose: |
|
|
return self.conv(hidden_states) |
|
|
|
|
|
if self.slicing: |
|
|
split_size = hidden_states.size(2) // 2 |
|
|
hidden_states = list( |
|
|
hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) |
|
|
) |
|
|
else: |
|
|
hidden_states = [hidden_states] |
|
|
|
|
|
for i in range(len(hidden_states)): |
|
|
hidden_states[i] = self.upscale_conv(hidden_states[i]) |
|
|
hidden_states[i] = rearrange( |
|
|
hidden_states[i], |
|
|
"b (x y z c) f h w -> b c (f z) (h x) (w y)", |
|
|
x=self.spatial_ratio, |
|
|
y=self.spatial_ratio, |
|
|
z=self.temporal_ratio, |
|
|
) |
|
|
|
|
|
|
|
|
if self.temporal_up and memory_state != MemoryState.ACTIVE: |
|
|
hidden_states[0] = remove_head(hidden_states[0]) |
|
|
|
|
|
if not self.slicing: |
|
|
hidden_states = hidden_states[0] |
|
|
|
|
|
if self.use_conv: |
|
|
if self.name == "conv": |
|
|
hidden_states = self.conv(hidden_states, memory_state=memory_state) |
|
|
else: |
|
|
hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) |
|
|
|
|
|
if not self.slicing: |
|
|
return hidden_states |
|
|
else: |
|
|
return torch.cat(hidden_states, dim=2) |
|
|
|
|
|
|
|
|
class Downsample3D(Downsample2D): |
|
|
"""A 3D downsampling layer with an optional convolution.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
spatial_down: bool = False, |
|
|
temporal_down: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
conv = self.conv |
|
|
self.temporal_down = temporal_down |
|
|
self.spatial_down = spatial_down |
|
|
|
|
|
self.temporal_ratio = 2 if temporal_down else 1 |
|
|
self.spatial_ratio = 2 if spatial_down else 1 |
|
|
|
|
|
self.temporal_kernel = 3 if temporal_down else 1 |
|
|
self.spatial_kernel = 3 if spatial_down else 1 |
|
|
|
|
|
if type(conv) in [nn.Conv2d, LoRACompatibleConv]: |
|
|
|
|
|
|
|
|
conv = init_causal_conv3d( |
|
|
self.channels, |
|
|
self.out_channels, |
|
|
kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), |
|
|
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), |
|
|
padding=( |
|
|
1 if self.temporal_down else 0, |
|
|
self.padding if self.spatial_down else 0, |
|
|
self.padding if self.spatial_down else 0, |
|
|
), |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
elif type(conv) is nn.AvgPool2d: |
|
|
assert self.channels == self.out_channels |
|
|
conv = nn.AvgPool3d( |
|
|
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), |
|
|
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if self.name == "conv": |
|
|
self.Conv2d_0 = conv |
|
|
self.conv = conv |
|
|
else: |
|
|
self.conv = conv |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.FloatTensor, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
**kwargs, |
|
|
) -> torch.FloatTensor: |
|
|
|
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
|
|
if hasattr(self, "norm") and self.norm is not None: |
|
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm, hidden_states) |
|
|
|
|
|
if self.use_conv and self.padding == 0 and self.spatial_down: |
|
|
pad = (0, 1, 0, 1) |
|
|
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) |
|
|
|
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
|
|
hidden_states = self.conv(hidden_states, memory_state=memory_state) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class ResnetBlock3D(ResnetBlock2D): |
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
slicing: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.conv1 = init_causal_conv3d( |
|
|
self.in_channels, |
|
|
self.out_channels, |
|
|
kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), |
|
|
stride=1, |
|
|
padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
self.conv2 = init_causal_conv3d( |
|
|
self.out_channels, |
|
|
self.conv2.out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
if self.up: |
|
|
assert type(self.upsample) is Upsample2D |
|
|
self.upsample = Upsample3D( |
|
|
self.in_channels, |
|
|
use_conv=False, |
|
|
inflation_mode=inflation_mode, |
|
|
slicing=slicing, |
|
|
) |
|
|
elif self.down: |
|
|
assert type(self.downsample) is Downsample2D |
|
|
self.downsample = Downsample3D( |
|
|
self.in_channels, |
|
|
use_conv=False, |
|
|
padding=1, |
|
|
name="op", |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
if self.use_in_shortcut: |
|
|
self.conv_shortcut = init_causal_conv3d( |
|
|
self.in_channels, |
|
|
self.conv_shortcut.out_channels, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=(self.conv_shortcut.bias is not None), |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs |
|
|
): |
|
|
hidden_states = input_tensor |
|
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm1, hidden_states) |
|
|
|
|
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
|
|
if self.upsample is not None: |
|
|
|
|
|
|
|
|
if hidden_states.shape[0] >= 64: |
|
|
input_tensor = input_tensor.contiguous() |
|
|
hidden_states = hidden_states.contiguous() |
|
|
input_tensor = self.upsample(input_tensor, memory_state=memory_state) |
|
|
hidden_states = self.upsample(hidden_states, memory_state=memory_state) |
|
|
elif self.downsample is not None: |
|
|
input_tensor = self.downsample(input_tensor, memory_state=memory_state) |
|
|
hidden_states = self.downsample(hidden_states, memory_state=memory_state) |
|
|
|
|
|
hidden_states = self.conv1(hidden_states, memory_state=memory_state) |
|
|
|
|
|
if self.time_emb_proj is not None: |
|
|
if not self.skip_time_act: |
|
|
temb = self.nonlinearity(temb) |
|
|
temb = self.time_emb_proj(temb)[:, :, None, None] |
|
|
|
|
|
if temb is not None and self.time_embedding_norm == "default": |
|
|
hidden_states = hidden_states + temb |
|
|
|
|
|
hidden_states = causal_norm_wrapper(self.norm2, hidden_states) |
|
|
|
|
|
if temb is not None and self.time_embedding_norm == "scale_shift": |
|
|
scale, shift = torch.chunk(temb, 2, dim=1) |
|
|
hidden_states = hidden_states * (1 + scale) + shift |
|
|
|
|
|
hidden_states = self.nonlinearity(hidden_states) |
|
|
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states = self.conv2(hidden_states, memory_state=memory_state) |
|
|
|
|
|
if self.conv_shortcut is not None: |
|
|
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) |
|
|
|
|
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor |
|
|
|
|
|
return output_tensor |
|
|
|
|
|
|
|
|
class DownEncoderBlock3D(DownEncoderBlock2D): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout: float = 0.0, |
|
|
num_layers: int = 1, |
|
|
resnet_eps: float = 1e-6, |
|
|
resnet_time_scale_shift: str = "default", |
|
|
resnet_act_fn: str = "swish", |
|
|
resnet_groups: int = 32, |
|
|
resnet_pre_norm: bool = True, |
|
|
output_scale_factor: float = 1.0, |
|
|
add_downsample: bool = True, |
|
|
downsample_padding: int = 1, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
temporal_down: bool = True, |
|
|
spatial_down: bool = True, |
|
|
): |
|
|
super().__init__( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
dropout=dropout, |
|
|
num_layers=num_layers, |
|
|
resnet_eps=resnet_eps, |
|
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
|
resnet_act_fn=resnet_act_fn, |
|
|
resnet_groups=resnet_groups, |
|
|
resnet_pre_norm=resnet_pre_norm, |
|
|
output_scale_factor=output_scale_factor, |
|
|
add_downsample=add_downsample, |
|
|
downsample_padding=downsample_padding, |
|
|
) |
|
|
resnets = [] |
|
|
temporal_modules = [] |
|
|
|
|
|
for i in range(num_layers): |
|
|
in_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
|
|
|
ResnetBlock3D( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=None, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
) |
|
|
temporal_modules.append(nn.Identity()) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
self.temporal_modules = nn.ModuleList(temporal_modules) |
|
|
|
|
|
if add_downsample: |
|
|
self.downsamplers = nn.ModuleList( |
|
|
[ |
|
|
|
|
|
Downsample3D( |
|
|
out_channels, |
|
|
use_conv=True, |
|
|
out_channels=out_channels, |
|
|
padding=downsample_padding, |
|
|
name="op", |
|
|
temporal_down=temporal_down, |
|
|
spatial_down=spatial_down, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
] |
|
|
) |
|
|
else: |
|
|
self.downsamplers = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.FloatTensor, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
**kwargs, |
|
|
) -> torch.FloatTensor: |
|
|
for resnet, temporal in zip(self.resnets, self.temporal_modules): |
|
|
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) |
|
|
hidden_states = temporal(hidden_states) |
|
|
|
|
|
if self.downsamplers is not None: |
|
|
for downsampler in self.downsamplers: |
|
|
hidden_states = downsampler(hidden_states, memory_state=memory_state) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class UpDecoderBlock3D(UpDecoderBlock2D): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout: float = 0.0, |
|
|
num_layers: int = 1, |
|
|
resnet_eps: float = 1e-6, |
|
|
resnet_time_scale_shift: str = "default", |
|
|
resnet_act_fn: str = "swish", |
|
|
resnet_groups: int = 32, |
|
|
resnet_pre_norm: bool = True, |
|
|
output_scale_factor: float = 1.0, |
|
|
add_upsample: bool = True, |
|
|
temb_channels: Optional[int] = None, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
temporal_up: bool = True, |
|
|
spatial_up: bool = True, |
|
|
slicing: bool = False, |
|
|
): |
|
|
super().__init__( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
dropout=dropout, |
|
|
num_layers=num_layers, |
|
|
resnet_eps=resnet_eps, |
|
|
resnet_time_scale_shift=resnet_time_scale_shift, |
|
|
resnet_act_fn=resnet_act_fn, |
|
|
resnet_groups=resnet_groups, |
|
|
resnet_pre_norm=resnet_pre_norm, |
|
|
output_scale_factor=output_scale_factor, |
|
|
add_upsample=add_upsample, |
|
|
temb_channels=temb_channels, |
|
|
) |
|
|
resnets = [] |
|
|
temporal_modules = [] |
|
|
|
|
|
for i in range(num_layers): |
|
|
input_channels = in_channels if i == 0 else out_channels |
|
|
|
|
|
resnets.append( |
|
|
|
|
|
ResnetBlock3D( |
|
|
in_channels=input_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
slicing=slicing, |
|
|
) |
|
|
) |
|
|
|
|
|
temporal_modules.append(nn.Identity()) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
self.temporal_modules = nn.ModuleList(temporal_modules) |
|
|
|
|
|
if add_upsample: |
|
|
|
|
|
self.upsamplers = nn.ModuleList( |
|
|
[ |
|
|
Upsample3D( |
|
|
out_channels, |
|
|
use_conv=True, |
|
|
out_channels=out_channels, |
|
|
temporal_up=temporal_up, |
|
|
spatial_up=spatial_up, |
|
|
interpolate=False, |
|
|
inflation_mode=inflation_mode, |
|
|
slicing=slicing, |
|
|
) |
|
|
] |
|
|
) |
|
|
else: |
|
|
self.upsamplers = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.FloatTensor, |
|
|
temb: Optional[torch.FloatTensor] = None, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
) -> torch.FloatTensor: |
|
|
for resnet, temporal in zip(self.resnets, self.temporal_modules): |
|
|
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) |
|
|
hidden_states = temporal(hidden_states) |
|
|
|
|
|
if self.upsamplers is not None: |
|
|
for upsampler in self.upsamplers: |
|
|
hidden_states = upsampler(hidden_states, memory_state=memory_state) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class UNetMidBlock3D(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
temb_channels: int, |
|
|
dropout: float = 0.0, |
|
|
num_layers: int = 1, |
|
|
resnet_eps: float = 1e-6, |
|
|
resnet_time_scale_shift: str = "default", |
|
|
resnet_act_fn: str = "swish", |
|
|
resnet_groups: int = 32, |
|
|
resnet_pre_norm: bool = True, |
|
|
add_attention: bool = True, |
|
|
attention_head_dim: int = 1, |
|
|
output_scale_factor: float = 1.0, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
): |
|
|
super().__init__() |
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
|
|
self.add_attention = add_attention |
|
|
|
|
|
|
|
|
resnets = [ |
|
|
|
|
|
ResnetBlock3D( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
] |
|
|
attentions = [] |
|
|
|
|
|
if attention_head_dim is None: |
|
|
logger.warn( |
|
|
f"It is not recommend to pass `attention_head_dim=None`. " |
|
|
f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." |
|
|
) |
|
|
attention_head_dim = in_channels |
|
|
|
|
|
for _ in range(num_layers): |
|
|
if self.add_attention: |
|
|
attentions.append( |
|
|
Attention( |
|
|
in_channels, |
|
|
heads=in_channels // attention_head_dim, |
|
|
dim_head=attention_head_dim, |
|
|
rescale_output_factor=output_scale_factor, |
|
|
eps=resnet_eps, |
|
|
norm_num_groups=( |
|
|
resnet_groups if resnet_time_scale_shift == "default" else None |
|
|
), |
|
|
spatial_norm_dim=( |
|
|
temb_channels if resnet_time_scale_shift == "spatial" else None |
|
|
), |
|
|
residual_connection=True, |
|
|
bias=True, |
|
|
upcast_softmax=True, |
|
|
_from_deprecated_attn_block=True, |
|
|
) |
|
|
) |
|
|
else: |
|
|
attentions.append(None) |
|
|
|
|
|
resnets.append( |
|
|
ResnetBlock3D( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
temb_channels=temb_channels, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList(attentions) |
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED): |
|
|
video_length, frame_height, frame_width = hidden_states.size()[-3:] |
|
|
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) |
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]): |
|
|
if attn is not None: |
|
|
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") |
|
|
hidden_states = attn(hidden_states, temb=temb) |
|
|
hidden_states = rearrange( |
|
|
hidden_states, "(b f) c h w -> b c f h w", f=video_length |
|
|
) |
|
|
hidden_states = resnet(hidden_states, temb, memory_state=memory_state) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class Encoder3D(nn.Module): |
|
|
r""" |
|
|
[Override] override most logics to support extra condition input and causal conv |
|
|
|
|
|
The `Encoder` layer of a variational autoencoder that encodes |
|
|
its input into a latent representation. |
|
|
|
|
|
Args: |
|
|
in_channels (`int`, *optional*, defaults to 3): |
|
|
The number of input channels. |
|
|
out_channels (`int`, *optional*, defaults to 3): |
|
|
The number of output channels. |
|
|
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): |
|
|
The types of down blocks to use. |
|
|
See `~diffusers.models.unet_2d_blocks.get_down_block` |
|
|
for available options. |
|
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): |
|
|
The number of output channels for each block. |
|
|
layers_per_block (`int`, *optional*, defaults to 2): |
|
|
The number of layers per block. |
|
|
norm_num_groups (`int`, *optional*, defaults to 32): |
|
|
The number of groups for normalization. |
|
|
act_fn (`str`, *optional*, defaults to `"silu"`): |
|
|
The activation function to use. |
|
|
See `~diffusers.models.activations.get_activation` for available options. |
|
|
double_z (`bool`, *optional*, defaults to `True`): |
|
|
Whether to double the number of output channels for the last block. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 3, |
|
|
out_channels: int = 3, |
|
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), |
|
|
block_out_channels: Tuple[int, ...] = (64,), |
|
|
layers_per_block: int = 2, |
|
|
norm_num_groups: int = 32, |
|
|
act_fn: str = "silu", |
|
|
double_z: bool = True, |
|
|
mid_block_add_attention=True, |
|
|
|
|
|
temporal_down_num: int = 2, |
|
|
extra_cond_dim: int = None, |
|
|
gradient_checkpoint: bool = False, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
): |
|
|
super().__init__() |
|
|
self.layers_per_block = layers_per_block |
|
|
self.temporal_down_num = temporal_down_num |
|
|
|
|
|
self.conv_in = init_causal_conv3d( |
|
|
in_channels, |
|
|
block_out_channels[0], |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
self.mid_block = None |
|
|
self.down_blocks = nn.ModuleList([]) |
|
|
self.extra_cond_dim = extra_cond_dim |
|
|
|
|
|
self.conv_extra_cond = nn.ModuleList([]) |
|
|
|
|
|
|
|
|
output_channel = block_out_channels[0] |
|
|
for i, down_block_type in enumerate(down_block_types): |
|
|
input_channel = output_channel |
|
|
output_channel = block_out_channels[i] |
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
|
|
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 |
|
|
|
|
|
|
|
|
assert down_block_type == "DownEncoderBlock3D" |
|
|
|
|
|
down_block = DownEncoderBlock3D( |
|
|
num_layers=self.layers_per_block, |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
add_downsample=not is_final_block, |
|
|
resnet_eps=1e-6, |
|
|
downsample_padding=0, |
|
|
|
|
|
resnet_act_fn=act_fn, |
|
|
resnet_groups=norm_num_groups, |
|
|
temporal_down=is_temporal_down_block, |
|
|
spatial_down=True, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
self.down_blocks.append(down_block) |
|
|
|
|
|
def zero_module(module): |
|
|
|
|
|
for p in module.parameters(): |
|
|
p.detach().zero_() |
|
|
return module |
|
|
|
|
|
self.conv_extra_cond.append( |
|
|
zero_module( |
|
|
nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) |
|
|
) |
|
|
if self.extra_cond_dim is not None and self.extra_cond_dim > 0 |
|
|
else None |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock3D( |
|
|
in_channels=block_out_channels[-1], |
|
|
resnet_eps=1e-6, |
|
|
resnet_act_fn=act_fn, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=block_out_channels[-1], |
|
|
resnet_groups=norm_num_groups, |
|
|
temb_channels=None, |
|
|
add_attention=mid_block_add_attention, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm( |
|
|
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 |
|
|
) |
|
|
self.conv_act = nn.SiLU() |
|
|
|
|
|
conv_out_channels = 2 * out_channels if double_z else out_channels |
|
|
self.conv_out = init_causal_conv3d( |
|
|
block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode |
|
|
) |
|
|
|
|
|
self.gradient_checkpointing = gradient_checkpoint |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
sample: torch.FloatTensor, |
|
|
extra_cond=None, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
) -> torch.FloatTensor: |
|
|
r"""The forward method of the `Encoder` class.""" |
|
|
sample = self.conv_in(sample, memory_state=memory_state) |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
|
|
|
|
|
|
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): |
|
|
sample = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(down_block), sample, memory_state, use_reentrant=False |
|
|
) |
|
|
if extra_block is not None: |
|
|
sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) |
|
|
|
|
|
|
|
|
sample = self.mid_block(sample, memory_state=memory_state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): |
|
|
sample = down_block(sample, memory_state=memory_state) |
|
|
if extra_block is not None: |
|
|
sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) |
|
|
|
|
|
|
|
|
sample = self.mid_block(sample, memory_state=memory_state) |
|
|
|
|
|
|
|
|
sample = causal_norm_wrapper(self.conv_norm_out, sample) |
|
|
sample = self.conv_act(sample) |
|
|
sample = self.conv_out(sample, memory_state=memory_state) |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
class Decoder3D(nn.Module): |
|
|
r""" |
|
|
The `Decoder` layer of a variational autoencoder that |
|
|
decodes its latent representation into an output sample. |
|
|
|
|
|
Args: |
|
|
in_channels (`int`, *optional*, defaults to 3): |
|
|
The number of input channels. |
|
|
out_channels (`int`, *optional*, defaults to 3): |
|
|
The number of output channels. |
|
|
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): |
|
|
The types of up blocks to use. |
|
|
See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. |
|
|
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): |
|
|
The number of output channels for each block. |
|
|
layers_per_block (`int`, *optional*, defaults to 2): |
|
|
The number of layers per block. |
|
|
norm_num_groups (`int`, *optional*, defaults to 32): |
|
|
The number of groups for normalization. |
|
|
act_fn (`str`, *optional*, defaults to `"silu"`): |
|
|
The activation function to use. |
|
|
See `~diffusers.models.activations.get_activation` for available options. |
|
|
norm_type (`str`, *optional*, defaults to `"group"`): |
|
|
The normalization type to use. Can be either `"group"` or `"spatial"`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 3, |
|
|
out_channels: int = 3, |
|
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), |
|
|
block_out_channels: Tuple[int, ...] = (64,), |
|
|
layers_per_block: int = 2, |
|
|
norm_num_groups: int = 32, |
|
|
act_fn: str = "silu", |
|
|
norm_type: str = "group", |
|
|
mid_block_add_attention=True, |
|
|
|
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "half", |
|
|
temporal_up_num: int = 2, |
|
|
slicing_up_num: int = 0, |
|
|
gradient_checkpoint: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers_per_block = layers_per_block |
|
|
self.temporal_up_num = temporal_up_num |
|
|
|
|
|
self.conv_in = init_causal_conv3d( |
|
|
in_channels, |
|
|
block_out_channels[-1], |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
|
|
|
self.mid_block = None |
|
|
self.up_blocks = nn.ModuleList([]) |
|
|
|
|
|
temb_channels = in_channels if norm_type == "spatial" else None |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock3D( |
|
|
in_channels=block_out_channels[-1], |
|
|
resnet_eps=1e-6, |
|
|
resnet_act_fn=act_fn, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default" if norm_type == "group" else norm_type, |
|
|
attention_head_dim=block_out_channels[-1], |
|
|
resnet_groups=norm_num_groups, |
|
|
temb_channels=temb_channels, |
|
|
add_attention=mid_block_add_attention, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
|
|
|
|
|
|
reversed_block_out_channels = list(reversed(block_out_channels)) |
|
|
output_channel = reversed_block_out_channels[0] |
|
|
print(f"slicing_up_num: {slicing_up_num}") |
|
|
for i, up_block_type in enumerate(up_block_types): |
|
|
prev_output_channel = output_channel |
|
|
output_channel = reversed_block_out_channels[i] |
|
|
|
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
is_temporal_up_block = i < self.temporal_up_num |
|
|
is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num |
|
|
|
|
|
|
|
|
assert up_block_type == "UpDecoderBlock3D" |
|
|
up_block = UpDecoderBlock3D( |
|
|
num_layers=self.layers_per_block + 1, |
|
|
in_channels=prev_output_channel, |
|
|
out_channels=output_channel, |
|
|
add_upsample=not is_final_block, |
|
|
resnet_eps=1e-6, |
|
|
resnet_act_fn=act_fn, |
|
|
resnet_groups=norm_num_groups, |
|
|
resnet_time_scale_shift=norm_type, |
|
|
temb_channels=temb_channels, |
|
|
temporal_up=is_temporal_up_block, |
|
|
slicing=is_slicing_up_block, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
self.up_blocks.append(up_block) |
|
|
prev_output_channel = output_channel |
|
|
|
|
|
|
|
|
if norm_type == "spatial": |
|
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) |
|
|
else: |
|
|
self.conv_norm_out = nn.GroupNorm( |
|
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 |
|
|
) |
|
|
self.conv_act = nn.SiLU() |
|
|
self.conv_out = init_causal_conv3d( |
|
|
block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode |
|
|
) |
|
|
|
|
|
self.gradient_checkpointing = gradient_checkpoint |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
sample: torch.FloatTensor, |
|
|
latent_embeds: Optional[torch.FloatTensor] = None, |
|
|
memory_state: MemoryState = MemoryState.DISABLED, |
|
|
) -> torch.FloatTensor: |
|
|
r"""The forward method of the `Decoder` class.""" |
|
|
|
|
|
sample = self.conv_in(sample, memory_state=memory_state) |
|
|
|
|
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
if is_torch_version(">=", "1.11.0"): |
|
|
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) |
|
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
sample = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(up_block), |
|
|
sample, |
|
|
latent_embeds, |
|
|
memory_state, |
|
|
use_reentrant=False, |
|
|
) |
|
|
else: |
|
|
|
|
|
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) |
|
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
sample = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(up_block), sample, latent_embeds, memory_state |
|
|
) |
|
|
else: |
|
|
|
|
|
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) |
|
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
sample = up_block(sample, latent_embeds, memory_state=memory_state) |
|
|
|
|
|
|
|
|
sample = causal_norm_wrapper(self.conv_norm_out, sample) |
|
|
sample = self.conv_act(sample) |
|
|
sample = self.conv_out(sample, memory_state=memory_state) |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
class AutoencoderKL(diffusers.AutoencoderKL): |
|
|
""" |
|
|
We simply inherit the model code from diffusers |
|
|
""" |
|
|
|
|
|
def __init__(self, attention: bool = True, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
if not attention: |
|
|
self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) |
|
|
self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
|
|
|
|
|
|
|
|
|
convert_deprecated_attention_blocks = getattr( |
|
|
self, "_convert_deprecated_attention_blocks", None |
|
|
) |
|
|
if callable(convert_deprecated_attention_blocks): |
|
|
convert_deprecated_attention_blocks(state_dict) |
|
|
return super().load_state_dict(state_dict, strict) |
|
|
|
|
|
|
|
|
class VideoAutoencoderKL(diffusers.AutoencoderKL): |
|
|
""" |
|
|
We simply inherit the model code from diffusers |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 3, |
|
|
out_channels: int = 3, |
|
|
down_block_types: Tuple[str] = ("DownEncoderBlock3D",), |
|
|
up_block_types: Tuple[str] = ("UpDecoderBlock3D",), |
|
|
block_out_channels: Tuple[int] = (64,), |
|
|
layers_per_block: int = 1, |
|
|
act_fn: str = "silu", |
|
|
latent_channels: int = 4, |
|
|
norm_num_groups: int = 32, |
|
|
sample_size: int = 32, |
|
|
scaling_factor: float = 0.18215, |
|
|
force_upcast: float = True, |
|
|
attention: bool = True, |
|
|
temporal_scale_num: int = 2, |
|
|
slicing_up_num: int = 0, |
|
|
gradient_checkpoint: bool = False, |
|
|
inflation_mode: _inflation_mode_t = "tail", |
|
|
time_receptive_field: _receptive_field_t = "full", |
|
|
slicing_sample_min_size: int = 32, |
|
|
use_quant_conv: bool = True, |
|
|
use_post_quant_conv: bool = True, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None |
|
|
self.slicing_sample_min_size = slicing_sample_min_size |
|
|
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) |
|
|
|
|
|
super().__init__( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
|
|
|
down_block_types=tuple( |
|
|
[down_block_type.replace("3D", "2D") for down_block_type in down_block_types] |
|
|
), |
|
|
up_block_types=tuple( |
|
|
[up_block_type.replace("3D", "2D") for up_block_type in up_block_types] |
|
|
), |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
act_fn=act_fn, |
|
|
latent_channels=latent_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
sample_size=sample_size, |
|
|
scaling_factor=scaling_factor, |
|
|
force_upcast=force_upcast, |
|
|
*args, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
self.encoder = Encoder3D( |
|
|
in_channels=in_channels, |
|
|
out_channels=latent_channels, |
|
|
down_block_types=down_block_types, |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
act_fn=act_fn, |
|
|
norm_num_groups=norm_num_groups, |
|
|
double_z=True, |
|
|
extra_cond_dim=extra_cond_dim, |
|
|
|
|
|
temporal_down_num=temporal_scale_num, |
|
|
gradient_checkpoint=gradient_checkpoint, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
|
|
|
|
|
|
self.decoder = Decoder3D( |
|
|
in_channels=latent_channels, |
|
|
out_channels=out_channels, |
|
|
up_block_types=up_block_types, |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
norm_num_groups=norm_num_groups, |
|
|
act_fn=act_fn, |
|
|
|
|
|
temporal_up_num=temporal_scale_num, |
|
|
slicing_up_num=slicing_up_num, |
|
|
gradient_checkpoint=gradient_checkpoint, |
|
|
inflation_mode=inflation_mode, |
|
|
time_receptive_field=time_receptive_field, |
|
|
) |
|
|
|
|
|
self.quant_conv = ( |
|
|
init_causal_conv3d( |
|
|
in_channels=2 * latent_channels, |
|
|
out_channels=2 * latent_channels, |
|
|
kernel_size=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
if use_quant_conv |
|
|
else None |
|
|
) |
|
|
self.post_quant_conv = ( |
|
|
init_causal_conv3d( |
|
|
in_channels=latent_channels, |
|
|
out_channels=latent_channels, |
|
|
kernel_size=1, |
|
|
inflation_mode=inflation_mode, |
|
|
) |
|
|
if use_post_quant_conv |
|
|
else None |
|
|
) |
|
|
|
|
|
|
|
|
if not attention: |
|
|
self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) |
|
|
self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) |
|
|
|
|
|
@apply_forward_hook |
|
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: |
|
|
h = self.slicing_encode(x) |
|
|
posterior = DiagonalGaussianDistribution(h) |
|
|
|
|
|
if not return_dict: |
|
|
return (posterior,) |
|
|
|
|
|
return AutoencoderKLOutput(latent_dist=posterior) |
|
|
|
|
|
@apply_forward_hook |
|
|
def decode( |
|
|
self, z: torch.Tensor, return_dict: bool = True |
|
|
) -> Union[DecoderOutput, torch.Tensor]: |
|
|
decoded = self.slicing_decode(z) |
|
|
|
|
|
if not return_dict: |
|
|
return (decoded,) |
|
|
|
|
|
return DecoderOutput(sample=decoded) |
|
|
|
|
|
def _encode( |
|
|
self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED |
|
|
) -> torch.Tensor: |
|
|
_x = x.to(self.device) |
|
|
_x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state) |
|
|
h = self.encoder(_x, memory_state=memory_state) |
|
|
if self.quant_conv is not None: |
|
|
output = self.quant_conv(h, memory_state=memory_state) |
|
|
else: |
|
|
output = h |
|
|
output = causal_conv_gather_outputs(output) |
|
|
return output.to(x.device) |
|
|
|
|
|
def _decode( |
|
|
self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED |
|
|
) -> torch.Tensor: |
|
|
_z = z.to(self.device) |
|
|
_z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state) |
|
|
if self.post_quant_conv is not None: |
|
|
_z = self.post_quant_conv(_z, memory_state=memory_state) |
|
|
output = self.decoder(_z, memory_state=memory_state) |
|
|
output = causal_conv_gather_outputs(output) |
|
|
return output.to(z.device) |
|
|
|
|
|
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: |
|
|
sp_size = get_sequence_parallel_world_size() |
|
|
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: |
|
|
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) |
|
|
encoded_slices = [ |
|
|
self._encode( |
|
|
torch.cat((x[:, :, :1], x_slices[0]), dim=2), |
|
|
memory_state=MemoryState.INITIALIZING, |
|
|
) |
|
|
] |
|
|
for x_idx in range(1, len(x_slices)): |
|
|
encoded_slices.append( |
|
|
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) |
|
|
) |
|
|
return torch.cat(encoded_slices, dim=2) |
|
|
else: |
|
|
return self._encode(x) |
|
|
|
|
|
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: |
|
|
sp_size = get_sequence_parallel_world_size() |
|
|
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: |
|
|
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) |
|
|
decoded_slices = [ |
|
|
self._decode( |
|
|
torch.cat((z[:, :, :1], z_slices[0]), dim=2), |
|
|
memory_state=MemoryState.INITIALIZING, |
|
|
) |
|
|
] |
|
|
for z_idx in range(1, len(z_slices)): |
|
|
decoded_slices.append( |
|
|
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) |
|
|
) |
|
|
return torch.cat(decoded_slices, dim=2) |
|
|
else: |
|
|
return self._decode(z) |
|
|
|
|
|
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
raise NotImplementedError |
|
|
|
|
|
def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward( |
|
|
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs |
|
|
): |
|
|
|
|
|
if mode == "encode": |
|
|
h = self.encode(x) |
|
|
return h.latent_dist |
|
|
elif mode == "decode": |
|
|
h = self.decode(x) |
|
|
return h.sample |
|
|
else: |
|
|
h = self.encode(x) |
|
|
h = self.decode(h.latent_dist.mode()) |
|
|
return h.sample |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convert_deprecated_attention_blocks = getattr( |
|
|
self, "_convert_deprecated_attention_blocks", None |
|
|
) |
|
|
if callable(convert_deprecated_attention_blocks): |
|
|
convert_deprecated_attention_blocks(state_dict) |
|
|
return super().load_state_dict(state_dict, strict) |
|
|
|
|
|
|
|
|
class VideoAutoencoderKLWrapper(VideoAutoencoderKL): |
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
spatial_downsample_factor: int, |
|
|
temporal_downsample_factor: int, |
|
|
freeze_encoder: bool, |
|
|
**kwargs, |
|
|
): |
|
|
self.spatial_downsample_factor = spatial_downsample_factor |
|
|
self.temporal_downsample_factor = temporal_downsample_factor |
|
|
self.freeze_encoder = freeze_encoder |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: |
|
|
with torch.no_grad() if self.freeze_encoder else nullcontext(): |
|
|
z, p = self.encode(x) |
|
|
x = self.decode(z).sample |
|
|
return CausalAutoencoderOutput(x, z, p) |
|
|
|
|
|
def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: |
|
|
if x.ndim == 4: |
|
|
x = x.unsqueeze(2) |
|
|
p = super().encode(x).latent_dist |
|
|
z = p.sample().squeeze(2) |
|
|
return CausalEncoderOutput(z, p) |
|
|
|
|
|
def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: |
|
|
if z.ndim == 4: |
|
|
z = z.unsqueeze(2) |
|
|
x = super().decode(z).sample.squeeze(2) |
|
|
return CausalDecoderOutput(x) |
|
|
|
|
|
def preprocess(self, x: torch.Tensor): |
|
|
|
|
|
assert x.ndim == 4 or x.size(2) % 4 == 1 |
|
|
return x |
|
|
|
|
|
def postprocess(self, x: torch.Tensor): |
|
|
|
|
|
return x |
|
|
|
|
|
def set_causal_slicing( |
|
|
self, |
|
|
*, |
|
|
split_size: Optional[int], |
|
|
memory_device: _memory_device_t, |
|
|
): |
|
|
assert ( |
|
|
split_size is None or memory_device is not None |
|
|
), "if split_size is set, memory_device must not be None." |
|
|
if split_size is not None: |
|
|
self.enable_slicing() |
|
|
self.slicing_sample_min_size = split_size |
|
|
self.slicing_latent_min_size = split_size // self.temporal_downsample_factor |
|
|
else: |
|
|
self.disable_slicing() |
|
|
for module in self.modules(): |
|
|
if isinstance(module, InflatedCausalConv3d): |
|
|
module.set_memory_device(memory_device) |
|
|
|
|
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): |
|
|
set_norm_limit(norm_max_mem) |
|
|
for m in self.modules(): |
|
|
if isinstance(m, InflatedCausalConv3d): |
|
|
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) |
|
|
|