Spaces:
Runtime error
Runtime error
JeffLiang
commited on
Commit
·
ee2b9bc
1
Parent(s):
d8659bc
update ovseg + sam
Browse files
open_vocab_seg/modeling/clip_adapter/utils.py
CHANGED
|
@@ -63,7 +63,7 @@ def crop_with_mask(
|
|
| 63 |
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
|
| 64 |
)
|
| 65 |
# return image[:, t:b, l:r], mask[None, t:b, l:r]
|
| 66 |
-
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (
|
| 67 |
|
| 68 |
|
| 69 |
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
|
|
|
|
| 63 |
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
|
| 64 |
)
|
| 65 |
# return image[:, t:b, l:r], mask[None, t:b, l:r]
|
| 66 |
+
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
|
| 67 |
|
| 68 |
|
| 69 |
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
|
open_vocab_seg/utils/__init__.py
CHANGED
|
@@ -2,4 +2,4 @@
|
|
| 2 |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
| 3 |
|
| 4 |
from .events import setup_wandb, WandbWriter
|
| 5 |
-
from .predictor import VisualizationDemo
|
|
|
|
| 2 |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
| 3 |
|
| 4 |
from .events import setup_wandb, WandbWriter
|
| 5 |
+
from .predictor import VisualizationDemo, SAMVisualizationDemo
|
open_vocab_seg/utils/predictor.py
CHANGED
|
@@ -3,11 +3,19 @@
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from detectron2.data import MetadataCatalog
|
|
|
|
| 8 |
from detectron2.engine.defaults import DefaultPredictor
|
| 9 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class OVSegPredictor(DefaultPredictor):
|
| 13 |
def __init__(self, cfg):
|
|
@@ -129,4 +137,89 @@ class VisualizationDemo(object):
|
|
| 129 |
else:
|
| 130 |
raise NotImplementedError
|
| 131 |
|
| 132 |
-
return predictions, vis_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
import cv2
|
| 8 |
|
| 9 |
from detectron2.data import MetadataCatalog
|
| 10 |
+
from detectron2.structures import BitMasks
|
| 11 |
from detectron2.engine.defaults import DefaultPredictor
|
| 12 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
| 13 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
| 14 |
|
| 15 |
+
import open_clip
|
| 16 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
| 17 |
+
from open_vocab_seg.modeling.clip_adapter.adapter import PIXEL_MEAN, PIXEL_STD
|
| 18 |
+
from open_vocab_seg.modeling.clip_adapter.utils import crop_with_mask
|
| 19 |
|
| 20 |
class OVSegPredictor(DefaultPredictor):
|
| 21 |
def __init__(self, cfg):
|
|
|
|
| 137 |
else:
|
| 138 |
raise NotImplementedError
|
| 139 |
|
| 140 |
+
return predictions, vis_output
|
| 141 |
+
|
| 142 |
+
class SAMVisualizationDemo(object):
|
| 143 |
+
def __init__(self, cfg, granularity, sam_path, ovsegclip_path, instance_mode=ColorMode.IMAGE, parallel=False):
|
| 144 |
+
self.metadata = MetadataCatalog.get(
|
| 145 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.cpu_device = torch.device("cpu")
|
| 149 |
+
self.instance_mode = instance_mode
|
| 150 |
+
|
| 151 |
+
self.parallel = parallel
|
| 152 |
+
self.granularity = granularity
|
| 153 |
+
sam = sam_model_registry["vit_h"](checkpoint=sam_path)
|
| 154 |
+
self.predictor = SamAutomaticMaskGenerator(sam)
|
| 155 |
+
self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
|
| 156 |
+
self.clip_model.cuda()
|
| 157 |
+
|
| 158 |
+
def run_on_image(self, image, class_names):
|
| 159 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 160 |
+
visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
|
| 161 |
+
|
| 162 |
+
masks = self.predictor.generate(image)
|
| 163 |
+
pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
|
| 164 |
+
pred_masks = np.row_stack(pred_masks)
|
| 165 |
+
pred_masks = BitMasks(pred_masks)
|
| 166 |
+
bboxes = pred_masks.get_bounding_boxes()
|
| 167 |
+
|
| 168 |
+
mask_fill = [255.0 * c for c in PIXEL_MEAN]
|
| 169 |
+
|
| 170 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
| 171 |
+
|
| 172 |
+
regions = []
|
| 173 |
+
for bbox, mask in zip(bboxes, pred_masks):
|
| 174 |
+
region, _ = crop_with_mask(
|
| 175 |
+
image,
|
| 176 |
+
mask,
|
| 177 |
+
bbox,
|
| 178 |
+
fill=mask_fill,
|
| 179 |
+
)
|
| 180 |
+
regions.append(region.unsqueeze(0))
|
| 181 |
+
regions = [F.interpolate(r.to(torch.float), size=(224, 224), mode="bicubic") for r in regions]
|
| 182 |
+
|
| 183 |
+
pixel_mean = torch.tensor(PIXEL_MEAN).reshape(1, -1, 1, 1)
|
| 184 |
+
pixel_std = torch.tensor(PIXEL_STD).reshape(1, -1, 1, 1)
|
| 185 |
+
imgs = [(r/255.0 - pixel_mean) / pixel_std for r in regions]
|
| 186 |
+
imgs = torch.cat(imgs)
|
| 187 |
+
if len(class_names) == 1:
|
| 188 |
+
class_names.append('others')
|
| 189 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
| 190 |
+
text = open_clip.tokenize(txts)
|
| 191 |
+
|
| 192 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
| 193 |
+
image_features = self.clip_model.encode_image(imgs)
|
| 194 |
+
text_features = self.clip_model.encode_text(text)
|
| 195 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 196 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 197 |
+
|
| 198 |
+
class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
| 199 |
+
select_cls = torch.zeros_like(class_preds)
|
| 200 |
+
|
| 201 |
+
max_scores, select_mask = torch.max(class_preds, dim=0)
|
| 202 |
+
if len(class_names) == 2 and class_names[-1] == 'others':
|
| 203 |
+
select_mask = select_mask[:-1]
|
| 204 |
+
if self.granularity < 1:
|
| 205 |
+
thr_scores = max_scores * self.granularity
|
| 206 |
+
select_mask = []
|
| 207 |
+
for i, thr in enumerate(thr_scores):
|
| 208 |
+
cls_pred = class_preds[:,i]
|
| 209 |
+
locs = torch.where(cls_pred > thr)
|
| 210 |
+
select_mask.extend(locs[0].tolist())
|
| 211 |
+
for idx in select_mask:
|
| 212 |
+
select_cls[idx] = class_preds[idx]
|
| 213 |
+
semseg = torch.einsum("qc,qhw->chw", select_cls, pred_masks.tensor.float())
|
| 214 |
+
|
| 215 |
+
r = semseg
|
| 216 |
+
blank_area = (r[0] == 0)
|
| 217 |
+
pred_mask = r.argmax(dim=0).to('cpu')
|
| 218 |
+
pred_mask[blank_area] = 255
|
| 219 |
+
pred_mask = np.array(pred_mask, dtype=np.int)
|
| 220 |
+
|
| 221 |
+
vis_output = visualizer.draw_sem_seg(
|
| 222 |
+
pred_mask
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return None, vis_output
|
requirements.txt
CHANGED
|
@@ -19,4 +19,7 @@ torchvision==0.11.2+cu113
|
|
| 19 |
|
| 20 |
# Detectron
|
| 21 |
--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
| 22 |
-
detectron2
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Detectron
|
| 21 |
--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
| 22 |
+
detectron2
|
| 23 |
+
|
| 24 |
+
# Segment-anything
|
| 25 |
+
git+https://github.com/facebookresearch/segment-anything.git
|