Spaces:
Running
on
Zero
Running
on
Zero
| from einops import rearrange | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class Generator(nn.Module): | |
| def __init__(self, | |
| in_features, | |
| ffd_hidden_size, | |
| num_classes, | |
| attn_layer_num, | |
| ): | |
| super(Generator, self).__init__() | |
| self.attn = nn.ModuleList( | |
| [ | |
| nn.MultiheadAttention( | |
| embed_dim=in_features, | |
| num_heads=8, | |
| dropout=0.2, | |
| batch_first=True, | |
| ) | |
| for _ in range(attn_layer_num) | |
| ] | |
| ) | |
| self.ffd = nn.Sequential( | |
| nn.Linear(in_features, ffd_hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(ffd_hidden_size, in_features) | |
| ) | |
| self.dropout = nn.Dropout(0.2) | |
| self.fc = nn.Linear(in_features * 2, num_classes) | |
| self.proj = nn.Tanh() | |
| def forward(self, ssl_feature, judge_id=None): | |
| ''' | |
| ssl_feature: [B, T, D] | |
| output: [B, num_classes] | |
| ''' | |
| B, T, D = ssl_feature.shape | |
| ssl_feature = self.ffd(ssl_feature) | |
| tmp_ssl_feature = ssl_feature | |
| for attn in self.attn: | |
| tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature) | |
| ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D | |
| x = self.fc(ssl_feature) # B, num_classes | |
| x = self.proj(x) * 2.0 + 3 | |
| return x | |