Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| ######################################################################################################################## | |
| # Upsample + BatchNorm | |
| class UpSampleBN(nn.Module): | |
| def __init__(self, skip_input, output_features): | |
| super(UpSampleBN, self).__init__() | |
| self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(output_features), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(output_features), | |
| nn.LeakyReLU()) | |
| def forward(self, x, concat_with): | |
| up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) | |
| f = torch.cat([up_x, concat_with], dim=1) | |
| return self._net(f) | |
| # Upsample + GroupNorm + Weight Standardization | |
| class UpSampleGN(nn.Module): | |
| def __init__(self, skip_input, output_features): | |
| super(UpSampleGN, self).__init__() | |
| self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), | |
| nn.GroupNorm(8, output_features), | |
| nn.LeakyReLU(), | |
| Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), | |
| nn.GroupNorm(8, output_features), | |
| nn.LeakyReLU()) | |
| def forward(self, x, concat_with): | |
| up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) | |
| f = torch.cat([up_x, concat_with], dim=1) | |
| return self._net(f) | |
| # Conv2d with weight standardization | |
| class Conv2d(nn.Conv2d): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
| padding=0, dilation=1, groups=1, bias=True): | |
| super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias) | |
| def forward(self, x): | |
| weight = self.weight | |
| weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, | |
| keepdim=True).mean(dim=3, keepdim=True) | |
| weight = weight - weight_mean | |
| std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 | |
| weight = weight / std.expand_as(weight) | |
| return F.conv2d(x, weight, self.bias, self.stride, | |
| self.padding, self.dilation, self.groups) | |
| # normalize | |
| def norm_normalize(norm_out): | |
| min_kappa = 0.01 | |
| norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) | |
| norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 | |
| kappa = F.elu(kappa) + 1.0 + min_kappa | |
| final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) | |
| return final_out | |
| # uncertainty-guided sampling (only used during training) | |
| def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): | |
| device = init_normal.device | |
| B, _, H, W = init_normal.shape | |
| N = int(sampling_ratio * H * W) | |
| beta = beta | |
| # uncertainty map | |
| uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W | |
| # gt_invalid_mask (B, H, W) | |
| if gt_norm_mask is not None: | |
| gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') | |
| gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 | |
| uncertainty_map[gt_invalid_mask] = -1e4 | |
| # (B, H*W) | |
| _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) | |
| # importance sampling | |
| if int(beta * N) > 0: | |
| importance = idx[:, :int(beta * N)] # B, beta*N | |
| # remaining | |
| remaining = idx[:, int(beta * N):] # B, H*W - beta*N | |
| # coverage | |
| num_coverage = N - int(beta * N) | |
| if num_coverage <= 0: | |
| samples = importance | |
| else: | |
| coverage_list = [] | |
| for i in range(B): | |
| idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" | |
| coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N | |
| coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N | |
| samples = torch.cat((importance, coverage), dim=1) # B, N | |
| else: | |
| # remaining | |
| remaining = idx[:, :] # B, H*W | |
| # coverage | |
| num_coverage = N | |
| coverage_list = [] | |
| for i in range(B): | |
| idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" | |
| coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N | |
| coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N | |
| samples = coverage | |
| # point coordinates | |
| rows_int = samples // W # 0 for first row, H-1 for last row | |
| rows_float = rows_int / float(H-1) # 0 to 1.0 | |
| rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 | |
| cols_int = samples % W # 0 for first column, W-1 for last column | |
| cols_float = cols_int / float(W-1) # 0 to 1.0 | |
| cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 | |
| point_coords = torch.zeros(B, 1, N, 2) | |
| point_coords[:, 0, :, 0] = cols_float # x coord | |
| point_coords[:, 0, :, 1] = rows_float # y coord | |
| point_coords = point_coords.to(device) | |
| return point_coords, rows_int, cols_int |