Spaces:
Runtime error
Runtime error
| """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" | |
| import os | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from scipy.io import loadmat | |
| from torch.nn.modules import BatchNorm2d | |
| from . import resnet | |
| from . import mobilenet | |
| NUM_CLASS = 150 | |
| base_path = os.path.dirname(os.path.abspath(__file__)) # current file path | |
| colors_path = os.path.join(base_path, 'color150.mat') | |
| classes_path = os.path.join(base_path, 'object150_info.csv') | |
| segm_options = dict(colors=loadmat(colors_path)['colors'], | |
| classes=pd.read_csv(classes_path),) | |
| class NormalizeTensor: | |
| def __init__(self, mean, std, inplace=False): | |
| """Normalize a tensor image with mean and standard deviation. | |
| .. note:: | |
| This transform acts out of place by default, i.e., it does not mutates the input tensor. | |
| See :class:`~torchvision.transforms.Normalize` for more details. | |
| Args: | |
| tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
| mean (sequence): Sequence of means for each channel. | |
| std (sequence): Sequence of standard deviations for each channel. | |
| inplace(bool,optional): Bool to make this operation inplace. | |
| Returns: | |
| Tensor: Normalized Tensor image. | |
| """ | |
| self.mean = mean | |
| self.std = std | |
| self.inplace = inplace | |
| def __call__(self, tensor): | |
| if not self.inplace: | |
| tensor = tensor.clone() | |
| dtype = tensor.dtype | |
| mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) | |
| std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) | |
| tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) | |
| return tensor | |
| # Model Builder | |
| class ModelBuilder: | |
| # custom weights initialization | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.kaiming_normal_(m.weight.data) | |
| elif classname.find('BatchNorm') != -1: | |
| m.weight.data.fill_(1.) | |
| m.bias.data.fill_(1e-4) | |
| def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): | |
| pretrained = True if len(weights) == 0 else False | |
| arch = arch.lower() | |
| if arch == 'mobilenetv2dilated': | |
| orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) | |
| net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) | |
| elif arch == 'resnet18': | |
| orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnet) | |
| elif arch == 'resnet18dilated': | |
| orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |
| elif arch == 'resnet50dilated': | |
| orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |
| elif arch == 'resnet50': | |
| orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnet) | |
| else: | |
| raise Exception('Architecture undefined!') | |
| # encoders are usually pretrained | |
| # net_encoder.apply(ModelBuilder.weights_init) | |
| if len(weights) > 0: | |
| print('Loading weights for net_encoder') | |
| net_encoder.load_state_dict( | |
| torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
| return net_encoder | |
| def build_decoder(arch='ppm_deepsup', | |
| fc_dim=512, num_class=NUM_CLASS, | |
| weights='', use_softmax=False, drop_last_conv=False): | |
| arch = arch.lower() | |
| if arch == 'ppm_deepsup': | |
| net_decoder = PPMDeepsup( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| use_softmax=use_softmax, | |
| drop_last_conv=drop_last_conv) | |
| elif arch == 'c1_deepsup': | |
| net_decoder = C1DeepSup( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| use_softmax=use_softmax, | |
| drop_last_conv=drop_last_conv) | |
| else: | |
| raise Exception('Architecture undefined!') | |
| net_decoder.apply(ModelBuilder.weights_init) | |
| if len(weights) > 0: | |
| print('Loading weights for net_decoder') | |
| net_decoder.load_state_dict( | |
| torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
| return net_decoder | |
| def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs): | |
| path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') | |
| return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv) | |
| def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation, | |
| *arts, **kwargs): | |
| if segmentation: | |
| path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') | |
| else: | |
| path = '' | |
| return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path) | |
| def conv3x3_bn_relu(in_planes, out_planes, stride=1): | |
| return nn.Sequential( | |
| nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), | |
| BatchNorm2d(out_planes), | |
| nn.ReLU(inplace=True), | |
| ) | |
| class SegmentationModule(nn.Module): | |
| def __init__(self, | |
| weights_path, | |
| num_classes=150, | |
| arch_encoder="resnet50dilated", | |
| drop_last_conv=False, | |
| net_enc=None, # None for Default encoder | |
| net_dec=None, # None for Default decoder | |
| encode=None, # {None, 'binary', 'color', 'sky'} | |
| use_default_normalization=False, | |
| return_feature_maps=False, | |
| return_feature_maps_level=3, # {0, 1, 2, 3} | |
| return_feature_maps_only=True, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.weights_path = weights_path | |
| self.drop_last_conv = drop_last_conv | |
| self.arch_encoder = arch_encoder | |
| if self.arch_encoder == "resnet50dilated": | |
| self.arch_decoder = "ppm_deepsup" | |
| self.fc_dim = 2048 | |
| elif self.arch_encoder == "mobilenetv2dilated": | |
| self.arch_decoder = "c1_deepsup" | |
| self.fc_dim = 320 | |
| else: | |
| raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}") | |
| model_builder_kwargs = dict(arch_encoder=self.arch_encoder, | |
| arch_decoder=self.arch_decoder, | |
| fc_dim=self.fc_dim, | |
| drop_last_conv=drop_last_conv, | |
| weights_path=self.weights_path) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc | |
| self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec | |
| self.use_default_normalization = use_default_normalization | |
| self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| self.encode = encode | |
| self.return_feature_maps = return_feature_maps | |
| assert 0 <= return_feature_maps_level <= 3 | |
| self.return_feature_maps_level = return_feature_maps_level | |
| def normalize_input(self, tensor): | |
| if tensor.min() < 0 or tensor.max() > 1: | |
| raise ValueError("Tensor should be 0..1 before using normalize_input") | |
| return self.default_normalization(tensor) | |
| def feature_maps_channels(self): | |
| return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048 | |
| def forward(self, img_data, segSize=None): | |
| if segSize is None: | |
| raise NotImplementedError("Please pass segSize param. By default: (300, 300)") | |
| fmaps = self.encoder(img_data, return_feature_maps=True) | |
| pred = self.decoder(fmaps, segSize=segSize) | |
| if self.return_feature_maps: | |
| return pred, fmaps | |
| # print("BINARY", img_data.shape, pred.shape) | |
| return pred | |
| def multi_mask_from_multiclass(self, pred, classes): | |
| def isin(ar1, ar2): | |
| return (ar1[..., None] == ar2).any(-1).float() | |
| return isin(pred, torch.LongTensor(classes).to(self.device)) | |
| def multi_mask_from_multiclass_probs(scores, classes): | |
| res = None | |
| for c in classes: | |
| if res is None: | |
| res = scores[:, c] | |
| else: | |
| res += scores[:, c] | |
| return res | |
| def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600) | |
| segSize=None): | |
| """Entry-point for segmentation. Use this methods instead of forward | |
| Arguments: | |
| tensor {torch.Tensor} -- BCHW | |
| Keyword Arguments: | |
| imgSizes {tuple or list} -- imgSizes for segmentation input. | |
| default: (300, 450) | |
| original implementation: (300, 375, 450, 525, 600) | |
| """ | |
| if segSize is None: | |
| segSize = tensor.shape[-2:] | |
| segSize = (tensor.shape[2], tensor.shape[3]) | |
| with torch.no_grad(): | |
| if self.use_default_normalization: | |
| tensor = self.normalize_input(tensor) | |
| scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device) | |
| features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device) | |
| result = [] | |
| for img_size in imgSizes: | |
| if img_size != -1: | |
| img_data = F.interpolate(tensor.clone(), size=img_size) | |
| else: | |
| img_data = tensor.clone() | |
| if self.return_feature_maps: | |
| pred_current, fmaps = self.forward(img_data, segSize=segSize) | |
| else: | |
| pred_current = self.forward(img_data, segSize=segSize) | |
| result.append(pred_current) | |
| scores = scores + pred_current / len(imgSizes) | |
| # Disclaimer: We use and aggregate only last fmaps: fmaps[3] | |
| if self.return_feature_maps: | |
| features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes) | |
| _, pred = torch.max(scores, dim=1) | |
| if self.return_feature_maps: | |
| return features | |
| return pred, result | |
| def get_edges(self, t): | |
| edge = torch.cuda.ByteTensor(t.size()).zero_() | |
| edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) | |
| edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) | |
| edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) | |
| edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) | |
| if True: | |
| return edge.half() | |
| return edge.float() | |
| # pyramid pooling, deep supervision | |
| class PPMDeepsup(nn.Module): | |
| def __init__(self, num_class=NUM_CLASS, fc_dim=4096, | |
| use_softmax=False, pool_scales=(1, 2, 3, 6), | |
| drop_last_conv=False): | |
| super().__init__() | |
| self.use_softmax = use_softmax | |
| self.drop_last_conv = drop_last_conv | |
| self.ppm = [] | |
| for scale in pool_scales: | |
| self.ppm.append(nn.Sequential( | |
| nn.AdaptiveAvgPool2d(scale), | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| BatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm = nn.ModuleList(self.ppm) | |
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
| self.conv_last = nn.Sequential( | |
| nn.Conv2d(fc_dim + len(pool_scales) * 512, 512, | |
| kernel_size=3, padding=1, bias=False), | |
| BatchNorm2d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.1), | |
| nn.Conv2d(512, num_class, kernel_size=1) | |
| ) | |
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| self.dropout_deepsup = nn.Dropout2d(0.1) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale in self.ppm: | |
| ppm_out.append(nn.functional.interpolate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False)) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| if self.drop_last_conv: | |
| return ppm_out | |
| else: | |
| x = self.conv_last(ppm_out) | |
| if self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| x = nn.functional.softmax(x, dim=1) | |
| return x | |
| # deep sup | |
| conv4 = conv_out[-2] | |
| _ = self.cbr_deepsup(conv4) | |
| _ = self.dropout_deepsup(_) | |
| _ = self.conv_last_deepsup(_) | |
| x = nn.functional.log_softmax(x, dim=1) | |
| _ = nn.functional.log_softmax(_, dim=1) | |
| return (x, _) | |
| class Resnet(nn.Module): | |
| def __init__(self, orig_resnet): | |
| super(Resnet, self).__init__() | |
| # take pretrained resnet, except AvgPool and FC | |
| self.conv1 = orig_resnet.conv1 | |
| self.bn1 = orig_resnet.bn1 | |
| self.relu1 = orig_resnet.relu1 | |
| self.conv2 = orig_resnet.conv2 | |
| self.bn2 = orig_resnet.bn2 | |
| self.relu2 = orig_resnet.relu2 | |
| self.conv3 = orig_resnet.conv3 | |
| self.bn3 = orig_resnet.bn3 | |
| self.relu3 = orig_resnet.relu3 | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def forward(self, x, return_feature_maps=False): | |
| conv_out = [] | |
| x = self.relu1(self.bn1(self.conv1(x))) | |
| x = self.relu2(self.bn2(self.conv2(x))) | |
| x = self.relu3(self.bn3(self.conv3(x))) | |
| x = self.maxpool(x) | |
| x = self.layer1(x); conv_out.append(x); | |
| x = self.layer2(x); conv_out.append(x); | |
| x = self.layer3(x); conv_out.append(x); | |
| x = self.layer4(x); conv_out.append(x); | |
| if return_feature_maps: | |
| return conv_out | |
| return [x] | |
| # Resnet Dilated | |
| class ResnetDilated(nn.Module): | |
| def __init__(self, orig_resnet, dilate_scale=8): | |
| super().__init__() | |
| from functools import partial | |
| if dilate_scale == 8: | |
| orig_resnet.layer3.apply( | |
| partial(self._nostride_dilate, dilate=2)) | |
| orig_resnet.layer4.apply( | |
| partial(self._nostride_dilate, dilate=4)) | |
| elif dilate_scale == 16: | |
| orig_resnet.layer4.apply( | |
| partial(self._nostride_dilate, dilate=2)) | |
| # take pretrained resnet, except AvgPool and FC | |
| self.conv1 = orig_resnet.conv1 | |
| self.bn1 = orig_resnet.bn1 | |
| self.relu1 = orig_resnet.relu1 | |
| self.conv2 = orig_resnet.conv2 | |
| self.bn2 = orig_resnet.bn2 | |
| self.relu2 = orig_resnet.relu2 | |
| self.conv3 = orig_resnet.conv3 | |
| self.bn3 = orig_resnet.bn3 | |
| self.relu3 = orig_resnet.relu3 | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def _nostride_dilate(self, m, dilate): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| # the convolution with stride | |
| if m.stride == (2, 2): | |
| m.stride = (1, 1) | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate // 2, dilate // 2) | |
| m.padding = (dilate // 2, dilate // 2) | |
| # other convoluions | |
| else: | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate, dilate) | |
| m.padding = (dilate, dilate) | |
| def forward(self, x, return_feature_maps=False): | |
| conv_out = [] | |
| x = self.relu1(self.bn1(self.conv1(x))) | |
| x = self.relu2(self.bn2(self.conv2(x))) | |
| x = self.relu3(self.bn3(self.conv3(x))) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| conv_out.append(x) | |
| x = self.layer2(x) | |
| conv_out.append(x) | |
| x = self.layer3(x) | |
| conv_out.append(x) | |
| x = self.layer4(x) | |
| conv_out.append(x) | |
| if return_feature_maps: | |
| return conv_out | |
| return [x] | |
| class MobileNetV2Dilated(nn.Module): | |
| def __init__(self, orig_net, dilate_scale=8): | |
| super(MobileNetV2Dilated, self).__init__() | |
| from functools import partial | |
| # take pretrained mobilenet features | |
| self.features = orig_net.features[:-1] | |
| self.total_idx = len(self.features) | |
| self.down_idx = [2, 4, 7, 14] | |
| if dilate_scale == 8: | |
| for i in range(self.down_idx[-2], self.down_idx[-1]): | |
| self.features[i].apply( | |
| partial(self._nostride_dilate, dilate=2) | |
| ) | |
| for i in range(self.down_idx[-1], self.total_idx): | |
| self.features[i].apply( | |
| partial(self._nostride_dilate, dilate=4) | |
| ) | |
| elif dilate_scale == 16: | |
| for i in range(self.down_idx[-1], self.total_idx): | |
| self.features[i].apply( | |
| partial(self._nostride_dilate, dilate=2) | |
| ) | |
| def _nostride_dilate(self, m, dilate): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| # the convolution with stride | |
| if m.stride == (2, 2): | |
| m.stride = (1, 1) | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate//2, dilate//2) | |
| m.padding = (dilate//2, dilate//2) | |
| # other convoluions | |
| else: | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate, dilate) | |
| m.padding = (dilate, dilate) | |
| def forward(self, x, return_feature_maps=False): | |
| if return_feature_maps: | |
| conv_out = [] | |
| for i in range(self.total_idx): | |
| x = self.features[i](x) | |
| if i in self.down_idx: | |
| conv_out.append(x) | |
| conv_out.append(x) | |
| return conv_out | |
| else: | |
| return [self.features(x)] | |
| # last conv, deep supervision | |
| class C1DeepSup(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False): | |
| super(C1DeepSup, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.drop_last_conv = drop_last_conv | |
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
| # last conv | |
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| x = self.cbr(conv5) | |
| if self.drop_last_conv: | |
| return x | |
| else: | |
| x = self.conv_last(x) | |
| if self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| x = nn.functional.softmax(x, dim=1) | |
| return x | |
| # deep sup | |
| conv4 = conv_out[-2] | |
| _ = self.cbr_deepsup(conv4) | |
| _ = self.conv_last_deepsup(_) | |
| x = nn.functional.log_softmax(x, dim=1) | |
| _ = nn.functional.log_softmax(_, dim=1) | |
| return (x, _) | |
| # last conv | |
| class C1(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): | |
| super(C1, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
| # last conv | |
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| x = self.cbr(conv5) | |
| x = self.conv_last(x) | |
| if self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| x = nn.functional.softmax(x, dim=1) | |
| else: | |
| x = nn.functional.log_softmax(x, dim=1) | |
| return x | |
| # pyramid pooling | |
| class PPM(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=4096, | |
| use_softmax=False, pool_scales=(1, 2, 3, 6)): | |
| super(PPM, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.ppm = [] | |
| for scale in pool_scales: | |
| self.ppm.append(nn.Sequential( | |
| nn.AdaptiveAvgPool2d(scale), | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| BatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm = nn.ModuleList(self.ppm) | |
| self.conv_last = nn.Sequential( | |
| nn.Conv2d(fc_dim+len(pool_scales)*512, 512, | |
| kernel_size=3, padding=1, bias=False), | |
| BatchNorm2d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.1), | |
| nn.Conv2d(512, num_class, kernel_size=1) | |
| ) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale in self.ppm: | |
| ppm_out.append(nn.functional.interpolate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False)) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| x = self.conv_last(ppm_out) | |
| if self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| x = nn.functional.softmax(x, dim=1) | |
| else: | |
| x = nn.functional.log_softmax(x, dim=1) | |
| return x | |