Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from .exceptions import EmptyTensorError | |
| def preprocess_image(image, preprocessing=None): | |
| image = image.astype(np.float32) | |
| image = np.transpose(image, [2, 0, 1]) | |
| if preprocessing is None: | |
| pass | |
| elif preprocessing == 'caffe': | |
| # RGB -> BGR | |
| image = image[:: -1, :, :] | |
| # Zero-center by mean pixel | |
| mean = np.array([103.939, 116.779, 123.68]) | |
| image = image - mean.reshape([3, 1, 1]) | |
| elif preprocessing == 'torch': | |
| image /= 255.0 | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1]) | |
| else: | |
| raise ValueError('Unknown preprocessing parameter.') | |
| return image | |
| def imshow_image(image, preprocessing=None): | |
| if preprocessing is None: | |
| pass | |
| elif preprocessing == 'caffe': | |
| mean = np.array([103.939, 116.779, 123.68]) | |
| image = image + mean.reshape([3, 1, 1]) | |
| # RGB -> BGR | |
| image = image[:: -1, :, :] | |
| elif preprocessing == 'torch': | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) | |
| image *= 255.0 | |
| else: | |
| raise ValueError('Unknown preprocessing parameter.') | |
| image = np.transpose(image, [1, 2, 0]) | |
| image = np.round(image).astype(np.uint8) | |
| return image | |
| def grid_positions(h, w, device, matrix=False): | |
| lines = torch.arange( | |
| 0, h, device=device | |
| ).view(-1, 1).float().repeat(1, w) | |
| columns = torch.arange( | |
| 0, w, device=device | |
| ).view(1, -1).float().repeat(h, 1) | |
| if matrix: | |
| return torch.stack([lines, columns], dim=0) | |
| else: | |
| return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0) | |
| def upscale_positions(pos, scaling_steps=0): | |
| for _ in range(scaling_steps): | |
| pos = pos * 2 + 0.5 | |
| return pos | |
| def downscale_positions(pos, scaling_steps=0): | |
| for _ in range(scaling_steps): | |
| pos = (pos - 0.5) / 2 | |
| return pos | |
| def interpolate_dense_features(pos, dense_features, return_corners=False): | |
| device = pos.device | |
| ids = torch.arange(0, pos.size(1), device=device) | |
| _, h, w = dense_features.size() | |
| i = pos[0, :] | |
| j = pos[1, :] | |
| # Valid corners | |
| i_top_left = torch.floor(i).long() | |
| j_top_left = torch.floor(j).long() | |
| valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) | |
| i_top_right = torch.floor(i).long() | |
| j_top_right = torch.ceil(j).long() | |
| valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) | |
| i_bottom_left = torch.ceil(i).long() | |
| j_bottom_left = torch.floor(j).long() | |
| valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) | |
| i_bottom_right = torch.ceil(i).long() | |
| j_bottom_right = torch.ceil(j).long() | |
| valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) | |
| valid_corners = torch.min( | |
| torch.min(valid_top_left, valid_top_right), | |
| torch.min(valid_bottom_left, valid_bottom_right) | |
| ) | |
| i_top_left = i_top_left[valid_corners] | |
| j_top_left = j_top_left[valid_corners] | |
| i_top_right = i_top_right[valid_corners] | |
| j_top_right = j_top_right[valid_corners] | |
| i_bottom_left = i_bottom_left[valid_corners] | |
| j_bottom_left = j_bottom_left[valid_corners] | |
| i_bottom_right = i_bottom_right[valid_corners] | |
| j_bottom_right = j_bottom_right[valid_corners] | |
| ids = ids[valid_corners] | |
| if ids.size(0) == 0: | |
| raise EmptyTensorError | |
| # Interpolation | |
| i = i[ids] | |
| j = j[ids] | |
| dist_i_top_left = i - i_top_left.float() | |
| dist_j_top_left = j - j_top_left.float() | |
| w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) | |
| w_top_right = (1 - dist_i_top_left) * dist_j_top_left | |
| w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) | |
| w_bottom_right = dist_i_top_left * dist_j_top_left | |
| descriptors = ( | |
| w_top_left * dense_features[:, i_top_left, j_top_left] + | |
| w_top_right * dense_features[:, i_top_right, j_top_right] + | |
| w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] + | |
| w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right] | |
| ) | |
| pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) | |
| if not return_corners: | |
| return [descriptors, pos, ids] | |
| else: | |
| corners = torch.stack([ | |
| torch.stack([i_top_left, j_top_left], dim=0), | |
| torch.stack([i_top_right, j_top_right], dim=0), | |
| torch.stack([i_bottom_left, j_bottom_left], dim=0), | |
| torch.stack([i_bottom_right, j_bottom_right], dim=0) | |
| ], dim=0) | |
| return [descriptors, pos, ids, corners] | |
| def savefig(filepath, fig=None, dpi=None): | |
| # TomNorway - https://stackoverflow.com/a/53516034 | |
| if not fig: | |
| fig = plt.gcf() | |
| plt.subplots_adjust(0, 0, 1, 1, 0, 0) | |
| for ax in fig.axes: | |
| ax.axis('off') | |
| ax.margins(0, 0) | |
| ax.xaxis.set_major_locator(plt.NullLocator()) | |
| ax.yaxis.set_major_locator(plt.NullLocator()) | |
| fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi) | |