| import torch | |
| import torch.nn as nn | |
| from torch.utils import data | |
| import torchvision.transforms as transform | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| from collections import defaultdict, deque | |
| import torch.distributed as dist | |
| def colorize_mask(mask): | |
| palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, | |
| 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, | |
| 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] | |
| zero_pad = 256 * 3 - len(palette) | |
| for i in range(zero_pad): | |
| palette.append(0) | |
| new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') | |
| new_mask.putpalette(palette) | |
| return new_mask | |
| def build_img(args): | |
| from PIL import Image | |
| img = Image.open(args.input_path) | |
| input_transform = transform.Compose([ | |
| transform.ToTensor(), | |
| transform.Normalize([.485, .456, .406], [.229, .224, .225]), | |
| transform.Resize((256, 512))]) | |
| resized_img = input_transform(img) | |
| resized_img = resized_img.unsqueeze(0) | |
| return resized_img | |
| class ConfusionMatrix(object): | |
| def __init__(self, num_classes): | |
| self.num_classes = num_classes | |
| self.mat = None | |
| def update(self, a, b): | |
| n = self.num_classes | |
| if self.mat is None: | |
| self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) | |
| with torch.no_grad(): | |
| k = (a >= 0) & (a < n) | |
| inds = n * a[k].to(torch.int64) + b[k] | |
| self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) | |
| def reset(self): | |
| self.mat.zero_() | |
| def compute(self): | |
| h = self.mat.float() | |
| acc_global = torch.diag(h).sum() / h.sum() | |
| acc = torch.diag(h) / h.sum(1) | |
| iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) | |
| return acc_global, acc, iu | |
| def reduce_from_all_processes(self): | |
| if not torch.distributed.is_available(): | |
| return | |
| if not torch.distributed.is_initialized(): | |
| return | |
| torch.distributed.barrier() | |
| torch.distributed.all_reduce(self.mat) | |
| def __str__(self): | |
| acc_global, acc, iu = self.compute() | |
| return ( | |
| 'per-class IoU(%): \n {}\n' | |
| 'mean IoU(%): {:.1f}').format( | |
| ['{:.1f}'.format(i) for i in (iu * 100).tolist()], | |
| iu.mean().item() * 100) |