Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from ldm.modules.attention import default, zero_module, checkpoint | |
| from ldm.modules.diffusionmodules.openaimodel import UNetModel | |
| from ldm.modules.diffusionmodules.util import timestep_embedding | |
| class DepthAttention(nn.Module): | |
| def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| context_dim = default(context_dim, query_dim) | |
| self.scale = dim_head ** -0.5 | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False) | |
| self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) | |
| self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) | |
| if output_bias: | |
| self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1) | |
| else: | |
| self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False) | |
| def forward(self, x, context): | |
| """ | |
| @param x: b,f0,h,w | |
| @param context: b,f1,d,h,w | |
| @return: | |
| """ | |
| hn, hd = self.heads, self.dim_head | |
| b, _, h, w = x.shape | |
| b, _, d, h, w = context.shape | |
| q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w | |
| k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w | |
| v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w | |
| sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w | |
| attn = sim.softmax(dim=2) | |
| # b,hn,hd,d,h,w * b,hn,1,d,h,w | |
| out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w | |
| out = out.reshape(b,hn*hd,h,w) | |
| return self.to_out(out) | |
| class DepthTransformer(nn.Module): | |
| def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): | |
| super().__init__() | |
| inner_dim = n_heads * d_head | |
| self.proj_in = nn.Sequential( | |
| nn.Conv2d(dim, inner_dim, 1, 1), | |
| nn.GroupNorm(8, inner_dim), | |
| nn.SiLU(True), | |
| ) | |
| self.proj_context = nn.Sequential( | |
| nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias | |
| nn.GroupNorm(8, context_dim), | |
| nn.ReLU(True), # only relu, because we want input is 0, output is 0 | |
| ) | |
| self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn | |
| self.proj_out = nn.Sequential( | |
| nn.GroupNorm(8, inner_dim), | |
| nn.ReLU(True), | |
| nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False), | |
| nn.GroupNorm(8, inner_dim), | |
| nn.ReLU(True), | |
| zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)), | |
| ) | |
| self.checkpoint = checkpoint | |
| def forward(self, x, context=None): | |
| return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) | |
| def _forward(self, x, context): | |
| x_in = x | |
| x = self.proj_in(x) | |
| context = self.proj_context(context) | |
| x = self.depth_attn(x, context) | |
| x = self.proj_out(x) + x_in | |
| return x | |
| class DepthWiseAttention(UNetModel): | |
| def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # num_heads = 4 | |
| model_channels = kwargs['model_channels'] | |
| channel_mult = kwargs['channel_mult'] | |
| d0,d1,d2,d3 = volume_dims | |
| # 4 | |
| ch = model_channels*channel_mult[2] | |
| self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3) | |
| self.output_conditions=nn.ModuleList() | |
| self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8} | |
| # 8 | |
| ch = model_channels*channel_mult[2] | |
| self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 | |
| # 16 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 | |
| ch = model_channels*channel_mult[1] | |
| self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 | |
| # 32 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 | |
| ch = model_channels*channel_mult[0] | |
| self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 | |
| self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 | |
| def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs): | |
| hs = [] | |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
| emb = self.time_embed(t_emb) | |
| h = x.type(self.dtype) | |
| for index, module in enumerate(self.input_blocks): | |
| h = module(h, emb, context) | |
| hs.append(h) | |
| h = self.middle_block(h, emb, context) | |
| h = self.middle_conditions(h, context=source_dict[h.shape[-1]]) | |
| for index, module in enumerate(self.output_blocks): | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb, context) | |
| if index in self.output_b2c: | |
| layer = self.output_conditions[self.output_b2c[index]] | |
| h = layer(h, context=source_dict[h.shape[-1]]) | |
| h = h.type(x.dtype) | |
| return self.out(h) | |
| def get_trainable_parameters(self): | |
| paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()] | |
| return paras | |