Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import math | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| class Conceptrol: | |
| def __init__(self, config): | |
| if "name" not in config: | |
| raise KeyError("name has to be provided as 'conceptrol' or 'ominicontrol'") | |
| name = config["name"] | |
| if name not in ["conceptrol", "ominicontrol"]: | |
| raise ValueError( | |
| f"Name must be one of ['conceptrol', 'ominicontrol'], got {name}" | |
| ) | |
| try: | |
| log_attn_map = config["log_attn_map"] | |
| except KeyError: | |
| log_attn_map = False | |
| # static | |
| self.NUM_BLOCKS = 19 # this is fixed for FLUX | |
| self.M = 512 # num of text tokens, fixed for FLUX | |
| self.N = 1024 # num of latent / image condtion tokens, fixed for FLUX | |
| self.EP = -10e6 | |
| self.CONCEPT_BLOCK_IDX = 18 | |
| # fixed during one generation | |
| self.name = name | |
| # variable during one generation | |
| self.textual_concept_mask = None | |
| self.forward_count = 0 | |
| # log out for visualization | |
| if log_attn_map: | |
| self.attn_maps = {"latent_to_concept": [], "latent_to_image": []} | |
| def __call__( | |
| self, | |
| query: torch.FloatTensor, | |
| key: torch.FloatTensor, | |
| attention_mask: torch.Tensor, | |
| c_factor: float = 1.0, | |
| ) -> torch.Tensor: | |
| if not hasattr(self, "textual_concept_idx"): | |
| raise AttributeError( | |
| "textual_concept_idx must be registered before calling Conceptrol" | |
| ) | |
| # Skip computation for ominicontrol | |
| if self.name == "ominicontrol": | |
| scale_factor = 1 / math.sqrt(query.size(-1)) | |
| attention_weight = ( | |
| query @ key.transpose(-2, -1) * scale_factor + attention_mask | |
| ) | |
| attention_probs = torch.softmax( | |
| attention_weight, dim=-1 | |
| ) # [B, H, M+2N, M+2N] | |
| return attention_probs | |
| if not self.textual_concept_idx[0] < self.textual_concept_idx[1]: | |
| raise ValueError( | |
| f"register_idx[0] must be less than register_idx[1], " | |
| f"got {self.textual_concept_idx[0]} >= {self.textual_concept_idx[1]}" | |
| ) | |
| ### Reset attention mask predefined in ominicontrol | |
| attention_mask = torch.zeros_like(attention_mask) | |
| bias = torch.log(c_factor[0]) | |
| # attention of image condition to latent | |
| attention_mask[-self.N :, self.M : -self.N] = bias | |
| # attention of latent to image condition | |
| attention_mask[self.M : -self.N, -self.N :] = bias | |
| # attention of textual concept to image condition | |
| attention_mask[ | |
| self.textual_concept_idx[0] : self.textual_concept_idx[1], -self.N : | |
| ] = bias | |
| # attention of other words to image condition (set as negative inf) | |
| attention_mask[: self.textual_concept_idx[0], -self.N :] = self.EP | |
| attention_mask[self.textual_concept_idx[1] : self.M, -self.N :] = self.EP | |
| # If there is no textual_concept_mask, it means currently in layers previous to the first concept-specific block | |
| if self.textual_concept_mask is None: | |
| self.textual_concept_mask = ( | |
| torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0) | |
| ) | |
| ### Compute attention | |
| scale_factor = 1 / math.sqrt(query.size(-1)) | |
| attention_weight = ( | |
| query @ key.transpose(-2, -1) * scale_factor | |
| + attention_mask | |
| + self.textual_concept_mask | |
| ) | |
| # [B, H, M+2N, M+2N] | |
| attention_probs = torch.softmax(attention_weight, dim=-1) | |
| ### Extract textual concept mask if it's concept-specific block | |
| is_concept_block = ( | |
| self.forward_count % self.NUM_BLOCKS == self.CONCEPT_BLOCK_IDX | |
| ) | |
| if is_concept_block: | |
| # Shape: [B, H, N, S], where S is the token numbers of the subject | |
| textual_concept_mask_local = attention_probs[ | |
| :, | |
| :, | |
| self.M : -self.N, | |
| self.textual_concept_idx[0] : self.textual_concept_idx[1], | |
| ] | |
| # Consider the ratio within context of text | |
| textual_concept_mask_local = textual_concept_mask_local / torch.sum( | |
| attention_probs[:, :, self.M : -self.N, : self.M], dim=-1, keepdim=True | |
| ) | |
| # Average over words and head, Shape: [B, 1, N, 1] | |
| textual_concept_mask_local = torch.mean( | |
| textual_concept_mask_local, dim=(-1, 1), keepdim=True | |
| ) | |
| # Normalize to average as 1 | |
| textual_concept_mask_local = textual_concept_mask_local / torch.mean( | |
| textual_concept_mask_local, dim=-2, keepdim=True | |
| ) | |
| self.textual_concept_mask = ( | |
| torch.zeros_like(attention_mask).unsqueeze(0).unsqueeze(0) | |
| ) | |
| # log(A) in the paper | |
| self.textual_concept_mask[:, :, self.M : -self.N, -self.N :] = torch.log( | |
| textual_concept_mask_local | |
| ) | |
| self.forward_count += 1 | |
| return attention_probs | |
| def register(self, textual_concept_idx): | |
| self.textual_concept_idx = textual_concept_idx | |
| def visualize_attn_map(self, config_name: str, subject: str): | |
| global global_concept_mask | |
| global forward_count | |
| save_dir = f"attn_maps/{config_name}/{subject}" | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| for attn_map_name, attn_maps in self.attn_maps.items(): | |
| if "token_to_token" in attn_map_name: | |
| continue | |
| plt.figure() | |
| rows, cols = 8, 19 | |
| fig, axes = plt.subplots( | |
| rows, cols, figsize=(64 * cols / 100, 64 * rows / 100) | |
| ) | |
| fig.subplots_adjust( | |
| wspace=0.1, hspace=0.1 | |
| ) # Adjust spacing between subplots | |
| # Plot each array in the list on the grid | |
| for i, ax in enumerate(axes.flatten()): | |
| if i < len(attn_maps): # Only plot existing arrays | |
| attn_map = attn_maps[i] / np.amax(attn_maps[i]) | |
| ax.imshow(attn_map, cmap="viridis") | |
| ax.axis("off") # Turn off axes for clarity | |
| else: | |
| ax.axis("off") # Turn off unused subplots | |
| fig.set_size_inches(64 * cols / 100, 64 * rows / 100) | |
| save_path = os.path.join(save_dir, f"{attn_map_name}.jpg") | |
| plt.savefig(save_path) | |
| plt.close() | |
| for attn_map_name, attn_maps in self.attn_maps.items(): | |
| if "token_to_token" not in attn_map_name: | |
| continue | |
| plt.figure() | |
| rows, cols = 8, 19 | |
| fig, axes = plt.subplots( | |
| rows, cols, figsize=(2560 * cols / 100, 2560 * rows / 100) | |
| ) | |
| fig.subplots_adjust( | |
| wspace=0.1, hspace=0.1 | |
| ) # Adjust spacing between subplots | |
| # Plot each array in the list on the grid | |
| for i, ax in enumerate(axes.flatten()): | |
| if i < len(attn_maps): # Only plot existing arrays | |
| attn_map = attn_maps[i] / np.amax(attn_maps[i]) | |
| ax.imshow(attn_map, cmap="viridis") | |
| ax.axis("off") # Turn off axes for clarity | |
| else: | |
| ax.axis("off") # Turn off unused subplots | |
| fig.set_size_inches(64 * cols / 100, 64 * rows / 100) | |
| save_path = os.path.join(save_dir, f"{attn_map_name}.jpg") | |
| plt.savefig(save_path) | |
| plt.close() | |
| for attn_map_name in self.attn_maps.keys(): | |
| self.attn_maps[attn_map_name] = [] | |
| global_concept_mask = None | |
| forward_count = 0 | |