Spaces:
Runtime error
Runtime error
| # Copyright 2023 Bytedance Ltd. and/or its affiliates | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import nn | |
| from .attention import SpatioTemporalTransformerModel | |
| from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D | |
| import glob | |
| import json | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import BaseOutput, logging | |
| from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
| from .unet_3d_blocks import ( | |
| CrossAttnDownBlockPseudo3D, | |
| CrossAttnUpBlockPseudo3D, | |
| DownBlockPseudo3D, | |
| UNetMidBlockPseudo3DCrossAttn, | |
| UpBlockPseudo3D, | |
| get_down_block, | |
| get_up_block, | |
| ) | |
| from .resnet import PseudoConv3d | |
| from diffusers.models.cross_attention import AttnProcessor | |
| from typing import Dict | |
| def set_zero_parameters(module): | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| # ControlNet: Zero Convolution | |
| def zero_conv(channels): | |
| return set_zero_parameters(PseudoConv3d(channels, channels, 1, padding=0)) | |
| class ControlNetInputHintBlock(nn.Module): | |
| def __init__(self, hint_channels: int = 3, channels: int = 320): | |
| super().__init__() | |
| # Layer configurations are from reference implementation. | |
| self.input_hint_block = nn.Sequential( | |
| PseudoConv3d(hint_channels, 16, 3, padding=1), | |
| nn.SiLU(), | |
| PseudoConv3d(16, 16, 3, padding=1), | |
| nn.SiLU(), | |
| PseudoConv3d(16, 32, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| PseudoConv3d(32, 32, 3, padding=1), | |
| nn.SiLU(), | |
| PseudoConv3d(32, 96, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| PseudoConv3d(96, 96, 3, padding=1), | |
| nn.SiLU(), | |
| PseudoConv3d(96, 256, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| set_zero_parameters(PseudoConv3d(256, channels, 3, padding=1)), | |
| ) | |
| def forward(self, hint: torch.Tensor): | |
| return self.input_hint_block(hint) | |
| class ControlNetPseudoZeroConv3dBlock(nn.Module): | |
| def __init__( | |
| self, | |
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
| down_block_types: Tuple[str] = ( | |
| "CrossAttnDownBlockPseudo3D", | |
| "CrossAttnDownBlockPseudo3D", | |
| "CrossAttnDownBlockPseudo3D", | |
| "DownBlockPseudo3D", | |
| ), | |
| layers_per_block: int = 2, | |
| ): | |
| super().__init__() | |
| self.input_zero_conv = zero_conv(block_out_channels[0]) | |
| zero_convs = [] | |
| for i, down_block_type in enumerate(down_block_types): | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| for _ in range(layers_per_block): | |
| zero_convs.append(zero_conv(output_channel)) | |
| if not is_final_block: | |
| zero_convs.append(zero_conv(output_channel)) | |
| self.zero_convs = nn.ModuleList(zero_convs) | |
| self.mid_zero_conv = zero_conv(block_out_channels[-1]) | |
| def forward( | |
| self, | |
| down_block_res_samples: List[torch.Tensor], | |
| mid_block_sample: torch.Tensor, | |
| ) -> List[torch.Tensor]: | |
| outputs = [] | |
| outputs.append(self.input_zero_conv(down_block_res_samples[0])) | |
| for res_sample, zero_conv in zip(down_block_res_samples[1:], self.zero_convs): | |
| outputs.append(zero_conv(res_sample)) | |
| outputs.append(self.mid_zero_conv(mid_block_sample)) | |
| return outputs | |