Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from typing import Tuple, List | |
| from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX | |
| # import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from src.prompt_to_prompt_controllers import AttentionStore | |
| def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts): | |
| out = [] | |
| attention_maps = attention_store.get_average_attention() | |
| num_pixels = res ** 2 | |
| for location in from_where: | |
| for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
| if item.shape[1] == num_pixels: | |
| cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] | |
| out.append(cross_maps) | |
| out = torch.cat(out, dim=0) | |
| out = out.sum(0) / out.shape[0] | |
| return out.cpu() | |
| def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0): | |
| tokens = tokenizer.encode(prompts[select]) | |
| decoder = tokenizer.decode | |
| attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts) | |
| images = [] | |
| for i in range(len(tokens)): | |
| image = attention_maps[:, :, i] | |
| image = 255 * image / image.max() | |
| image = image.unsqueeze(-1).expand(*image.shape, 3) | |
| image = image.numpy().astype(np.uint8) | |
| image = np.array(Image.fromarray(image).resize((256, 256))) | |
| image = text_under_image(image, decoder(int(tokens[i]))) | |
| images.append(image) | |
| view_images(np.stack(images, axis=0)) | |
| def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], | |
| max_com=10, select: int = 0): | |
| attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape( | |
| (res ** 2, res ** 2)) | |
| u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) | |
| images = [] | |
| for i in range(max_com): | |
| image = vh[i].reshape(res, res) | |
| image = image - image.min() | |
| image = 255 * image / image.max() | |
| image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) | |
| image = Image.fromarray(image).resize((256, 256)) | |
| image = np.array(image) | |
| images.append(image) | |
| view_images(np.concatenate(images, axis=1)) | |
| def view_images(images, num_rows=1, offset_ratio=0.02): | |
| if type(images) is list: | |
| num_empty = len(images) % num_rows | |
| elif images.ndim == 4: | |
| num_empty = images.shape[0] % num_rows | |
| else: | |
| images = [images] | |
| num_empty = 0 | |
| empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
| images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
| num_items = len(images) | |
| h, w, c = images[0].shape | |
| offset = int(h * offset_ratio) | |
| num_cols = num_items // num_rows | |
| image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
| w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
| i * num_cols + j] | |
| pil_img = Image.fromarray(image_) | |
| display(pil_img) | |
| def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): | |
| h, w, c = image.shape | |
| offset = int(h * .2) | |
| img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 | |
| font = FONT_HERSHEY_SIMPLEX | |
| img[:h] = image | |
| textsize = getTextSize(text, font, 1, 2)[0] | |
| text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 | |
| putText(img, text, (text_x, text_y ), font, 1, text_color, 2) | |
| return img | |
| def display(image): | |
| global display_index | |
| plt.imshow(image) | |
| plt.show() | |