Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Jialian Wu from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py | |
| import torch | |
| from detectron2.engine.defaults import DefaultPredictor | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| class Visualizer_GRiT(Visualizer): | |
| def __init__(self, image, instance_mode=None): | |
| super().__init__(image, instance_mode=instance_mode) | |
| def draw_instance_predictions(self, predictions): | |
| boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None | |
| scores = predictions.scores if predictions.has("scores") else None | |
| classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None | |
| object_description = predictions.pred_object_descriptions.data | |
| # uncomment to output scores in visualized images | |
| # object_description = [c + '|' + str(round(s.item(), 1)) for c, s in zip(object_description, scores)] | |
| if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): | |
| colors = [ | |
| self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes | |
| ] | |
| alpha = 0.8 | |
| else: | |
| colors = None | |
| alpha = 0.5 | |
| if self._instance_mode == ColorMode.IMAGE_BW: | |
| self.output.reset_image( | |
| self._create_grayscale_image( | |
| (predictions.pred_masks.any(dim=0) > 0).numpy() | |
| if predictions.has("pred_masks") | |
| else None | |
| ) | |
| ) | |
| alpha = 0.3 | |
| self.overlay_instances( | |
| masks=None, | |
| boxes=boxes, | |
| labels=object_description, | |
| keypoints=None, | |
| assigned_colors=colors, | |
| alpha=alpha, | |
| ) | |
| return self.output | |
| class VisualizationDemo(object): | |
| def __init__(self, cfg, instance_mode=ColorMode.IMAGE): | |
| self.cpu_device = torch.device("cpu") | |
| self.instance_mode = instance_mode | |
| self.predictor = DefaultPredictor(cfg) | |
| def run_on_image(self, image): | |
| predictions = self.predictor(image) | |
| # Convert image from OpenCV BGR format to Matplotlib RGB format. | |
| image = image[:, :, ::-1] | |
| visualizer = Visualizer_GRiT(image, instance_mode=self.instance_mode) | |
| instances = predictions["instances"].to(self.cpu_device) | |
| vis_output = visualizer.draw_instance_predictions(predictions=instances) | |
| return predictions, vis_output |