Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import cv2 | |
| import math | |
| import random | |
| from abc import ABCMeta | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from .utils import convert_to_numpy, convert_to_pil, single_rle_to_mask, get_mask_box, read_video_one_frame | |
| class InpaintingAnnotator: | |
| def __init__(self, cfg, device=None): | |
| self.use_aug = cfg.get('USE_AUG', True) | |
| self.return_mask = cfg.get('RETURN_MASK', True) | |
| self.return_source = cfg.get('RETURN_SOURCE', True) | |
| self.mask_color = cfg.get('MASK_COLOR', 128) | |
| self.mode = cfg.get('MODE', "mask") | |
| assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"] | |
| if self.mode in ["salient", "salienttrack"]: | |
| from .salient import SalientAnnotator | |
| self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) | |
| if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']: | |
| from .sam2 import SAM2ImageAnnotator | |
| self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) | |
| if self.mode in ['label', 'caption']: | |
| from .gdino import GDINOAnnotator | |
| from .sam2 import SAM2ImageAnnotator | |
| self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) | |
| self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) | |
| if self.mode in ['all']: | |
| from .salient import SalientAnnotator | |
| from .gdino import GDINOAnnotator | |
| from .sam2 import SAM2ImageAnnotator | |
| self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) | |
| self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) | |
| self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) | |
| if self.use_aug: | |
| from .maskaug import MaskAugAnnotator | |
| self.maskaug_anno = MaskAugAnnotator(cfg={}) | |
| def apply_plain_mask(self, image, mask, mask_color): | |
| bool_mask = mask > 0 | |
| out_image = image.copy() | |
| out_image[bool_mask] = mask_color | |
| out_mask = np.where(bool_mask, 255, 0).astype(np.uint8) | |
| return out_image, out_mask | |
| def apply_seg_mask(self, image, mask, mask_color, mask_cfg=None): | |
| out_mask = (mask * 255).astype('uint8') | |
| if self.use_aug and mask_cfg is not None: | |
| out_mask = self.maskaug_anno.forward(out_mask, mask_cfg) | |
| bool_mask = out_mask > 0 | |
| out_image = image.copy() | |
| out_image[bool_mask] = mask_color | |
| return out_image, out_mask | |
| def forward(self, image=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None): | |
| mode = mode if mode is not None else self.mode | |
| return_mask = return_mask if return_mask is not None else self.return_mask | |
| return_source = return_source if return_source is not None else self.return_source | |
| mask_color = mask_color if mask_color is not None else self.mask_color | |
| image = convert_to_numpy(image) | |
| out_image, out_mask = None, None | |
| if mode in ['salient']: | |
| mask = self.salient_model.forward(image) | |
| out_image, out_mask = self.apply_plain_mask(image, mask, mask_color) | |
| elif mode in ['mask']: | |
| mask_h, mask_w = mask.shape[:2] | |
| h, w = image.shape[:2] | |
| if (mask_h ==h) and (mask_w == w): | |
| mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| out_image, out_mask = self.apply_plain_mask(image, mask, mask_color) | |
| elif mode in ['bbox']: | |
| x1, y1, x2, y2 = bbox | |
| h, w = image.shape[:2] | |
| x1, y1 = int(max(0, x1)), int(max(0, y1)) | |
| x2, y2 = int(min(w, x2)), int(min(h, y2)) | |
| out_image = image.copy() | |
| out_image[y1:y2, x1:x2] = mask_color | |
| out_mask = np.zeros((h, w), dtype=np.uint8) | |
| out_mask[y1:y2, x1:x2] = 255 | |
| elif mode in ['salientmasktrack']: | |
| mask = self.salient_model.forward(image) | |
| resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['salientbboxtrack']: | |
| mask = self.salient_model.forward(image) | |
| bbox = get_mask_box(np.array(mask), threshold=1) | |
| out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['maskpointtrack']: | |
| out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_point', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['maskbboxtrack']: | |
| out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_box', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['masktrack']: | |
| resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['bboxtrack']: | |
| out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['label']: | |
| gdino_res = self.gdino_model.forward(image, classes=label) | |
| if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: | |
| bboxes = gdino_res['boxes'][0] | |
| else: | |
| raise ValueError(f"Unable to find the corresponding boxes of label: {label}") | |
| out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| elif mode in ['caption']: | |
| gdino_res = self.gdino_model.forward(image, caption=caption) | |
| if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: | |
| bboxes = gdino_res['boxes'][0] | |
| else: | |
| raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}") | |
| out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True) | |
| out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) | |
| ret_data = {"image": out_image} | |
| if return_mask: | |
| ret_data["mask"] = out_mask | |
| if return_source: | |
| ret_data["src_image"] = image | |
| return ret_data | |
| class InpaintingVideoAnnotator: | |
| def __init__(self, cfg, device=None): | |
| self.use_aug = cfg.get('USE_AUG', True) | |
| self.return_frame = cfg.get('RETURN_FRAME', True) | |
| self.return_mask = cfg.get('RETURN_MASK', True) | |
| self.return_source = cfg.get('RETURN_SOURCE', True) | |
| self.mask_color = cfg.get('MASK_COLOR', 128) | |
| self.mode = cfg.get('MODE', "mask") | |
| assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"] | |
| if self.mode in ["salient", "salienttrack"]: | |
| from .salient import SalientAnnotator | |
| self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) | |
| if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']: | |
| from .sam2 import SAM2VideoAnnotator | |
| self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) | |
| if self.mode in ['label', 'caption']: | |
| from .gdino import GDINOAnnotator | |
| from .sam2 import SAM2VideoAnnotator | |
| self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) | |
| self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) | |
| if self.mode in ['all']: | |
| from .salient import SalientAnnotator | |
| from .gdino import GDINOAnnotator | |
| from .sam2 import SAM2VideoAnnotator | |
| self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) | |
| self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) | |
| self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) | |
| if self.use_aug: | |
| from .maskaug import MaskAugAnnotator | |
| self.maskaug_anno = MaskAugAnnotator(cfg={}) | |
| def apply_plain_mask(self, frames, mask, mask_color, return_frame=True): | |
| out_frames = [] | |
| num_frames = len(frames) | |
| bool_mask = mask > 0 | |
| out_masks = [np.where(bool_mask, 255, 0).astype(np.uint8)] * num_frames | |
| if not return_frame: | |
| return None, out_masks | |
| for i in range(num_frames): | |
| masked_frame = frames[i].copy() | |
| masked_frame[bool_mask] = mask_color | |
| out_frames.append(masked_frame) | |
| return out_frames, out_masks | |
| def apply_seg_mask(self, mask_data, frames, mask_color, mask_cfg=None, return_frame=True): | |
| out_frames = [] | |
| out_masks = [(single_rle_to_mask(val[0]["mask"]) * 255).astype('uint8') for key, val in mask_data['annotations'].items()] | |
| if not return_frame: | |
| return None, out_masks | |
| num_frames = min(len(out_masks), len(frames)) | |
| for i in range(num_frames): | |
| sub_mask = out_masks[i] | |
| if self.use_aug and mask_cfg is not None: | |
| sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg) | |
| out_masks[i] = sub_mask | |
| bool_mask = sub_mask > 0 | |
| masked_frame = frames[i].copy() | |
| masked_frame[bool_mask] = mask_color | |
| out_frames.append(masked_frame) | |
| out_masks = out_masks[:num_frames] | |
| return out_frames, out_masks | |
| def forward(self, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_frame=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None): | |
| mode = mode if mode is not None else self.mode | |
| return_frame = return_frame if return_frame is not None else self.return_frame | |
| return_mask = return_mask if return_mask is not None else self.return_mask | |
| return_source = return_source if return_source is not None else self.return_source | |
| mask_color = mask_color if mask_color is not None else self.mask_color | |
| out_frames, out_masks = [], [] | |
| if mode in ['salient']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| mask = self.salient_model.forward(first_frame) | |
| out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame) | |
| elif mode in ['mask']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| mask_h, mask_w = mask.shape[:2] | |
| h, w = first_frame.shape[:2] | |
| if (mask_h ==h) and (mask_w == w): | |
| mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame) | |
| elif mode in ['bbox']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| num_frames = len(frames) | |
| x1, y1, x2, y2 = bbox | |
| h, w = first_frame.shape[:2] | |
| x1, y1 = int(max(0, x1)), int(max(0, y1)) | |
| x2, y2 = int(min(w, x2)), int(min(h, y2)) | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| mask[y1:y2, x1:x2] = 255 | |
| out_masks = [mask] * num_frames | |
| if not return_frame: | |
| out_frames = None | |
| else: | |
| for i in range(num_frames): | |
| masked_frame = frames[i].copy() | |
| masked_frame[y1:y2, x1:x2] = mask_color | |
| out_frames.append(masked_frame) | |
| elif mode in ['salientmasktrack']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| salient_mask = self.salient_model.forward(first_frame) | |
| mask_data = self.sam2_model.forward(video=video, mask=salient_mask, task_type='mask') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['salientbboxtrack']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| salient_mask = self.salient_model.forward(first_frame) | |
| bbox = get_mask_box(np.array(salient_mask), threshold=1) | |
| mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['maskpointtrack']: | |
| mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_point') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['maskbboxtrack']: | |
| mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_box') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['masktrack']: | |
| mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['bboxtrack']: | |
| mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['label']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| gdino_res = self.gdino_model.forward(first_frame, classes=label) | |
| if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: | |
| bboxes = gdino_res['boxes'][0] | |
| else: | |
| raise ValueError(f"Unable to find the corresponding boxes of label: {label}") | |
| mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| elif mode in ['caption']: | |
| first_frame = frames[0] if frames is not None else read_video_one_frame(video) | |
| gdino_res = self.gdino_model.forward(first_frame, caption=caption) | |
| if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: | |
| bboxes = gdino_res['boxes'][0] | |
| else: | |
| raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}") | |
| mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box') | |
| out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) | |
| ret_data = {} | |
| if return_frame: | |
| ret_data["frames"] = out_frames | |
| if return_mask: | |
| ret_data["masks"] = out_masks | |
| return ret_data | |