Spaces:
Sleeping
Sleeping
| import math | |
| import warnings | |
| from functools import partial | |
| import torch | |
| from torch import nn | |
| from .transformer import Block | |
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |
| # Cut & paste from PyTorch official master until it's in a few official releases - RW | |
| # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
| def norm_cdf(x): | |
| # Computes standard normal cumulative distribution function | |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
| if (mean < a - 2 * std) or (mean > b + 2 * std): | |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
| "The distribution of values may be incorrect.", | |
| stacklevel=2) | |
| with torch.no_grad(): | |
| # Values are generated by using a truncated uniform distribution and | |
| # then using the inverse CDF for the normal distribution. | |
| # Get upper and lower cdf values | |
| l = norm_cdf((a - mean) / std) | |
| u = norm_cdf((b - mean) / std) | |
| # Uniformly fill tensor with values from [l, u], then translate to | |
| # [2l-1, 2u-1]. | |
| tensor.uniform_(2 * l - 1, 2 * u - 1) | |
| # Use inverse cdf transform for normal distribution to get truncated | |
| # standard normal | |
| tensor.erfinv_() | |
| # Transform to proper mean, std | |
| tensor.mul_(std * math.sqrt(2.)) | |
| tensor.add_(mean) | |
| # Clamp to ensure it's in the proper range | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |
| def get_num_patches(height=64, width=1001, patch_height=16, patch_width=16): | |
| return (height // patch_height) * (width // patch_width) | |
| from einops.layers.torch import Rearrange | |
| class PatchEmbed_v2(nn.Module): | |
| def __init__(self, patch_height=64, patch_width=4, embed_dim=768, input_dim=1): | |
| super().__init__() | |
| self.patch_height = patch_height | |
| self.patch_width = patch_width | |
| self.patch_maker = Rearrange('b c (h p1) (w p2) -> b (w h) (p1 p2 c)', p1=patch_height, p2=patch_width) | |
| self.patch_embed = nn.Linear(patch_height * patch_width * input_dim, embed_dim) | |
| def forward(self, melspec, length=None): | |
| height = melspec.shape[2] - melspec.shape[2] % self.patch_height | |
| width = melspec.shape[3] - melspec.shape[3] % self.patch_width | |
| patch = self.patch_maker(melspec[:, :, :height, :width]) | |
| patch_embed = self.patch_embed(patch) | |
| if length is not None: | |
| patch_length = (torch.div(height, self.patch_height, rounding_mode='trunc')) * torch.div( | |
| (length - length % self.patch_width), self.patch_width, rounding_mode='trunc') | |
| else: | |
| patch_length = None | |
| return patch, patch_embed, patch_length | |
| class FrameAST(nn.Module): | |
| """ Vision Transformer """ | |
| def __init__(self, nprompt=0, spec_h=64, spec_w=1001, patch_w=16, patch_h=16, pos_type="cut", in_chans=1, | |
| num_classes=0, embed_dim=768, depth=12, | |
| num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.0, attn_drop_rate=0., | |
| drop_path_rate=0.0, norm_layer=nn.LayerNorm, **kwargs): | |
| super().__init__() | |
| self.num_features = self.embed_dim = embed_dim | |
| self.spec_w = spec_w | |
| self.spec_h = spec_h | |
| self.embed_dim = embed_dim | |
| self.patch_w = patch_w | |
| self.patch_h = patch_h | |
| self.pos_type = pos_type | |
| self.patch_embed = PatchEmbed_v2(patch_h, patch_w, embed_dim) | |
| self.mask_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
| # hack | |
| self.nprompt = nprompt | |
| if self.nprompt > 0: | |
| self.prompt_embed = nn.Parameter(torch.zeros(1, self.nprompt, self.embed_dim)) | |
| trunc_normal_(self.prompt_embed, std=.02) | |
| num_patches = get_num_patches(spec_h, spec_w, patch_h, patch_w) | |
| self.num_patches = num_patches | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
| self.blocks = nn.ModuleList([ | |
| Block( | |
| dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | |
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
| for i in range(depth)]) | |
| self.norm_frame = norm_layer(embed_dim) | |
| trunc_normal_(self.pos_embed, std=.02) | |
| trunc_normal_(self.mask_embed, std=.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def prepare_tokens(self, x, mask_index, length, mask=True): | |
| B, nc, h, w = x.shape | |
| mel_patches, x, patch_length = self.patch_embed(x, length) # patch linear embedding | |
| B, T, C = x.shape | |
| if (mask_index is not None) and mask: | |
| mask_index_expand = mask_index.unsqueeze(2).expand(B, T, self.embed_dim).float() | |
| x = (1 - mask_index_expand) * x + mask_index_expand * self.mask_embed.expand(B, T, C) | |
| # add positional encoding to each token | |
| if self.pos_type == "cut": | |
| pos = self.pos_embed[:, 1:T + 1, :].expand(B, -1, -1) | |
| x = x + pos | |
| else: | |
| pos = self.interpolate_pos_encoding(x, h, w) | |
| x = x + pos[:, 1:] | |
| # pos = self.pos_embed[:,1:T+1,:].expand(B,-1,-1) | |
| # x = x + pos | |
| return self.pos_drop(x), pos, mel_patches, h, w, patch_length | |
| def forward(self, x, mask_index=None, mask_input=True, length=None): | |
| x, pos, mel_patches, h, w, patch_length = self.prepare_tokens(x, mask_index, length, mask_input) | |
| length_mask = torch.arange(mel_patches.shape[1]).to(x.device) < patch_length.unsqueeze(1) | |
| length_mask = length_mask.to(x.device) | |
| mask_index = mask_index & length_mask | |
| if self.nprompt > 0: | |
| x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1) | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x, patch_length + self.nprompt) | |
| frame_repr = self.norm_frame(x) | |
| return frame_repr[:, self.nprompt:][mask_index] | |
| def interpolate_pos_encoding(self, x, h, w): | |
| npatch = x.shape[1] - 1 | |
| N = self.pos_embed.shape[1] - 1 | |
| if npatch == N and w == self.spec_w and h == self.spec_h: | |
| return self.pos_embed | |
| class_pos_embed = self.pos_embed[:, 0] | |
| patch_pos_embed = self.pos_embed[:, 1:] | |
| dim = x.shape[-1] | |
| w0 = w // self.patch_embed.patch_width | |
| h0 = h // self.patch_embed.patch_height | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| w0, h0 = w0 + 0.1, h0 + 0.1 | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed.reshape(1, self.spec_h // self.patch_h, self.spec_w // self.patch_w, dim).permute(0, 3, 1, | |
| 2), | |
| scale_factor=(h0 / (self.spec_h // self.patch_h), w0 / (self.spec_w // self.patch_w)), | |
| mode='bicubic', | |
| ) | |
| assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | |
| def get_last_selfattention(self, x): | |
| x, _, _, _, _, _ = self.prepare_tokens(x, mask_index=None, length=None, mask=False) | |
| atts = [] | |
| for i, blk in enumerate(self.blocks): | |
| if i < len(self.blocks) - 1: | |
| x, att = blk(x, return_attention=True) | |
| atts.append(att) | |
| else: | |
| x, att = blk(x, return_attention=True) | |
| atts.append(att) | |
| return atts | |
| # return attention of the last block | |
| def get_intermediate_layers(self, x, length, n=1, scene=True, other_emb=None): | |
| x, _, _, _, _, patch_length = self.prepare_tokens(x, mask_index=None, length=length, mask=False) | |
| # we return the output tokens from the `n` last blocks | |
| if other_emb is not None: | |
| x = torch.cat([other_emb, x], dim=1) | |
| output = [] | |
| if self.nprompt > 0: | |
| x = torch.cat([self.prompt_embed.expand(x.shape[0], -1, -1), x], dim=1) | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x, patch_length + self.nprompt) | |
| if len(self.blocks) - i <= n: | |
| norm_x = self.norm_frame(x) | |
| if scene: | |
| length_mask = torch.arange(x.shape[1] - self.nprompt).to(x.device) < patch_length.unsqueeze(1) | |
| avg = torch.sum(norm_x[:, self.nprompt:] * length_mask.unsqueeze(-1), dim=1) / ( | |
| patch_length.unsqueeze(-1) + 1e-6) | |
| negative = (~length_mask) * -1e10 | |
| # max = torch.max(norm_x[:,self.nprompt:]+negative.unsqueeze(-1),1).values | |
| output.append(avg) | |
| if self.nprompt > 0: | |
| output.append(torch.mean(norm_x[:, :self.nprompt], dim=1)) | |
| else: | |
| output.append(norm_x[:, self.nprompt:]) | |
| return torch.cat(output, dim=-1) | |
| def get_cls_avg(output_i, cur_len, use_cls): | |
| length_mask = torch.arange(output_i[0].shape[1]).to(output_i[0].device) < cur_len.unsqueeze(1) | |
| cls = [torch.zeros_like(x[:, 0]) for x in output_i] | |
| avg = [torch.sum(x * length_mask.unsqueeze(-1), dim=1) / (cur_len.unsqueeze(1) + 1e-6) for x in output_i] | |
| return cls, avg | |
| def FrameASTModel(patch_h=64, patch_w=4, atst_dropout=0.1, **kwargs): | |
| return FrameAST( | |
| patch_h=patch_h, | |
| patch_w=patch_w, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| qkv_bias=False, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| drop_path_rate=atst_dropout, | |
| drop_rate=atst_dropout, | |
| **kwargs) | |