Spaces:
Runtime error
Runtime error
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from re import U | |
| import numpy as np | |
| from einops import rearrange | |
| from .masactrl_utils import AttentionBase | |
| from torchvision.utils import save_image | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| import torch.fft as fft | |
| from einops import rearrange, repeat | |
| from diffusers.utils import deprecate, logging | |
| from diffusers.utils.import_utils import is_xformers_available | |
| # from masactrl.masactrl import MutualSelfAttentionControl | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| class AttentionBase: | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_att_layers = -1 | |
| self.cur_att_layer = 0 | |
| def after_step(self): | |
| pass | |
| def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
| self.cur_att_layer += 1 | |
| if self.cur_att_layer == self.num_att_layers: | |
| self.cur_att_layer = 0 | |
| self.cur_step += 1 | |
| # after step | |
| self.after_step() | |
| return out | |
| def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| out = torch.einsum('b i j, b j d -> b i d', attn, v) | |
| out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) | |
| return out | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| class MaskPromptedStyleAttentionControl(AttentionBase): | |
| def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1, | |
| only_masked_region=False, guidance=0.0, | |
| style_mask=None, source_mask=None, de_bug=False): | |
| """ | |
| MaskPromptedSAC | |
| Args: | |
| start_step: the step to start mutual self-attention control | |
| start_layer: the layer to start mutual self-attention control | |
| layer_idx: list of the layers to apply mutual self-attention control | |
| step_idx: list the steps to apply mutual self-attention control | |
| total_steps: the total number of steps | |
| thres: the thereshold for mask thresholding | |
| ref_token_idx: the token index list for cross-attention map aggregation | |
| cur_token_idx: the token index list for cross-attention map aggregation | |
| mask_save_dir: the path to save the mask image | |
| """ | |
| super().__init__() | |
| self.total_steps = total_steps | |
| self.total_layers = 16 | |
| self.start_step = start_step | |
| self.start_layer = start_layer | |
| self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) | |
| self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) | |
| print("using MaskPromptStyleAttentionControl") | |
| print("MaskedSAC at denoising steps: ", self.step_idx) | |
| print("MaskedSAC at U-Net layers: ", self.layer_idx) | |
| self.de_bug = de_bug | |
| self.style_guidance = style_guidance | |
| self.only_masked_region = only_masked_region | |
| self.style_attn_step = style_attn_step | |
| self.self_attns = [] | |
| self.cross_attns = [] | |
| self.guidance = guidance | |
| self.style_mask = style_mask | |
| self.source_mask = source_mask | |
| def after_step(self): | |
| self.self_attns = [] | |
| self.cross_attns = [] | |
| def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs): | |
| B = q.shape[0] // num_heads | |
| H = W = int(np.sqrt(q.shape[1])) | |
| q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
| k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
| v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
| sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
| if q_mask is not None: | |
| sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
| if k_mask is not None: | |
| sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
| attn = sim.softmax(-1) if attn is None else attn | |
| if len(attn) == 2 * len(v): | |
| v = torch.cat([v] * 2) | |
| out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
| out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
| return out | |
| def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs): | |
| B = q.shape[0] // num_heads | |
| H = W = int(np.sqrt(q.shape[1])) | |
| q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
| k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
| v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
| sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
| if q_mask is not None: | |
| sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
| sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max) | |
| if k_mask is not None: | |
| sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
| sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max) | |
| sim = torch.cat([sim_fg, sim_bg]) | |
| attn = sim.softmax(-1) | |
| if len(attn) == 2 * len(v): | |
| v = torch.cat([v] * 2) | |
| out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
| out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
| return out | |
| def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
| """ | |
| Attention forward function | |
| """ | |
| if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
| return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
| B = q.shape[0] // num_heads // 2 | |
| H = W = int(np.sqrt(q.shape[1])) | |
| if self.style_mask is not None and self.source_mask is not None: | |
| #mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (4, H, W) | |
| heigh, width = self.style_mask.shape[-2:] | |
| mask_style = self.style_mask# (H, W) | |
| mask_source = self.source_mask# (H, W) | |
| scale = int(np.sqrt(heigh * width / q.shape[1])) | |
| # res = int(np.sqrt(q.shape[1])) | |
| spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1) | |
| spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1) | |
| else: | |
| spatial_mask_source=None | |
| spatial_mask_style=None | |
| if spatial_mask_style is None or spatial_mask_source is None: | |
| out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
| else: | |
| if self.only_masked_region: | |
| out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
| else: | |
| out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
| out = torch.cat([out_s,out_c,out_t],dim=0) | |
| return out | |
| def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
| if self.de_bug: | |
| import pdb; pdb.set_trace() | |
| qs, qc, qt = q.chunk(3) | |
| out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| if self.cur_step < self.style_attn_step: | |
| out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| else: | |
| out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| if self.style_guidance>=0: | |
| out_t = out_c + (out_t - out_c) * self.style_guidance | |
| return out_s,out_c,out_t | |
| def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
| qs, qc, qt = q.chunk(3) | |
| out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs) | |
| out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs) | |
| if self.de_bug: | |
| import pdb; pdb.set_trace() | |
| if self.cur_step < self.style_attn_step: | |
| out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| else: | |
| out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| if self.style_guidance>=0: | |
| out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance | |
| out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source) | |
| if self.de_bug: | |
| import pdb; pdb.set_trace() | |
| # print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source))) | |
| return out_s,out_c,out_t | |
| def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
| if self.de_bug: | |
| import pdb; pdb.set_trace() | |
| # To prevent query confusion, render fg and bg according to mask. | |
| qs, qc, qt = q.chunk(3) | |
| out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
| if self.cur_step < self.style_attn_step: | |
| out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| out_c_fg,out_c_bg = out_c.chunk(2) | |
| out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source) | |
| else: | |
| out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
| out_t_fg,out_t_bg = out_t.chunk(2) | |
| out_c_fg,out_c_bg = out_c.chunk(2) | |
| if self.style_guidance>=0: | |
| out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance | |
| out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance | |
| out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source) | |
| return out_s,out_t,out_t | |