Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| # from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test | |
| class AbstractEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def encode(self, *args, **kwargs): | |
| raise NotImplementedError | |
| class ClassEmbedder(nn.Module): | |
| def __init__(self, embed_dim, n_classes=1000, key='class'): | |
| super().__init__() | |
| self.key = key | |
| self.embedding = nn.Embedding(n_classes, embed_dim) | |
| def forward(self, batch, key=None): | |
| if key is None: | |
| key = self.key | |
| # this is for use in crossattn | |
| c = batch[key][:, None] | |
| c = self.embedding(c) | |
| return c | |
| class TransformerEmbedder(AbstractEncoder): | |
| """Some transformer encoder layers""" | |
| def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77): | |
| super().__init__() | |
| self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, | |
| attn_layers=Encoder(dim=n_embed, depth=n_layer)) | |
| def forward(self, tokens): | |
| z = self.transformer(tokens, return_embeddings=True) | |
| return z | |
| def encode(self, x): | |
| return self(x) | |
| class BERTTokenizer(AbstractEncoder): | |
| """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" | |
| def __init__(self, device="cuda", vq_interface=True, max_length=77): | |
| super().__init__() | |
| from transformers import BertTokenizerFast # TODO: add to reuquirements | |
| self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
| self.vq_interface = vq_interface | |
| self.max_length = max_length | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, | |
| return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
| tokens = batch_encoding["input_ids"] | |
| return tokens | |
| def encode(self, text): | |
| tokens = self(text) | |
| if not self.vq_interface: | |
| return tokens | |
| return None, None, [None, None, tokens] | |
| def decode(self, text): | |
| return text | |
| class BERTEmbedder(AbstractEncoder): | |
| """Uses the BERT tokenizr model and add some transformer encoder layers""" | |
| def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, | |
| ckpt_path=None, ignore_keys=[], device="cuda", use_tokenizer=True, embedding_dropout=0.0): | |
| super().__init__() | |
| self.use_tknz_fn = use_tokenizer | |
| if self.use_tknz_fn: | |
| self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) | |
| self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, | |
| attn_layers=Encoder(dim=n_embed, depth=n_layer), | |
| emb_dropout=embedding_dropout) | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| def init_from_ckpt(self, path, ignore_keys=list()): | |
| sd = torch.load(path, map_location="cpu") | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| missing, unexpected = self.load_state_dict(sd, strict=False) | |
| print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
| def forward(self, text): | |
| if self.use_tknz_fn: | |
| tokens = self.tknz_fn(text) | |
| else: | |
| tokens = text | |
| device = self.transformer.token_emb.weight.device # a trick to get device | |
| tokens = tokens.to(device) | |
| z = self.transformer(tokens, return_embeddings=True) | |
| return z | |
| def encode(self, text): | |
| # output of length 77 | |
| return self(text) | |
| class SpatialRescaler(nn.Module): | |
| def __init__(self, | |
| n_stages=1, | |
| method='bilinear', | |
| multiplier=0.5, | |
| in_channels=3, | |
| out_channels=None, | |
| bias=False): | |
| super().__init__() | |
| self.n_stages = n_stages | |
| assert self.n_stages >= 0 | |
| assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] | |
| self.multiplier = multiplier | |
| self.interpolator = partial(torch.nn.functional.interpolate, mode=method) | |
| self.remap_output = out_channels is not None | |
| if self.remap_output: | |
| print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') | |
| self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) | |
| def forward(self,x): | |
| for stage in range(self.n_stages): | |
| x = self.interpolator(x, scale_factor=self.multiplier) | |
| if self.remap_output: | |
| x = self.channel_mapper(x) | |
| return x | |
| def encode(self, x): | |
| return self(x) | |