Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import os | |
| import boto3 | |
| import torch | |
| from torch import nn, distributed as dist | |
| from torch.nn import functional as F | |
| from imaginaire.utils.distributed import is_local_master | |
| from imaginaire.utils.io import download_file_from_google_drive | |
| def get_segmentation_hist_model(dataset_name, aws_credentials=None): | |
| if dist.is_initialized() and not is_local_master(): | |
| # Make sure only the first process in distributed training downloads | |
| # the model, and the others will use the cache | |
| # noinspection PyUnresolvedReferences | |
| torch.distributed.barrier() | |
| # Load the segmentation network. | |
| if dataset_name == "celebamask_hq": | |
| from imaginaire.evaluation.segmentation.celebamask_hq import Unet | |
| seg_network = Unet() | |
| os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True) | |
| model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt") | |
| if not os.path.exists(model_path): | |
| if aws_credentials is not None: | |
| s3 = boto3.client('s3', **aws_credentials) | |
| s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path) | |
| else: | |
| download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| seg_network.load_state_dict(state_dict) | |
| elif dataset_name == "cocostuff" or dataset_name == "getty": | |
| from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2 | |
| seg_network = DeepLabV2() | |
| else: | |
| print(f"No segmentation network for {dataset_name} was found.") | |
| return None | |
| if dist.is_initialized() and is_local_master(): | |
| # Make sure only the first process in distributed training downloads | |
| # the model, and the others will use the cache | |
| # noinspection PyUnresolvedReferences | |
| torch.distributed.barrier() | |
| if seg_network is not None: | |
| seg_network = seg_network.to('cuda').eval() | |
| return SegmentationHistModel(seg_network) | |
| class SegmentationHistModel(nn.Module): | |
| def __init__(self, seg_network): | |
| super().__init__() | |
| self.seg_network = seg_network | |
| def forward(self, data, fake_images, align_corners=True): | |
| pred = self.seg_network(fake_images, align_corners=align_corners) | |
| gt = data["segmaps"] | |
| gt = gt * 255.0 | |
| gt = gt.long() | |
| # print(fake_images.shape, fake_images.min(), fake_images.max()) | |
| # print(gt.shape, gt.min(), gt.max()) | |
| # exit() | |
| return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care) | |
| def compute_hist(pred, gt, n_classes, use_dont_care): | |
| _, H, W = pred.size() | |
| gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1) | |
| ignore_idx = n_classes if use_dont_care else -1 | |
| all_hist = [] | |
| for cur_pred, cur_gt in zip(pred, gt): | |
| keep = torch.logical_not(cur_gt == ignore_idx) | |
| merge = cur_pred[keep] * n_classes + cur_gt[keep] | |
| hist = torch.bincount(merge, minlength=n_classes ** 2) | |
| hist = hist.view((n_classes, n_classes)) | |
| all_hist.append(hist) | |
| all_hist = torch.stack(all_hist) | |
| return all_hist | |
| def get_miou(hist, eps=1e-8): | |
| hist = hist.sum(0) | |
| IOUs = torch.diag(hist) / ( | |
| torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps) | |
| mIOU = 100 * torch.mean(IOUs).item() | |
| return {"seg_mIOU": mIOU} | |