Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Any, Dict, Optional | |
| import torch | |
| from torch import nn | |
| from diffusers.models.attention import Attention | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from einops import rearrange, repeat | |
| import math | |
| import torch.nn.functional as F | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| class RowwiseMVAttention(Attention): | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersMVAttnProcessor() | |
| self.set_processor(processor) | |
| # print("using xformers attention processor") | |
| class IPCDAttention(Attention): | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersIPCDAttnProcessor() | |
| self.set_processor(processor) | |
| # print("using xformers attention processor") | |
| class XFormersMVAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_views=1, | |
| multiview_attention=True, | |
| cd_attention_mid=False | |
| ): | |
| # print(num_views) | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| height = int(math.sqrt(sequence_length)) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # from yuancheng; here attention_mask is None | |
| if attention_mask is not None: | |
| # expand our mask's singleton query_tokens dimension: | |
| # [batch*heads, 1, key_tokens] -> | |
| # [batch*heads, query_tokens, key_tokens] | |
| # so that it can be added as a bias onto the attention scores that xformers computes: | |
| # [batch*heads, query_tokens, key_tokens] | |
| # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
| _, query_tokens, _ = hidden_states.shape | |
| attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
| if attn.group_norm is not None: | |
| print('Warning: using group norm, pay attention to use it in row-wise attention') | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key_raw = attn.to_k(encoder_hidden_states) | |
| value_raw = attn.to_v(encoder_hidden_states) | |
| # print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
| # pdb.set_trace() | |
| def transpose(tensor): | |
| tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) | |
| tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c | |
| tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c | |
| tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) | |
| return tensor | |
| # print(mvcd_attention) | |
| # import pdb;pdb.set_trace() | |
| if cd_attention_mid: | |
| key = transpose(key_raw) | |
| value = transpose(value_raw) | |
| query = transpose(query) | |
| else: | |
| key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
| value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
| query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) | |
| query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if cd_attention_mid: | |
| hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) | |
| hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c | |
| hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c | |
| hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) | |
| else: | |
| hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class XFormersIPCDAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def process(self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_tasks=2, | |
| num_views=6): | |
| ### TODO: num_views | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| height = int(math.sqrt(sequence_length)) | |
| height_st = height // 3 | |
| height_end = height - height_st | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # from yuancheng; here attention_mask is None | |
| if attention_mask is not None: | |
| # expand our mask's singleton query_tokens dimension: | |
| # [batch*heads, 1, key_tokens] -> | |
| # [batch*heads, query_tokens, key_tokens] | |
| # so that it can be added as a bias onto the attention scores that xformers computes: | |
| # [batch*heads, query_tokens, key_tokens] | |
| # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
| _, query_tokens, _ = hidden_states.shape | |
| attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| assert num_tasks == 2 # only support two tasks now | |
| # ip attn | |
| # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) | |
| # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] | |
| # print(body_hidden_states.shape, face_hidden_states.shape) | |
| # import pdb;pdb.set_trace() | |
| # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) | |
| # hidden_states = rearrange( | |
| # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), | |
| # 'b v l c -> (b v) l c') | |
| # face cross attention | |
| # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) | |
| # ip_key = attn.to_k_ip(ip_hidden_states) | |
| # ip_value = attn.to_v_ip(ip_hidden_states) | |
| # ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
| # ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
| # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() | |
| # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) | |
| # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
| # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) | |
| # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) | |
| # import pdb;pdb.set_trace() | |
| def transpose(tensor): | |
| tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c | |
| tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c | |
| # tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1) | |
| # body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c | |
| # face = face.repeat(1, num_views, 1, 1) # b,v,l,c | |
| # tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c | |
| # tensor = rearrange(tensor, "b v l c -> (b v) l c") | |
| return tensor | |
| key = transpose(key) | |
| value = transpose(value) | |
| query = transpose(query) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c | |
| hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) | |
| face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach() | |
| face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') | |
| hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal | |
| hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal | |
| # hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal | |
| hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c") | |
| hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) | |
| face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach() | |
| face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') | |
| hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color | |
| hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color | |
| # hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color | |
| hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c") | |
| hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_tasks=2, | |
| ): | |
| hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks) | |
| # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c') | |
| # body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :] | |
| # import pdb;pdb.set_trace() | |
| # hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1) | |
| # hidden_states = rearrange( | |
| # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1), | |
| # 'b v l c -> (b v) l c') | |
| return hidden_states | |
| class IPCrossAttn(Attention): | |
| r""" | |
| Attention processor for IP-Adapater. | |
| Args: | |
| hidden_size (`int`): | |
| The hidden size of the attention layer. | |
| cross_attention_dim (`int`): | |
| The number of channels in the `encoder_hidden_states`. | |
| scale (`float`, defaults to 1.0): | |
| the weight scale of image prompt. | |
| num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
| The context length of the image features. | |
| """ | |
| def __init__(self, | |
| query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0): | |
| super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention) | |
| self.ip_scale = ip_scale | |
| # self.num_tokens = num_tokens | |
| # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
| # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
| # self.to_out_ip = nn.ModuleList([]) | |
| # self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias)) | |
| # self.to_out_ip.append(nn.Dropout(dropout)) | |
| # nn.init.zeros_(self.to_k_ip.weight.data) | |
| # nn.init.zeros_(self.to_v_ip.weight.data) | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersIPCrossAttnProcessor() | |
| self.set_processor(processor) | |
| class XFormersIPCrossAttnProcessor: | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_views=1 | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # ip attn | |
| # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) | |
| # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] | |
| # print(body_hidden_states.shape, face_hidden_states.shape) | |
| # import pdb;pdb.set_trace() | |
| # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) | |
| # hidden_states = rearrange( | |
| # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), | |
| # 'b v l c -> (b v) l c') | |
| # face cross attention | |
| # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) | |
| # ip_key = attn.to_k_ip(ip_hidden_states) | |
| # ip_value = attn.to_v_ip(ip_hidden_states) | |
| # ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
| # ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
| # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() | |
| # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) | |
| # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
| # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) | |
| # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) | |
| # import pdb;pdb.set_trace() | |
| # body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states | |
| # hidden_states = rearrange( | |
| # torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1), | |
| # 'b v l c -> (b v) l c') | |
| # import pdb;pdb.set_trace() | |
| # | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| # TODO: region control | |
| # region control | |
| # if len(region_control.prompt_image_conditioning) == 1: | |
| # region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) | |
| # if region_mask is not None: | |
| # h, w = region_mask.shape[:2] | |
| # ratio = (h * w / query.shape[1]) ** 0.5 | |
| # mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) | |
| # else: | |
| # mask = torch.ones_like(ip_hidden_states) | |
| # ip_hidden_states = ip_hidden_states * mask | |
| return hidden_states | |
| class RowwiseMVProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_views=1, | |
| cd_attention_mid=False | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| height = int(math.sqrt(sequence_length)) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
| #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) | |
| # pdb.set_trace() | |
| # multi-view self-attention | |
| def transpose(tensor): | |
| tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) | |
| tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c | |
| tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c | |
| tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) | |
| return tensor | |
| if cd_attention_mid: | |
| key = transpose(key) | |
| value = transpose(value) | |
| query = transpose(query) | |
| else: | |
| key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
| value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
| query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if cd_attention_mid: | |
| hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) | |
| hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c | |
| hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c | |
| hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) | |
| else: | |
| hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class CDAttention(Attention): | |
| # def __init__(self, ip_scale, | |
| # query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor): | |
| # super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor) | |
| # self.ip_scale = ip_scale | |
| # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
| # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
| # nn.init.zeros_(self.to_k_ip.weight.data) | |
| # nn.init.zeros_(self.to_v_ip.weight.data) | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
| ): | |
| processor = XFormersCDAttnProcessor() | |
| self.set_processor(processor) | |
| # print("using xformers attention processor") | |
| class XFormersCDAttnProcessor: | |
| r""" | |
| Default processor for performing attention-related computations. | |
| """ | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| temb=None, | |
| num_tasks=2 | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| assert num_tasks == 2 # only support two tasks now | |
| def transpose(tensor): | |
| tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c | |
| tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c | |
| return tensor | |
| key = transpose(key) | |
| value = transpose(value) | |
| query = transpose(query) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |