Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import ModuleList | |
| from DenseAV.denseav.featurizers.DINO import Block | |
| class ChannelNorm(torch.nn.Module): | |
| def __init__(self, dim, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.norm = torch.nn.LayerNorm(dim, eps=1e-4) | |
| def forward_spatial(self, x): | |
| return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| def forward(self, x, cls): | |
| return self.forward_spatial(x), self.forward_cls(cls) | |
| def forward_cls(self, cls): | |
| if cls is not None: | |
| return self.norm(cls) | |
| else: | |
| return None | |
| def id_conv(dim, strength=.9): | |
| conv = torch.nn.Conv2d(dim, dim, 1, padding="same") | |
| start_w = conv.weight.data | |
| conv.weight.data = torch.nn.Parameter( | |
| torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength)) | |
| conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength)) | |
| return conv | |
| class LinearAligner(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, use_norm=True): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| if use_norm: | |
| self.norm = ChannelNorm(in_dim) | |
| else: | |
| self.norm = Identity2() | |
| if in_dim == out_dim: | |
| self.layer = id_conv(in_dim, 0) | |
| else: | |
| self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1) | |
| self.cls_layer = torch.nn.Linear(in_dim, out_dim) | |
| def forward(self, spatial, cls): | |
| norm_spatial, norm_cls = self.norm(spatial, cls) | |
| if cls is not None: | |
| aligned_cls = self.cls_layer(cls) | |
| else: | |
| aligned_cls = None | |
| return self.layer(norm_spatial), aligned_cls | |
| class IdLinearAligner(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| assert self.out_dim == self.in_dim | |
| self.layer = id_conv(in_dim, 1.0) | |
| def forward(self, spatial, cls): | |
| return self.layer(spatial), cls | |
| class FrequencyAvg(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, spatial, cls): | |
| return spatial.mean(2, keepdim=True), cls | |
| class LearnedTimePool(torch.nn.Module): | |
| def __init__(self, dim, width, maxpool): | |
| super().__init__() | |
| self.dim = dim | |
| self.width = width | |
| self.norm = ChannelNorm(dim) | |
| if maxpool: | |
| self.layer = torch.nn.Sequential( | |
| torch.nn.Conv2d(dim, dim, kernel_size=width, stride=1, padding="same"), | |
| torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width)) | |
| ) | |
| else: | |
| self.layer = torch.nn.Conv2d(dim, dim, kernel_size=(1, width), stride=(1, width)) | |
| def forward(self, spatial, cls): | |
| norm_spatial, norm_cls = self.norm(spatial, cls) | |
| return self.layer(norm_spatial), norm_cls | |
| class LearnedTimePool2(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, width, maxpool, use_cls_layer): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| self.width = width | |
| if maxpool: | |
| self.layer = torch.nn.Sequential( | |
| torch.nn.Conv2d(in_dim, out_dim, kernel_size=width, stride=1, padding="same"), | |
| torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width)) | |
| ) | |
| else: | |
| self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=(1, width), stride=(1, width)) | |
| self.use_cls_layer = use_cls_layer | |
| if use_cls_layer: | |
| self.cls_layer = torch.nn.Linear(in_dim, out_dim) | |
| def forward(self, spatial, cls): | |
| if cls is not None: | |
| if self.use_cls_layer: | |
| aligned_cls = self.cls_layer(cls) | |
| else: | |
| aligned_cls = cls | |
| else: | |
| aligned_cls = None | |
| return self.layer(spatial), aligned_cls | |
| class Sequential2(torch.nn.Module): | |
| def __init__(self, *modules): | |
| super().__init__() | |
| self.mod_list = ModuleList(modules) | |
| def forward(self, x, y): | |
| results = (x, y) | |
| for m in self.mod_list: | |
| results = m(*results) | |
| return results | |
| class ProgressiveGrowing(torch.nn.Module): | |
| def __init__(self, stages, phase_lengths): | |
| super().__init__() | |
| self.stages = torch.nn.ModuleList(stages) | |
| self.phase_lengths = torch.tensor(phase_lengths) | |
| assert len(self.phase_lengths) + 1 == len(self.stages) | |
| self.phase_boundaries = self.phase_lengths.cumsum(0) | |
| self.register_buffer('phase', torch.tensor([1])) | |
| def maybe_change_phase(self, global_step): | |
| needed_phase = (global_step >= self.phase_boundaries).to(torch.int64).sum().item() + 1 | |
| if needed_phase != self.phase.item(): | |
| print(f"Changing aligner phase to {needed_phase}") | |
| self.phase.copy_(torch.tensor([needed_phase]).to(self.phase.device)) | |
| return True | |
| else: | |
| return False | |
| def parameters(self, recurse: bool = True): | |
| phase = self.phase.item() | |
| used_stages = self.stages[:phase] | |
| print(f"Progressive Growing at stage {phase}") | |
| all_params = [] | |
| for stage in used_stages: | |
| all_params.extend(stage.parameters(recurse)) | |
| return iter(all_params) | |
| def forward(self, spatial, cls): | |
| pipeline = Sequential2(*self.stages[:self.phase.item()]) | |
| return pipeline(spatial, cls) | |
| class Identity2(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, y): | |
| return x, y | |
| class SelfAttentionAligner(torch.nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = 6 | |
| if dim % self.num_heads != 0: | |
| self.padding = self.num_heads - (dim % self.num_heads) | |
| else: | |
| self.padding = 0 | |
| self.block = Block( | |
| dim + self.padding, | |
| num_heads=self.num_heads, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=0.0, | |
| norm_layer=partial(torch.nn.LayerNorm, eps=1e-4)) | |
| def forward(self, spatial, cls): | |
| padded_feats = F.pad(spatial, [0, 0, 0, 0, self.padding, 0]) | |
| B, C, H, W = padded_feats.shape | |
| proj_feats = padded_feats.reshape(B, C, H * W).permute(0, 2, 1) | |
| if cls is not None: | |
| assert len(cls.shape) == 2 | |
| padded_cls = F.pad(cls, [self.padding, 0]) | |
| proj_feats = torch.cat([padded_cls.unsqueeze(1), proj_feats], dim=1) | |
| aligned_feat, attn, qkv = self.block(proj_feats, return_qkv=True) | |
| if cls is not None: | |
| aligned_cls = aligned_feat[:, 0, :] | |
| aligned_spatial = aligned_feat[:, 1:, :] | |
| else: | |
| aligned_cls = None | |
| aligned_spatial = aligned_feat | |
| aligned_spatial = aligned_spatial.reshape(B, H, W, self.dim + self.padding).permute(0, 3, 1, 2) | |
| aligned_spatial = aligned_spatial[:, self.padding:, :, :] | |
| if aligned_cls is not None: | |
| aligned_cls = aligned_cls[:, self.padding:] | |
| return aligned_spatial, aligned_cls | |
| def get_aligner(aligner_type, in_dim, out_dim, **kwargs): | |
| if aligner_type is None: | |
| return Identity2() | |
| if "prog" in aligner_type: | |
| phase_length = kwargs["phase_length"] | |
| if aligner_type == "image_linear": | |
| return LinearAligner(in_dim, out_dim) | |
| elif aligner_type == "image_idlinear": | |
| return IdLinearAligner(in_dim, out_dim) | |
| elif aligner_type == "image_linear_no_norm": | |
| return LinearAligner(in_dim, out_dim, use_norm=False) | |
| elif aligner_type == "image_id": | |
| return Identity2() | |
| elif aligner_type == "image_norm": | |
| return ChannelNorm(in_dim) | |
| elif aligner_type == "audio_linear": | |
| return Sequential2( | |
| LinearAligner(in_dim, out_dim), | |
| FrequencyAvg()) | |
| elif aligner_type == "audio_sa": | |
| return Sequential2( | |
| LinearAligner(in_dim, out_dim), | |
| FrequencyAvg(), | |
| SelfAttentionAligner(out_dim) | |
| ) | |
| elif aligner_type == "audio_sa_sa": | |
| return Sequential2( | |
| FrequencyAvg(), | |
| LinearAligner(in_dim, out_dim), | |
| SelfAttentionAligner(out_dim), | |
| SelfAttentionAligner(out_dim) | |
| ) | |
| elif aligner_type == "audio_3_3_pool": | |
| return Sequential2( | |
| LinearAligner(in_dim, out_dim), | |
| FrequencyAvg(), | |
| LearnedTimePool(out_dim, 3, False), | |
| LearnedTimePool(out_dim, 3, False), | |
| ) | |
| elif aligner_type == "audio_sa_3_3_pool": | |
| return Sequential2( | |
| LinearAligner(in_dim, out_dim), | |
| FrequencyAvg(), | |
| LearnedTimePool(out_dim, 3, False), | |
| LearnedTimePool(out_dim, 3, False), | |
| SelfAttentionAligner(out_dim) | |
| ) | |
| elif aligner_type == "audio_sa_3_3_pool_2": | |
| return Sequential2( | |
| FrequencyAvg(), | |
| ChannelNorm(in_dim), | |
| LearnedTimePool2(in_dim, out_dim, 3, False, True), | |
| LearnedTimePool2(out_dim, out_dim, 3, False, False), | |
| SelfAttentionAligner(out_dim) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown aligner type {aligner_type}") | |