Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Semantic-SAM: Segment and Recognize Anything at Any Granularity | |
| # Copyright (c) 2023 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Hao Zhang (hzhangcx@connect.ust.hk) | |
| # -------------------------------------------------------- | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from task_adapter.utils.visualizer import Visualizer | |
| from typing import Tuple | |
| from PIL import Image | |
| from detectron2.data import MetadataCatalog | |
| metadata = MetadataCatalog.get('coco_2017_train_panoptic') | |
| def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']): | |
| t = [] | |
| t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) | |
| transform1 = transforms.Compose(t) | |
| image_ori = transform1(image['image']) | |
| mask_ori = transform1(image['mask']) | |
| width = image_ori.size[0] | |
| height = image_ori.size[1] | |
| image_ori = np.asarray(image_ori) | |
| images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() | |
| all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':') | |
| data = {"image": images, "height": height, "width": width} | |
| mask_ori = np.asarray(mask_ori)[:,:,0:1].copy() | |
| mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0] | |
| points=mask_ori.nonzero().float().to(images.device) | |
| if len(points)==0: | |
| point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]]) | |
| else: | |
| point_=points.mean(0)[None] | |
| point=point_.clone() | |
| point[0, 0] = point_[0, 0] / mask_ori.shape[0] | |
| point[0, 1] = point_[0, 1] / mask_ori.shape[1] | |
| point = point[:, [1, 0]] | |
| point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1) | |
| data['targets'] = [dict()] | |
| data['targets'][0]['points']=point | |
| data['targets'][0]['pb']=point.new_tensor([0.]) | |
| batch_inputs = [data] | |
| masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts) | |
| pred_masks_poses = masks | |
| reses=[] | |
| ious=ious[0,0] | |
| ids=torch.argsort(ious,descending=True) | |
| text_res='' | |
| try: | |
| thresh=float(thresh) | |
| except Exception: | |
| thresh=0.0 | |
| mask_ls=[] | |
| ious_res=[] | |
| areas=[] | |
| for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])): | |
| iou=round(float(iou),2) | |
| texts=f'{iou}' | |
| mask=(pred_masks_pos>0.0).cpu().numpy() | |
| area=mask.sum() | |
| conti=False | |
| if iou<thresh: | |
| conti=True | |
| for m in mask_ls: | |
| if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95: | |
| conti=True | |
| break | |
| if i == len(pred_masks_poses[ids])-1 and mask_ls==[]: | |
| conti=False | |
| if conti: | |
| continue | |
| ious_res.append(iou) | |
| mask_ls.append(mask) | |
| areas.append(area) | |
| mask,_=remove_small_regions(mask,int(hole_scale),mode="holes") | |
| mask,_=remove_small_regions(mask,int(island_scale),mode="islands") | |
| mask=(mask).astype(np.float) | |
| out_txt = texts | |
| visual = Visualizer(image_ori, metadata=metadata) | |
| color=[0.,0.,1.0] | |
| # demo = visual.draw_binary_mask(mask, color=color, text=texts) | |
| demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode) | |
| res = demo.get_image() | |
| point_x0=max(0,int(point_[0, 1])-3) | |
| point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3) | |
| point_y0 = max(0, int(point_[0, 0]) - 3) | |
| point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3) | |
| # res[point_y0:point_y1,point_x0:point_x1,0]=255 | |
| # res[point_y0:point_y1,point_x0:point_x1,1]=0 | |
| # res[point_y0:point_y1,point_x0:point_x1,2]=0 | |
| reses.append(Image.fromarray(res)) | |
| text_res=text_res+';'+out_txt | |
| ids=list(torch.argsort(torch.tensor(areas),descending=False)) | |
| ids = [int(i) for i in ids] | |
| torch.cuda.empty_cache() | |
| return reses,[reses[i] for i in ids] | |
| def interactive_infer_image_3l(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None): | |
| t = [] | |
| t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) | |
| transform1 = transforms.Compose(t) | |
| image_ori = transform1(image['image']) | |
| mask_ori = transform1(image['mask']) | |
| width = image_ori.size[0] | |
| height = image_ori.size[1] | |
| image_ori = np.asarray(image_ori) | |
| images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() | |
| all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':') | |
| data = {"image": images, "height": height, "width": width} | |
| mask_ori = np.asarray(mask_ori)[:,:,0:1].copy() | |
| mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0] | |
| points=mask_ori.nonzero().float().to(images.device) | |
| if len(points)==0: | |
| point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]]) | |
| else: | |
| point_=points.mean(0)[None] | |
| point=point_.clone() | |
| point[0, 0] = point_[0, 0] / mask_ori.shape[0] | |
| point[0, 1] = point_[0, 1] / mask_ori.shape[1] | |
| point = point[:, [1, 0]] | |
| point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1) | |
| data['targets'] = [dict()] | |
| data['targets'][0]['points']=point | |
| data['targets'][0]['pb']=point.new_tensor([0.]) | |
| batch_inputs = [data] | |
| masks, ious, pred_class, pred_class_score = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, level=[0,1,2]) | |
| pred_masks_poses = masks | |
| reses=[] | |
| ious=ious[0,0] | |
| ids=torch.argsort(ious,descending=True) | |
| text_res='' | |
| try: | |
| thresh=float(thresh) | |
| except Exception: | |
| thresh=0.0 | |
| mask_ls=[] | |
| ious_res=[] | |
| areas=[] | |
| new_pred_class = [] | |
| new_pred_class_score = [] | |
| for i in ids: | |
| new_pred_class_score.append(pred_class_score[i]) | |
| new_pred_class.append(pred_class[i]) | |
| # import ipdb; ipdb.set_trace() | |
| for i,(pred_masks_pos,iou, cls_name, cls_score) in enumerate(zip(pred_masks_poses[ids],ious[ids], new_pred_class, new_pred_class_score)): | |
| iou=round(float(iou),2) | |
| texts=f'{iou}_{cls_name}_{cls_score}' | |
| mask=(pred_masks_pos>0.0).cpu().numpy() | |
| area=mask.sum() | |
| conti=False | |
| if iou<thresh: | |
| conti=True | |
| for m in mask_ls: | |
| if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95: | |
| conti=True | |
| break | |
| if i == len(pred_masks_poses[ids])-1 and mask_ls==[]: | |
| conti=False | |
| if conti: | |
| continue | |
| ious_res.append(iou) | |
| mask_ls.append(mask) | |
| areas.append(area) | |
| mask,_=remove_small_regions(mask,int(hole_scale),mode="holes") | |
| mask,_=remove_small_regions(mask,int(island_scale),mode="islands") | |
| mask=(mask).astype(np.float) | |
| out_txt = texts | |
| visual = Visualizer(image_ori, metadata=metadata) | |
| color=[0.,0.,1.0] | |
| demo = visual.draw_binary_mask(mask, color=color, text=texts) | |
| res = demo.get_image() | |
| point_x0=max(0,int(point_[0, 1])-3) | |
| point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3) | |
| point_y0 = max(0, int(point_[0, 0]) - 3) | |
| point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3) | |
| res[point_y0:point_y1,point_x0:point_x1,0]=255 | |
| res[point_y0:point_y1,point_x0:point_x1,1]=0 | |
| res[point_y0:point_y1,point_x0:point_x1,2]=0 | |
| reses.append(Image.fromarray(res)) | |
| text_res=text_res+';'+out_txt | |
| ids=list(torch.argsort(torch.tensor(areas),descending=False)) | |
| ids = [int(i) for i in ids] | |
| torch.cuda.empty_cache() | |
| return reses,[reses[i] for i in ids] | |
| def interactive_infer_image_semantic(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None): | |
| t = [] | |
| t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) | |
| transform1 = transforms.Compose(t) | |
| image_ori = transform1(image['image']) | |
| mask_ori = transform1(image['mask']) | |
| width = image_ori.size[0] | |
| height = image_ori.size[1] | |
| image_ori = np.asarray(image_ori) | |
| images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() | |
| all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':') | |
| data = {"image": images, "height": height, "width": width} | |
| mask_ori = np.asarray(mask_ori)[:,:,0:1].copy() | |
| mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0] | |
| points=mask_ori.nonzero().float().to(images.device) | |
| if len(points)==0: | |
| point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]]) | |
| else: | |
| point_=points.mean(0)[None] | |
| point=point_.clone() | |
| point[0, 0] = point_[0, 0] / mask_ori.shape[0] | |
| point[0, 1] = point_[0, 1] / mask_ori.shape[1] | |
| point = point[:, [1, 0]] | |
| point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1) | |
| data['targets'] = [dict()] | |
| data['targets'][0]['points']=point | |
| data['targets'][0]['pb']=point.new_tensor([0.]) | |
| data['targets'][0]['pb']=point.new_tensor([1.]) | |
| batch_inputs = [data] | |
| masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts) | |
| pred_masks_poses = masks | |
| reses=[] | |
| ious=ious[0,0] | |
| ids=torch.argsort(ious,descending=True) | |
| text_res='' | |
| try: | |
| thresh=float(thresh) | |
| except Exception: | |
| thresh=0.0 | |
| mask_ls=[] | |
| ious_res=[] | |
| areas=[] | |
| for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])): | |
| iou=round(float(iou),2) | |
| texts=f'{iou}' | |
| mask=(pred_masks_pos>0.0).cpu().numpy() | |
| area=mask.sum() | |
| conti=False | |
| if iou<thresh: | |
| conti=True | |
| for m in mask_ls: | |
| if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95: | |
| conti=True | |
| break | |
| if i == len(pred_masks_poses[ids])-1 and mask_ls==[]: | |
| conti=False | |
| if conti: | |
| continue | |
| ious_res.append(iou) | |
| mask_ls.append(mask) | |
| areas.append(area) | |
| mask,_=remove_small_regions(mask,int(hole_scale),mode="holes") | |
| mask,_=remove_small_regions(mask,int(island_scale),mode="islands") | |
| mask=(mask).astype(np.float) | |
| out_txt = texts | |
| visual = Visualizer(image_ori, metadata=metadata) | |
| color=[0.,0.,1.0] | |
| demo = visual.draw_binary_mask(mask, color=color, text=texts) | |
| res = demo.get_image() | |
| point_x0=max(0,int(point_[0, 1])-3) | |
| point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3) | |
| point_y0 = max(0, int(point_[0, 0]) - 3) | |
| point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3) | |
| res[point_y0:point_y1,point_x0:point_x1,0]=255 | |
| res[point_y0:point_y1,point_x0:point_x1,1]=0 | |
| res[point_y0:point_y1,point_x0:point_x1,2]=0 | |
| reses.append(Image.fromarray(res)) | |
| text_res=text_res+';'+out_txt | |
| ids=list(torch.argsort(torch.tensor(areas),descending=False)) | |
| ids = [int(i) for i in ids] | |
| torch.cuda.empty_cache() | |
| return reses,[reses[i] for i in ids] | |
| def remove_small_regions( | |
| mask: np.ndarray, area_thresh: float, mode: str | |
| ) -> Tuple[np.ndarray, bool]: | |
| """ | |
| Removes small disconnected regions and holes in a mask. Returns the | |
| mask and an indicator of if the mask has been modified. | |
| """ | |
| import cv2 # type: ignore | |
| assert mode in ["holes", "islands"] | |
| correct_holes = mode == "holes" | |
| working_mask = (correct_holes ^ mask).astype(np.uint8) | |
| n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) | |
| sizes = stats[:, -1][1:] # Row 0 is background label | |
| small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] | |
| if len(small_regions) == 0: | |
| return mask, False | |
| fill_labels = [0] + small_regions | |
| if not correct_holes: | |
| fill_labels = [i for i in range(n_labels) if i not in fill_labels] | |
| # If every region is below threshold, keep largest | |
| if len(fill_labels) == 0: | |
| fill_labels = [int(np.argmax(sizes)) + 1] | |
| mask = np.isin(regions, fill_labels) | |
| return mask, True |