Spaces:
Build error
Build error
change small model
Browse files- app.py +28 -4
- clip_vitb_imagenet_zeroweights.pt +3 -0
app.py
CHANGED
|
@@ -17,6 +17,8 @@ from sklearn import metrics
|
|
| 17 |
import torch
|
| 18 |
from torchvision import transforms
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
|
| 21 |
|
| 22 |
data_transform = transforms.Compose(
|
|
@@ -42,7 +44,7 @@ class CLIPModel_Super(torch.nn.Module):
|
|
| 42 |
self.device = device
|
| 43 |
self.model, _ = clip.load(type, device=self.device, download_root=download_root)
|
| 44 |
|
| 45 |
-
self.model = self.model.
|
| 46 |
|
| 47 |
def forward(self, vision_inputs):
|
| 48 |
"""
|
|
@@ -70,18 +72,40 @@ def transform_vision_data(image):
|
|
| 70 |
image = data_transform(image)
|
| 71 |
return image
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# Instantiate model
|
| 75 |
-
vis_model = CLIPModel_Super("ViT-
|
| 76 |
vis_model.eval()
|
| 77 |
vis_model.to(device)
|
| 78 |
print("load clip model")
|
| 79 |
|
| 80 |
-
semantic_path = "./
|
| 81 |
if os.path.exists(semantic_path):
|
| 82 |
semantic_feature = torch.load(semantic_path, map_location="cpu")
|
| 83 |
semantic_feature = semantic_feature.to(device)
|
| 84 |
semantic_feature = semantic_feature.type(torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
explainer = MultiModalSubModularExplanationEfficientPlus(
|
| 87 |
vis_model, semantic_feature, transform_vision_data, device=device,
|
|
|
|
| 17 |
import torch
|
| 18 |
from torchvision import transforms
|
| 19 |
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus
|
| 23 |
|
| 24 |
data_transform = transforms.Compose(
|
|
|
|
| 44 |
self.device = device
|
| 45 |
self.model, _ = clip.load(type, device=self.device, download_root=download_root)
|
| 46 |
|
| 47 |
+
self.model = self.model.type(torch.float32)
|
| 48 |
|
| 49 |
def forward(self, vision_inputs):
|
| 50 |
"""
|
|
|
|
| 72 |
image = data_transform(image)
|
| 73 |
return image
|
| 74 |
|
| 75 |
+
def zeroshot_classifier(model, classnames, templates, device):
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
zeroshot_weights = []
|
| 78 |
+
for classname in tqdm(classnames):
|
| 79 |
+
texts = [template.format(classname) for template in templates] #format with class
|
| 80 |
+
texts = clip.tokenize(texts).to(device) #tokenize
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
class_embeddings = model.model.encode_text(texts)
|
| 84 |
+
|
| 85 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 86 |
+
class_embedding = class_embeddings.mean(dim=0)
|
| 87 |
+
class_embedding /= class_embedding.norm()
|
| 88 |
+
zeroshot_weights.append(class_embedding)
|
| 89 |
+
zeroshot_weights = torch.stack(zeroshot_weights).cuda()
|
| 90 |
+
return zeroshot_weights*100
|
| 91 |
+
|
| 92 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 93 |
+
device = "cuda"
|
| 94 |
# Instantiate model
|
| 95 |
+
vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
|
| 96 |
vis_model.eval()
|
| 97 |
vis_model.to(device)
|
| 98 |
print("load clip model")
|
| 99 |
|
| 100 |
+
semantic_path = "./clip_vitb_imagenet_zeroweights.pt"
|
| 101 |
if os.path.exists(semantic_path):
|
| 102 |
semantic_feature = torch.load(semantic_path, map_location="cpu")
|
| 103 |
semantic_feature = semantic_feature.to(device)
|
| 104 |
semantic_feature = semantic_feature.type(torch.float32)
|
| 105 |
+
else:
|
| 106 |
+
semantic_feature = zeroshot_classifier(vis_model, imagenet_classes, imagenet_templates, device)
|
| 107 |
+
torch.save(semantic_feature, semantic_path)
|
| 108 |
+
|
| 109 |
|
| 110 |
explainer = MultiModalSubModularExplanationEfficientPlus(
|
| 111 |
vis_model, semantic_feature, transform_vision_data, device=device,
|
clip_vitb_imagenet_zeroweights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c552bb4a3eebecf3162e53861a8368417a2d0b5c3af5454041369c89160ac34e
|
| 3 |
+
size 2048880
|