|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
|
from typing import Optional |
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers.models.normalization import RMSNorm |
|
|
from einops import rearrange |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from common.logger import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class MemoryState(Enum): |
|
|
""" |
|
|
State[Disabled]: No memory bank will be enabled. |
|
|
State[Initializing]: The model is handling the first clip, |
|
|
need to reset / initialize the memory bank. |
|
|
State[Active]: There has been some data in the memory bank. |
|
|
""" |
|
|
|
|
|
DISABLED = 0 |
|
|
INITIALIZING = 1 |
|
|
ACTIVE = 2 |
|
|
|
|
|
|
|
|
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: |
|
|
if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): |
|
|
if x.ndim == 4: |
|
|
x = rearrange(x, "b c h w -> b h w c") |
|
|
x = norm_layer(x) |
|
|
x = rearrange(x, "b h w c -> b c h w") |
|
|
return x |
|
|
if x.ndim == 5: |
|
|
x = rearrange(x, "b c t h w -> b t h w c") |
|
|
x = norm_layer(x) |
|
|
x = rearrange(x, "b t h w c -> b c t h w") |
|
|
return x |
|
|
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): |
|
|
if x.ndim <= 4: |
|
|
return norm_layer(x) |
|
|
if x.ndim == 5: |
|
|
t = x.size(2) |
|
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
|
x = norm_layer(x) |
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) |
|
|
return x |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def remove_head(tensor: Tensor, times: int = 1) -> Tensor: |
|
|
""" |
|
|
Remove duplicated first frame features in the up-sampling process. |
|
|
""" |
|
|
if times == 0: |
|
|
return tensor |
|
|
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) |
|
|
|
|
|
|
|
|
def extend_head( |
|
|
tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None |
|
|
) -> Tensor: |
|
|
""" |
|
|
When memory is None: |
|
|
- Duplicate first frame features in the down-sampling process. |
|
|
When memory is not None: |
|
|
- Concatenate memory features with the input features to keep temporal consistency. |
|
|
""" |
|
|
if times == 0: |
|
|
return tensor |
|
|
if memory is not None: |
|
|
return torch.cat((memory.to(tensor), tensor), dim=2) |
|
|
else: |
|
|
tile_repeat = np.ones(tensor.ndim).astype(int) |
|
|
tile_repeat[2] = times |
|
|
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) |
|
|
|
|
|
|
|
|
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): |
|
|
""" |
|
|
Inflate a 2D convolution weight matrix to a 3D one. |
|
|
Parameters: |
|
|
weight_2d: The weight matrix of 2D conv to be inflated. |
|
|
weight_3d: The weight matrix of 3D conv to be initialized. |
|
|
inflation_mode: the mode of inflation |
|
|
""" |
|
|
assert inflation_mode in ["constant", "replicate"] |
|
|
assert weight_3d.shape[:2] == weight_2d.shape[:2] |
|
|
with torch.no_grad(): |
|
|
if inflation_mode == "replicate": |
|
|
depth = weight_3d.size(2) |
|
|
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) |
|
|
else: |
|
|
weight_3d.fill_(0.0) |
|
|
weight_3d[:, :, -1].copy_(weight_2d) |
|
|
return weight_3d |
|
|
|
|
|
|
|
|
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): |
|
|
""" |
|
|
Inflate a 2D convolution bias tensor to a 3D one |
|
|
Parameters: |
|
|
bias_2d: The bias tensor of 2D conv to be inflated. |
|
|
bias_3d: The bias tensor of 3D conv to be initialized. |
|
|
inflation_mode: Placeholder to align `inflate_weight`. |
|
|
""" |
|
|
assert bias_3d.shape == bias_2d.shape |
|
|
with torch.no_grad(): |
|
|
bias_3d.copy_(bias_2d) |
|
|
return bias_3d |
|
|
|
|
|
|
|
|
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): |
|
|
""" |
|
|
the main function to inflated 2D parameters to 3D. |
|
|
""" |
|
|
weight_name = prefix + "weight" |
|
|
bias_name = prefix + "bias" |
|
|
if weight_name in state_dict: |
|
|
weight_2d = state_dict[weight_name] |
|
|
if weight_2d.dim() == 4: |
|
|
|
|
|
weight_3d = inflate_weight_fn( |
|
|
weight_2d=weight_2d, |
|
|
weight_3d=layer.weight, |
|
|
inflation_mode=layer.inflation_mode, |
|
|
) |
|
|
state_dict[weight_name] = weight_3d |
|
|
else: |
|
|
return state_dict |
|
|
|
|
|
if bias_name in state_dict: |
|
|
bias_2d = state_dict[bias_name] |
|
|
if bias_2d.dim() == 1: |
|
|
|
|
|
bias_3d = inflate_bias_fn( |
|
|
bias_2d=bias_2d, |
|
|
bias_3d=layer.bias, |
|
|
inflation_mode=layer.inflation_mode, |
|
|
) |
|
|
state_dict[bias_name] = bias_3d |
|
|
return state_dict |
|
|
|