|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
from typing import Literal, Optional |
|
|
from torch import Tensor |
|
|
from torch.nn import Conv3d |
|
|
|
|
|
from models.video_vae_v3.modules.inflated_lib import ( |
|
|
MemoryState, |
|
|
extend_head, |
|
|
inflate_bias, |
|
|
inflate_weight, |
|
|
modify_state_dict, |
|
|
) |
|
|
|
|
|
_inflation_mode_t = Literal["none", "tail", "replicate"] |
|
|
_memory_device_t = Optional[Literal["cpu", "same"]] |
|
|
|
|
|
|
|
|
class InflatedCausalConv3d(Conv3d): |
|
|
def __init__( |
|
|
self, |
|
|
*args, |
|
|
inflation_mode: _inflation_mode_t, |
|
|
memory_device: _memory_device_t = "same", |
|
|
**kwargs, |
|
|
): |
|
|
self.inflation_mode = inflation_mode |
|
|
self.memory = None |
|
|
super().__init__(*args, **kwargs) |
|
|
self.temporal_padding = self.padding[0] |
|
|
self.memory_device = memory_device |
|
|
self.padding = (0, *self.padding[1:]) |
|
|
|
|
|
def set_memory_device(self, memory_device: _memory_device_t): |
|
|
self.memory_device = memory_device |
|
|
|
|
|
def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: |
|
|
mem_size = self.stride[0] - self.kernel_size[0] |
|
|
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): |
|
|
input = extend_head(input, memory=self.memory) |
|
|
else: |
|
|
input = extend_head(input, times=self.temporal_padding * 2) |
|
|
memory = ( |
|
|
input[:, :, mem_size:].detach() |
|
|
if (mem_size != 0 and memory_state != MemoryState.DISABLED) |
|
|
else None |
|
|
) |
|
|
if ( |
|
|
memory_state != MemoryState.DISABLED |
|
|
and not self.training |
|
|
and (self.memory_device is not None) |
|
|
): |
|
|
self.memory = memory |
|
|
if self.memory_device == "cpu" and self.memory is not None: |
|
|
self.memory = self.memory.to("cpu") |
|
|
return super().forward(input) |
|
|
|
|
|
def _load_from_state_dict( |
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
|
): |
|
|
if self.inflation_mode != "none": |
|
|
state_dict = modify_state_dict( |
|
|
self, |
|
|
state_dict, |
|
|
prefix, |
|
|
inflate_weight_fn=partial(inflate_weight, position="tail"), |
|
|
inflate_bias_fn=partial(inflate_bias, position="tail"), |
|
|
) |
|
|
super()._load_from_state_dict( |
|
|
state_dict, |
|
|
prefix, |
|
|
local_metadata, |
|
|
(strict and self.inflation_mode == "none"), |
|
|
missing_keys, |
|
|
unexpected_keys, |
|
|
error_msgs, |
|
|
) |
|
|
|
|
|
|
|
|
def init_causal_conv3d( |
|
|
*args, |
|
|
inflation_mode: _inflation_mode_t, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Initialize a Causal-3D convolution layer. |
|
|
Parameters: |
|
|
inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. |
|
|
- none: No inflation will be conducted. |
|
|
The loading logic of state dict will fall back to default. |
|
|
- tail / replicate: Refer to the definition of `InflatedCausalConv3d`. |
|
|
""" |
|
|
return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) |
|
|
|