Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn, Tensor | |
| from .blocks import LayerNorm, Transformer | |
| class CLIPTextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| context_length: int, | |
| vocab_size: int, | |
| transformer_width: int, | |
| transformer_heads: int, | |
| transformer_layers: int, | |
| ) -> None: | |
| super().__init__() | |
| self.context_length = context_length | |
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
| self.transformer = Transformer( | |
| width=transformer_width, | |
| layers=transformer_layers, | |
| heads=transformer_heads, | |
| attn_mask=self.build_attention_mask(), | |
| ) | |
| self.vocab_size = vocab_size | |
| self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
| self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
| self.ln_final = LayerNorm(transformer_width) | |
| self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
| def build_attention_mask(self): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(self.context_length, self.context_length) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def dtype(self): | |
| return self.transformer.resblocks[0].attn.in_proj_weight.dtype | |
| def forward(self, text: Tensor): | |
| x = self.token_embedding(text).type(self.dtype) | |
| x = x + self.positional_embedding.type(self.dtype) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_final(x).type(self.dtype) | |
| x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
| return x | |