Spaces:
Build error
Build error
| """ | |
| @Date: 2021/09/01 | |
| @description: | |
| """ | |
| import warnings | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from einops import rearrange | |
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
| # Cut & paste from PyTorch official master until it's in a few official releases - RW | |
| # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
| def norm_cdf(x): | |
| # Computes standard normal cumulative distribution function | |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
| if (mean < a - 2 * std) or (mean > b + 2 * std): | |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
| "The distribution of values may be incorrect.", | |
| stacklevel=2) | |
| with torch.no_grad(): | |
| # Values are generated by using a truncated uniform distribution and | |
| # then using the inverse CDF for the normal distribution. | |
| # Get upper and lower cdf values | |
| l = norm_cdf((a - mean) / std) | |
| u = norm_cdf((b - mean) / std) | |
| # Uniformly fill tensor with values from [l, u], then translate to | |
| # [2l-1, 2u-1]. | |
| tensor.uniform_(2 * l - 1, 2 * u - 1) | |
| # Use inverse cdf transform for normal distribution to get truncated | |
| # standard normal | |
| tensor.erfinv_() | |
| # Transform to proper mean, std | |
| tensor.mul_(std * math.sqrt(2.)) | |
| tensor.add_(mean) | |
| # Clamp to ensure it's in the proper range | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| # compatibility pytorch < 1.4 | |
| class GELU(nn.Module): | |
| def forward(self, input): | |
| return F.gelu(input) | |
| class Attend(nn.Module): | |
| def __init__(self, dim=None): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, input): | |
| return F.softmax(input, dim=self.dim, dtype=input.dtype) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class RelativePosition(nn.Module): | |
| def __init__(self, heads, patch_num=None, rpe=None): | |
| super().__init__() | |
| self.rpe = rpe | |
| self.heads = heads | |
| self.patch_num = patch_num | |
| if rpe == 'lr_parameter': | |
| # -255 ~ 0 ~ 255 all count : patch * 2 - 1 | |
| count = patch_num * 2 - 1 | |
| self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
| nn.init.xavier_uniform_(self.rpe_table) | |
| elif rpe == 'lr_parameter_mirror': | |
| # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 | |
| count = patch_num // 2 + 1 | |
| self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
| nn.init.xavier_uniform_(self.rpe_table) | |
| elif rpe == 'lr_parameter_half': | |
| # -127 ~ 0 ~ 128 all count : patch | |
| count = patch_num | |
| self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
| nn.init.xavier_uniform_(self.rpe_table) | |
| elif rpe == 'fix_angle': | |
| # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 | |
| count = patch_num // 2 + 1 | |
| # we think that closer proximity should have stronger relationships | |
| rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads) | |
| self.register_buffer('rpe_table', rpe_table) | |
| def get_relative_pos_embed(self): | |
| range_vec = torch.arange(self.patch_num) | |
| distance_mat = range_vec[None, :] - range_vec[:, None] | |
| if self.rpe == 'lr_parameter': | |
| # -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255 | |
| distance_mat += self.patch_num - 1 # remove negative | |
| return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
| elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle': | |
| distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror | |
| distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[ | |
| distance_mat > self.patch_num // 2] # remove repeat | |
| return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
| elif self.rpe == 'lr_parameter_half': | |
| distance_mat[distance_mat > self.patch_num // 2] = distance_mat[ | |
| distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127 | |
| distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[ | |
| distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128 | |
| # -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128 | |
| distance_mat += self.patch_num//2 - 1 # remove negative | |
| return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
| def forward(self, attn): | |
| return attn + self.get_relative_pos_embed() | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1): | |
| """ | |
| :param dim: | |
| :param heads: | |
| :param dim_head: | |
| :param dropout: | |
| :param patch_num: | |
| :param rpe: relative position embedding | |
| """ | |
| super().__init__() | |
| self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe) | |
| inner_dim = dim_head * heads | |
| project_out = not (heads == 1 and dim_head == dim) | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| self.rpe_pos = rpe_pos | |
| self.attend = Attend(dim=-1) | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, dim), | |
| nn.Dropout(dropout) | |
| ) if project_out else nn.Identity() | |
| def forward(self, x): | |
| b, n, _, h = *x.shape, self.heads | |
| qkv = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) | |
| dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
| if self.rpe_pos == 0: | |
| if self.relative_pos_embed is not None: | |
| dots = self.relative_pos_embed(dots) | |
| attn = self.attend(dots) | |
| if self.rpe_pos == 1: | |
| if self.relative_pos_embed is not None: | |
| attn = self.relative_pos_embed(attn) | |
| out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| return self.to_out(out) | |
| class AbsolutePosition(nn.Module): | |
| def __init__(self, dim, dropout=0., patch_num=None, ape=None): | |
| super().__init__() | |
| self.ape = ape | |
| if ape == 'lr_parameter': | |
| self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim)) | |
| trunc_normal_(self.absolute_pos_embed, std=.02) | |
| elif ape == 'fix_angle': | |
| angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2) | |
| self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None] | |
| def forward(self, x): | |
| return x + self.absolute_pos_embed | |
| class WinAttention(nn.Module): | |
| def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1): | |
| super().__init__() | |
| self.win_size = win_size | |
| self.shift = shift | |
| self.attend = Attention(dim, heads=heads, dim_head=dim_head, | |
| dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter', | |
| rpe_pos=rpe_pos) | |
| def forward(self, x): | |
| b = x.shape[0] | |
| if self.shift != 0: | |
| x = torch.roll(x, shifts=self.shift, dims=-2) | |
| x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows | |
| out = self.attend(x) | |
| out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows | |
| if self.shift != 0: | |
| out = torch.roll(out, shifts=-self.shift, dims=-2) | |
| return out | |
| class Conv(nn.Module): | |
| def __init__(self, dim, dropout=0.): | |
| super().__init__() | |
| self.dim = dim | |
| self.net = nn.Sequential( | |
| nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) | |
| x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1) | |
| x = self.net(x) | |
| return x.transpose(1, 2) | |