Spaces:
Running
Running
Commit
·
aeaceee
1
Parent(s):
c027c15
Add instance seg visualization
Browse files- app.py +2 -2
- predict.py +17 -17
app.py
CHANGED
|
@@ -10,14 +10,14 @@ demo = gr.Blocks()
|
|
| 10 |
|
| 11 |
with demo:
|
| 12 |
|
| 13 |
-
gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Transformer for Universal Segmentation</p>**")
|
| 14 |
gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
|
| 15 |
Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
|
| 16 |
|
| 17 |
with gr.Box():
|
| 18 |
|
| 19 |
with gr.Row():
|
| 20 |
-
segmentation_task = gr.Dropdown(["semantic", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
| 21 |
with gr.Box():
|
| 22 |
with gr.Row():
|
| 23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
|
|
|
| 10 |
|
| 11 |
with demo:
|
| 12 |
|
| 13 |
+
gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Mask Transformer for Universal Segmentation</p>**")
|
| 14 |
gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
|
| 15 |
Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
|
| 16 |
|
| 17 |
with gr.Box():
|
| 18 |
|
| 19 |
with gr.Row():
|
| 20 |
+
segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
| 21 |
with gr.Box():
|
| 22 |
with gr.Row():
|
| 23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
predict.py
CHANGED
|
@@ -40,6 +40,7 @@ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
|
| 40 |
return output_img
|
| 41 |
|
| 42 |
def draw_semantic_segmentation(segmentation_map, image, palette):
|
|
|
|
| 43 |
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
|
| 44 |
for label, color in enumerate(palette):
|
| 45 |
color_segmentation_map[segmentation_map - 1 == label, :] = color
|
|
@@ -50,15 +51,20 @@ def draw_semantic_segmentation(segmentation_map, image, palette):
|
|
| 50 |
img = img.astype(np.uint8)
|
| 51 |
return img
|
| 52 |
|
| 53 |
-
def visualize_instance_seg_mask(mask):
|
| 54 |
-
|
|
|
|
| 55 |
labels = np.unique(mask)
|
| 56 |
-
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
| 64 |
|
|
@@ -82,15 +88,9 @@ def predict_masks(input_img_path: str, segmentation_task: str):
|
|
| 82 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
| 83 |
|
| 84 |
elif segmentation_task == "instance":
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
|
| 89 |
-
# # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
|
| 90 |
-
# print("predicted_segmentation_map:",predicted_segmentation_map)
|
| 91 |
-
# print("type predicted_segmentation_map:", type(predicted_segmentation_map))
|
| 92 |
-
# output_result = visualize_instance_seg_mask(predicted_segmentation_map)
|
| 93 |
-
# # mask = plot_semantic_map(predicted_segmentation_map, image)
|
| 94 |
|
| 95 |
else:
|
| 96 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
|
|
|
| 40 |
return output_img
|
| 41 |
|
| 42 |
def draw_semantic_segmentation(segmentation_map, image, palette):
|
| 43 |
+
|
| 44 |
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
|
| 45 |
for label, color in enumerate(palette):
|
| 46 |
color_segmentation_map[segmentation_map - 1 == label, :] = color
|
|
|
|
| 51 |
img = img.astype(np.uint8)
|
| 52 |
return img
|
| 53 |
|
| 54 |
+
def visualize_instance_seg_mask(mask, input_image):
|
| 55 |
+
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
| 56 |
+
|
| 57 |
labels = np.unique(mask)
|
| 58 |
+
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
| 59 |
+
|
| 60 |
+
for label, color in label2color.items():
|
| 61 |
+
color_segmentation_map[mask - 1 == label, :] = color
|
| 62 |
+
|
| 63 |
+
ground_truth_color_seg = color_segmentation_map[..., ::-1]
|
| 64 |
+
|
| 65 |
+
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
|
| 66 |
+
img = img.astype(np.uint8)
|
| 67 |
+
return img
|
| 68 |
|
| 69 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
| 70 |
|
|
|
|
| 88 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
| 89 |
|
| 90 |
elif segmentation_task == "instance":
|
| 91 |
+
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
| 92 |
+
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
| 93 |
+
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
else:
|
| 96 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|