Spaces:
Running
Running
| import torch | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| from collections import defaultdict | |
| import os | |
| # Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2. | |
| # Hence, installing detectron2 this way when using Gradio HF spaces. | |
| os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| from color_palette import ade_palette | |
| from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation | |
| def load_model_and_processor(model_ckpt: str): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device)) | |
| model.eval() | |
| image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt) | |
| return model, image_preprocessor | |
| def load_default_ckpt(segmentation_task: str): | |
| if segmentation_task == "semantic": | |
| default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic" | |
| elif segmentation_task == "instance": | |
| default_ckpt = "facebook/mask2former-swin-small-coco-instance" | |
| else: | |
| default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic" | |
| return default_ckpt | |
| def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image): | |
| metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
| for res in seg_info: | |
| res['category_id'] = res.pop('label_id') | |
| pred_class = res['category_id'] | |
| isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() | |
| res['isthing'] = bool(isthing) | |
| visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE) | |
| out = visualizer.draw_panoptic_seg_predictions( | |
| predicted_segmentation_map.cpu(), seg_info, alpha=0.5 | |
| ) | |
| output_img = Image.fromarray(out.get_image()) | |
| return output_img | |
| def draw_semantic_segmentation(segmentation_map, image, palette): | |
| color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
| for label, color in enumerate(palette): | |
| color_segmentation_map[segmentation_map - 1 == label, :] = color | |
| # Convert to BGR | |
| ground_truth_color_seg = color_segmentation_map[..., ::-1] | |
| img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5 | |
| img = img.astype(np.uint8) | |
| return img | |
| def visualize_instance_seg_mask(mask, input_image): | |
| color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) | |
| labels = np.unique(mask) | |
| label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} | |
| for label, color in label2color.items(): | |
| color_segmentation_map[mask - 1 == label, :] = color | |
| ground_truth_color_seg = color_segmentation_map[..., ::-1] | |
| img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5 | |
| img = img.astype(np.uint8) | |
| return img | |
| def predict_masks(input_img_path: str, segmentation_task: str): | |
| #load model and image processor | |
| default_ckpt = load_default_ckpt(segmentation_task) | |
| model, image_processor = load_model_and_processor(default_ckpt) | |
| ## pass input image through image processor | |
| image = Image.open(input_img_path) | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| ## pass inputs to model for prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # pass outputs to processor for postprocessing | |
| if segmentation_task == "semantic": | |
| result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
| predicted_segmentation_map = result.cpu().numpy() | |
| palette = ade_palette() | |
| output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette) | |
| output_heading = "Semantic Segmentation Output" | |
| elif segmentation_task == "instance": | |
| result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
| predicted_instance_map = result["segmentation"].cpu().detach().numpy() | |
| output_result = visualize_instance_seg_mask(predicted_instance_map, image) | |
| output_heading = "Instance Segmentation Output" | |
| else: | |
| result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
| predicted_segmentation_map = result["segmentation"] | |
| seg_info = result['segments_info'] | |
| output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image) | |
| output_heading = "Panoptic Segmentation Output" | |
| return output_result, output_heading | |