Spaces:
Runtime error
Runtime error
| """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| try: | |
| from urllib import urlretrieve | |
| except ImportError: | |
| from urllib.request import urlretrieve | |
| def load_url(url, model_dir='./pretrained', map_location=None): | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| filename = url.split('/')[-1] | |
| cached_file = os.path.join(model_dir, filename) | |
| if not os.path.exists(cached_file): | |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
| urlretrieve(url, cached_file) | |
| return torch.load(cached_file, map_location=map_location) | |
| def color_encode(labelmap, colors, mode='RGB'): | |
| labelmap = labelmap.astype('int') | |
| labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), | |
| dtype=np.uint8) | |
| for label in np.unique(labelmap): | |
| if label < 0: | |
| continue | |
| labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ | |
| np.tile(colors[label], | |
| (labelmap.shape[0], labelmap.shape[1], 1)) | |
| if mode == 'BGR': | |
| return labelmap_rgb[:, :, ::-1] | |
| else: | |
| return labelmap_rgb | |