Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| def make_one_hot(cfg, is_inference, data): | |
| r"""Convert appropriate image data types to one-hot representation. | |
| Args: | |
| data (dict): Dict containing data_type as key, with each value | |
| as a list of torch.Tensors. | |
| Returns: | |
| data (dict): same as input data, but with one-hot for selected | |
| types. | |
| """ | |
| assert hasattr(cfg, 'one_hot_num_classes') | |
| num_classes = getattr(cfg, 'one_hot_num_classes') | |
| use_dont_care = getattr(cfg, 'use_dont_care', False) | |
| for data_type, data_type_num_classes in num_classes.items(): | |
| if data_type in data.keys(): | |
| data[data_type] = _encode_onehot(data[data_type] * 255.0, data_type_num_classes, use_dont_care).float() | |
| return data | |
| def concat_labels(cfg, is_inference, data): | |
| assert hasattr(cfg, 'input_labels') | |
| input_labels = getattr(cfg, 'input_labels') | |
| dataset_type = getattr(cfg, 'type') | |
| # Package output. | |
| labels = [] | |
| for data_type in input_labels: | |
| label = data.pop(data_type) | |
| labels.append(label) | |
| if not ('video' in dataset_type): | |
| data['label'] = torch.cat(labels, dim=0) | |
| else: | |
| data['label'] = torch.cat(labels, dim=1) | |
| return data | |
| def concat_few_shot_labels(cfg, is_inference, data): | |
| assert hasattr(cfg, 'input_few_shot_labels') | |
| input_labels = getattr(cfg, 'input_few_shot_labels') | |
| dataset_type = getattr(cfg, 'type') | |
| # Package output. | |
| labels = [] | |
| for data_type in input_labels: | |
| label = data.pop(data_type) | |
| labels.append(label) | |
| if not ('video' in dataset_type): | |
| data['few_shot_label'] = torch.cat(labels, dim=0) | |
| else: | |
| data['few_shot_label'] = torch.cat(labels, dim=1) | |
| return data | |
| def move_dont_care(cfg, is_inference, data): | |
| assert hasattr(cfg, 'move_dont_care') | |
| move_dont_care = getattr(cfg, 'move_dont_care') | |
| for data_type, data_type_num_classes in move_dont_care.items(): | |
| label_map = data[data_type] * 255.0 | |
| label_map[label_map < 0] = data_type_num_classes | |
| label_map[label_map >= data_type_num_classes] = data_type_num_classes | |
| data[data_type] = label_map / 255.0 | |
| return data | |
| def _encode_onehot(label_map, num_classes, use_dont_care): | |
| r"""Make input one-hot. | |
| Args: | |
| label_map (torch.Tensor): (C, H, W) tensor containing indices. | |
| num_classes (int): Number of labels to expand tensor to. | |
| use_dont_care (bool): Use the dont care label or not? | |
| Returns: | |
| output (torch.Tensor): (num_classes, H, W) one-hot tensor. | |
| """ | |
| # All labels lie in [0. num_classes - 1]. | |
| # Encode dont care as num_classes. | |
| label_map[label_map < 0] = num_classes | |
| label_map[label_map >= num_classes] = num_classes | |
| size = label_map.size() | |
| output_size = (num_classes + 1, size[-2], size[-1]) | |
| output = torch.zeros(*output_size) | |
| if label_map.dim() == 4: | |
| output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1) | |
| output = output.scatter_(1, label_map.data.long(), 1.0) | |
| if not use_dont_care: | |
| output = output[:, :num_classes, ...] | |
| else: | |
| output = output.scatter_(0, label_map.data.long(), 1.0) | |
| if not use_dont_care: | |
| output = output[:num_classes, ...] | |
| return output | |