Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import, division | |
| import torch | |
| from torch import nn | |
| class GlobalGraph(nn.Module): | |
| """" | |
| Global graph attention layer | |
| """ | |
| def __init__(self, adj, in_channels, inter_channels=None): | |
| super(GlobalGraph, self).__init__() | |
| self.adj = adj | |
| self.in_channels = in_channels | |
| self.inter_channels = inter_channels | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.leakyrelu = nn.LeakyReLU(0.2) | |
| if self.inter_channels == self.in_channels // 2: | |
| self.g_channels = self.in_channels | |
| else: | |
| self.g_channels = self.inter_channels | |
| assert self.inter_channels > 0 | |
| self.g = nn.Conv1d(in_channels=self.in_channels, out_channels=self.g_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| self.theta = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| self.phi = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
| kernel_size=1, stride=1, padding=0) | |
| adj_shape = self.adj.shape | |
| self.C_k = nn.Parameter(torch.zeros(adj_shape, dtype=torch.float)) | |
| self.concat_project = nn.Sequential( | |
| nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), | |
| ) | |
| nn.init.kaiming_normal_(self.concat_project[0].weight) | |
| nn.init.kaiming_normal_(self.g.weight) | |
| nn.init.constant_(self.g.bias, 0) | |
| nn.init.kaiming_normal_(self.theta.weight) | |
| nn.init.constant_(self.theta.bias, 0) | |
| nn.init.kaiming_normal_(self.phi.weight) | |
| nn.init.constant_(self.phi.bias, 0) | |
| def forward(self, x): | |
| batch_size = x.size(0) # x: (B*T, C, N) | |
| # g_x: (B*T, N, C/k) | |
| g_x = self.g(x).view(batch_size, self.g_channels, -1) | |
| g_x = g_x.permute(0, 2, 1) | |
| # (B*T, C/k, N, 1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) | |
| # (B*T, C/k, 1, N) | |
| phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) | |
| # h: N, w: N | |
| h = theta_x.size(2) | |
| w = phi_x.size(3) | |
| theta_x = theta_x.expand(-1, -1, -1, w) # (B*T, C/k, N, N) | |
| phi_x = phi_x.expand(-1, -1, h, -1) | |
| # concat_feature: (B*T, C/k, N, N) | |
| concat_feature = torch.cat([theta_x, phi_x], dim=1) | |
| f = self.concat_project(concat_feature) # (B*T, 1, N, N) | |
| b, _, h, w = f.size() | |
| attention = self.leakyrelu(f.view(b, h, w)) # (B*T, N, N) attention:B_k | |
| attention = torch.add(self.softmax(attention), self.C_k) | |
| # y: (B*T, C/k, N) | |
| y = torch.matmul(attention, g_x) | |
| y = y.permute(0, 2, 1).contiguous() | |
| y = y.view(batch_size, self.g_channels, *x.size()[2:]) | |
| return y | |
| class MultiGlobalGraph(nn.Module): | |
| def __init__(self, adj, in_channels, inter_channels, dropout=None): | |
| super(MultiGlobalGraph, self).__init__() | |
| self.num_non_local = in_channels // inter_channels | |
| attentions = [GlobalGraph(adj, in_channels, inter_channels) for _ in range(self.num_non_local)] | |
| self.attentions = nn.ModuleList(attentions) | |
| self.cat_conv = nn.Conv2d(in_channels, in_channels, 1, bias=False) | |
| self.cat_bn = nn.BatchNorm2d(in_channels, momentum=0.1) | |
| self.relu = nn.ReLU(inplace=True) | |
| if dropout is not None: | |
| self.dropout = nn.Dropout(dropout) | |
| else: | |
| self.dropout = None | |
| def forward(self, x): | |
| # x: (B, T, K, C) --> (B*T, K, C) | |
| x_size = x.shape | |
| x = x.contiguous() | |
| x = x.view(-1, *x_size[2:]) | |
| # x: (B*T, C, K) | |
| x = x.permute(0, 2, 1) | |
| x = torch.cat([self.attentions[i](x) for i in range(len(self.attentions))], dim=1) | |
| # x: (B*T, C, K) --> (B*T, K, C) | |
| x = x.permute(0, 2, 1).contiguous() | |
| # x = torch.matmul(x, self.W) | |
| # x: (B*T, K, C) --> (B, T, K, C) | |
| x = x.view(*x_size) | |
| # x: (B, T, K, C) --> (B, C, T, K) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.relu(self.cat_bn(self.cat_conv(x))) | |
| if self.dropout is not None: | |
| x = self.dropout(x) | |
| # x: (B, C, T, K) --> (B, T, K, C) | |
| x = x.permute(0, 2, 3, 1) | |
| return x | |
| class SingleGlobalGraph(nn.Module): | |
| def __init__(self, adj, in_channels, output_channels, dropout=None): | |
| super(SingleGlobalGraph, self).__init__() | |
| self.attentions = GlobalGraph(adj, in_channels, output_channels//2) | |
| self.bn = nn.BatchNorm2d(in_channels, momentum=0.1) | |
| self.relu = nn.ReLU(inplace=True) | |
| if dropout is not None: | |
| self.dropout = nn.Dropout(dropout) | |
| else: | |
| self.dropout = None | |
| def forward(self, x): | |
| # x: (B, T, K, C) --> (B*T, K, C) | |
| x_size = x.shape | |
| x = x.contiguous() | |
| x = x.view(-1, *x_size[2:]) | |
| # x: (B*T, C, K) | |
| x = x.permute(0, 2, 1) | |
| x = self.attentions(x) | |
| # x: (B*T, C, K) --> (B*T, K, C) | |
| x = x.permute(0, 2, 1).contiguous() | |
| # x = torch.matmul(x, self.W) | |
| # x: (B*T, K, C) --> (B, T, K, C) | |
| x = x.view(*x_size) | |
| # x: (B, T, K, C) --> (B, C, T, K) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.relu(self.bn(x)) | |
| if self.dropout is not None: | |
| x = self.dropout(x) | |
| # x: (B, C, T, K) --> (B, T, K, C) | |
| x = x.permute(0, 2, 3, 1) | |
| return x | |