Spaces:
Runtime error
Runtime error
fixed a small bug on mask
Browse files- lib/model_zoo/clip.py +1 -24
lib/model_zoo/clip.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
| 3 |
import numpy as np
|
| 4 |
from functools import partial
|
| 5 |
from lib.model_zoo.common.get_model import register
|
|
|
|
| 6 |
|
| 7 |
symbol = 'clip'
|
| 8 |
|
|
@@ -104,7 +105,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
|
|
| 104 |
assert isinstance(masks, torch.Tensor)
|
| 105 |
assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
| 106 |
masks = torch.clamp(masks, 0, 1)
|
| 107 |
-
masked_images = images*masks
|
| 108 |
masks = masks.float()
|
| 109 |
masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
| 110 |
if masks.sum() == masks.numel():
|
|
@@ -142,29 +142,6 @@ class CLIPImageContextEncoder(AbstractEncoder):
|
|
| 142 |
z = z * vtoken_mask.to(dtype)
|
| 143 |
return z
|
| 144 |
|
| 145 |
-
# def _encode_wmask(self, images, masks):
|
| 146 |
-
# assert isinstance(masks, torch.Tensor)
|
| 147 |
-
# assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
| 148 |
-
# masks = torch.clamp(masks, 0, 1)
|
| 149 |
-
# masks = masks.float()
|
| 150 |
-
# masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
| 151 |
-
# if masks.sum() == masks.numel():
|
| 152 |
-
# return self._encode(images)
|
| 153 |
-
|
| 154 |
-
# device = images.device
|
| 155 |
-
# dtype = images.dtype
|
| 156 |
-
|
| 157 |
-
# vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
|
| 158 |
-
# vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
|
| 159 |
-
# mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
|
| 160 |
-
# vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
|
| 161 |
-
# vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
|
| 162 |
-
|
| 163 |
-
# z = self._encode(images)
|
| 164 |
-
# z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
|
| 165 |
-
# z[:, 0, :] = 0
|
| 166 |
-
# return z
|
| 167 |
-
|
| 168 |
def encode(self, images, masks=None):
|
| 169 |
if masks is None:
|
| 170 |
return self._encode(images)
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from functools import partial
|
| 5 |
from lib.model_zoo.common.get_model import register
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
|
| 8 |
symbol = 'clip'
|
| 9 |
|
|
|
|
| 105 |
assert isinstance(masks, torch.Tensor)
|
| 106 |
assert (len(masks.shape)==4) and (masks.shape[1]==1)
|
| 107 |
masks = torch.clamp(masks, 0, 1)
|
|
|
|
| 108 |
masks = masks.float()
|
| 109 |
masks = F.interpolate(masks, [224, 224], mode='bilinear')
|
| 110 |
if masks.sum() == masks.numel():
|
|
|
|
| 142 |
z = z * vtoken_mask.to(dtype)
|
| 143 |
return z
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
def encode(self, images, masks=None):
|
| 146 |
if masks is None:
|
| 147 |
return self._encode(images)
|