Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from typing import List, Callable, Optional | |
| from functools import partial | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| """ Model wrapper to return a tensor""" | |
| class HuggingfaceToTensorModelWrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super(HuggingfaceToTensorModelWrapper, self).__init__() | |
| self.model = model | |
| def forward(self, x): | |
| return self.model(x).logits | |
| class ClassActivationMap(object): | |
| def __init__(self, model, processor): | |
| self.model = HuggingfaceToTensorModelWrapper(model) | |
| target_layer = model.swinv2.layernorm | |
| self.target_layer = [target_layer] | |
| self.processor = processor | |
| def swinT_reshape_transform_huggingface(self, tensor, width, height): | |
| result = tensor.reshape(tensor.size(0), | |
| height, | |
| width, | |
| tensor.size(2)) | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| def run_grad_cam_on_image(self, | |
| targets_for_gradcam: List[Callable], | |
| reshape_transform: Optional[Callable], | |
| input_tensor: torch.nn.Module, | |
| input_image: Image, | |
| method: Callable=GradCAM): | |
| with method(model=self.model, | |
| target_layers=self.target_layer, | |
| reshape_transform=reshape_transform) as cam: | |
| # Replicate the tensor for each of the categories we want to create Grad-CAM for: | |
| # print(input_tensor.size()) | |
| repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1) | |
| # print(repeated_tensor.size()) | |
| batch_results = cam(input_tensor=repeated_tensor, | |
| targets=targets_for_gradcam) | |
| results = [] | |
| for grayscale_cam in batch_results: | |
| visualization = show_cam_on_image(np.float32(input_image) / 255, | |
| grayscale_cam, | |
| use_rgb=True) | |
| # Make it weight less in the notebook: | |
| visualization = cv2.resize(visualization, | |
| (visualization.shape[1] // 1, visualization.shape[0] // 1)) | |
| results.append(visualization) | |
| return np.hstack(results) | |
| def get_cam(self, image, category_id): | |
| image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width'])) | |
| img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze() | |
| targets_for_gradcam = [ClassifierOutputTarget(category_id)] | |
| reshape_transform = partial(self.swinT_reshape_transform_huggingface, | |
| width=img_tensor.shape[2] // 32, | |
| height=img_tensor.shape[1] // 32) | |
| cam = self.run_grad_cam_on_image(input_tensor=img_tensor, | |
| input_image=image, | |
| targets_for_gradcam=targets_for_gradcam, | |
| reshape_transform=reshape_transform) | |
| return cam |