Spaces:
Running
on
Zero
Running
on
Zero
| from enum import Enum, unique | |
| from typing import Any | |
| import torch | |
| import torchvision.transforms.v2 as transforms | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel | |
| from torch import Tensor, nn | |
| from transformers import ( | |
| AutoImageProcessor, | |
| AutoModel, | |
| AutoProcessor, | |
| CLIPImageProcessor, | |
| CLIPVisionModel, | |
| SiglipImageProcessor, | |
| SiglipVisionModel, | |
| ) | |
| class TryOffDiff(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") | |
| self.transformer = torch.nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True) | |
| self.proj = nn.Linear(1024, 77) | |
| self.norm = nn.LayerNorm(768) | |
| def forward(self, noisy_latents, t, cond_emb): | |
| cond_emb = self.transformer(cond_emb) | |
| cond_emb = self.proj(cond_emb.transpose(1, 2)) | |
| cond_emb = self.norm(cond_emb.transpose(1, 2)) | |
| return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample | |
| class TryOffDiffv2(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.unet = UNet2DConditionModel( | |
| sample_size=64, | |
| in_channels=4, | |
| out_channels=4, | |
| layers_per_block=2, | |
| block_out_channels=(320, 640, 1280, 1280), | |
| down_block_types=( | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| up_block_types=( | |
| "UpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| ), | |
| cross_attention_dim=768, | |
| class_embed_type=None, | |
| num_class_embeds=3, | |
| ) | |
| # Load the pretrained weights into the custom model, skipping incompatible keys | |
| pretrained_state_dict = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").state_dict() | |
| self.unet.load_state_dict(pretrained_state_dict, strict=False) | |
| self.proj = nn.Linear(1024, 77) | |
| self.norm = nn.LayerNorm(768) | |
| def forward(self, noisy_latents, t, cond_emb, class_labels): | |
| cond_emb = self.proj(cond_emb.transpose(1, 2)) | |
| cond_emb = self.norm(cond_emb.transpose(1, 2)) | |
| return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb, class_labels=class_labels).sample | |
| class TryOffDiffv2Single(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") | |
| self.proj = nn.Linear(1024, 77) | |
| self.norm = nn.LayerNorm(768) | |
| def forward(self, noisy_latents, t, cond_emb): | |
| cond_emb = self.proj(cond_emb.transpose(1, 2)) | |
| cond_emb = self.norm(cond_emb.transpose(1, 2)) | |
| return self.unet(noisy_latents, t, encoder_hidden_states=cond_emb).sample | |
| class ModelName(Enum): | |
| TryOffDiff = TryOffDiff | |
| TryOffDiffv2 = TryOffDiffv2 | |
| TryOffDiffv2Single = TryOffDiffv2Single | |
| def create_model(model_name: str, **kwargs: Any) -> Any: | |
| model_class = ModelName[model_name].value | |
| return model_class(**kwargs) |