Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import numpy as np | |
| from scipy.spatial import ConvexHull | |
| from skimage.draw import polygon | |
| from scipy import ndimage | |
| from .utils import convert_to_numpy | |
| class MaskDrawAnnotator: | |
| def __init__(self, cfg, device=None): | |
| self.mode = cfg.get('MODE', 'maskpoint') | |
| self.return_dict = cfg.get('RETURN_DICT', True) | |
| assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox'] | |
| def forward(self, | |
| mask=None, | |
| image=None, | |
| bbox=None, | |
| mode=None, | |
| return_dict=None): | |
| mode = mode if mode is not None else self.mode | |
| return_dict = return_dict if return_dict is not None else self.return_dict | |
| mask = convert_to_numpy(mask) if mask is not None else None | |
| image = convert_to_numpy(image) if image is not None else None | |
| mask_shape = mask.shape | |
| if mode == 'maskpoint': | |
| scribble = mask.transpose(1, 0) | |
| labeled_array, num_features = ndimage.label(scribble >= 255) | |
| centers = ndimage.center_of_mass(scribble, labeled_array, | |
| range(1, num_features + 1)) | |
| centers = np.array(centers) | |
| out_mask = np.zeros(mask_shape, dtype=np.uint8) | |
| hull = ConvexHull(centers) | |
| hull_vertices = centers[hull.vertices] | |
| rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape) | |
| out_mask[rr, cc] = 255 | |
| elif mode == 'maskbbox': | |
| scribble = mask.transpose(1, 0) | |
| labeled_array, num_features = ndimage.label(scribble >= 255) | |
| centers = ndimage.center_of_mass(scribble, labeled_array, | |
| range(1, num_features + 1)) | |
| centers = np.array(centers) | |
| # (x1, y1, x2, y2) | |
| x_min = centers[:, 0].min() | |
| x_max = centers[:, 0].max() | |
| y_min = centers[:, 1].min() | |
| y_max = centers[:, 1].max() | |
| out_mask = np.zeros(mask_shape, dtype=np.uint8) | |
| out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 | |
| if image is not None: | |
| out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] | |
| elif mode == 'bbox': | |
| if isinstance(bbox, list): | |
| bbox = np.array(bbox) | |
| x_min, y_min, x_max, y_max = bbox | |
| out_mask = np.zeros(mask_shape, dtype=np.uint8) | |
| out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 | |
| if image is not None: | |
| out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] | |
| elif mode == 'mask': | |
| out_mask = mask | |
| else: | |
| raise NotImplementedError | |
| if return_dict: | |
| if image is not None: | |
| return {"image": out_image, "mask": out_mask} | |
| else: | |
| return {"mask": out_mask} | |
| else: | |
| if image is not None: | |
| return out_image, out_mask | |
| else: | |
| return out_mask |