Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from typing import Tuple, Union, Any, List, Iterable, Optional | |
| from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d | |
| class ModifiedResNet(nn.Module): | |
| """ | |
| A ResNet class that is similar to torchvision's but contains the following changes: | |
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |
| - The final pooling layer is a QKV attention instead of an average pool | |
| """ | |
| def __init__( | |
| self, | |
| layers: Tuple[int, int, int, int], | |
| output_dim: int, | |
| input_resolution: int = 224, | |
| width: int = 64, | |
| heads: int = 8, | |
| features_only: bool = False, | |
| out_indices: Optional[Iterable[int]] = None, | |
| reduction: int = 32, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__() | |
| input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution | |
| assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" | |
| self.input_resolution = input_resolution | |
| self.downsampling_rate = 32 # the rate at which the input is downsampled by the network | |
| # the 3-layer stem | |
| self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(width // 2) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(width // 2) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(width) | |
| self.relu3 = nn.ReLU(inplace=True) | |
| self.avgpool = nn.AvgPool2d(2) | |
| # residual layers | |
| self._inplanes = width # this is a *mutable* variable used during construction | |
| self.layer1 = self._make_layer(width, layers[0]) | |
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |
| self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2) | |
| self.features_only = features_only | |
| if features_only: | |
| self.out_indices = out_indices if out_indices is not None else range(5) | |
| self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices | |
| self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort | |
| assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages" | |
| self.channels = width * 32 # the ResNet feature dimension | |
| else: | |
| self.out_indices = None | |
| embed_dim = width * 32 # the ResNet feature dimension | |
| self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim) | |
| self.channels = output_dim | |
| self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate | |
| self.clip_embed_dim = output_dim | |
| def _make_layer(self, planes, blocks, stride=1): | |
| layers = [Bottleneck(self._inplanes, planes, stride)] | |
| self._inplanes = planes * Bottleneck.expansion | |
| for _ in range(1, blocks): | |
| layers.append(Bottleneck(self._inplanes, planes)) | |
| return nn.Sequential(*layers) | |
| def _stem(self, x: Tensor) -> Tensor: | |
| x = self.relu1(self.bn1(self.conv1(x))) | |
| x = self.relu2(self.bn2(self.conv2(x))) | |
| x = self.relu3(self.bn3(self.conv3(x))) | |
| x = self.avgpool(x) | |
| return x | |
| def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]: | |
| x = x.type(self.conv1.weight.dtype) | |
| x = self._stem(x) | |
| feats = [x] if self.features_only and 0 in self.out_indices else [] | |
| x = self.layer1(x) | |
| if self.features_only and 1 in self.out_indices: | |
| feats.append(x) | |
| x = self.layer2(x) | |
| if self.features_only and 2 in self.out_indices: | |
| feats.append(x) | |
| x = self.layer3(x) | |
| if self.features_only and 3 in self.out_indices: | |
| feats.append(x) | |
| x = self.layer4(x) | |
| if self.features_only and 4 in self.out_indices: | |
| feats.append(x) | |
| if self.features_only: | |
| if len(self.out_indices) == 1: | |
| return feats[0] | |
| else: | |
| return feats | |
| else: | |
| x = self.attnpool(x) | |
| return x | |
| class VisionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| input_resolution: Union[int, Tuple[int, int]], | |
| patch_size: Union[int, Tuple[int, int]], | |
| output_dim: int, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| features_only: bool = False, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__() | |
| input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution | |
| patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size | |
| assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" | |
| assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}" | |
| assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid." | |
| assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}" | |
| self.input_resolution = input_resolution | |
| self.patch_size = patch_size | |
| self.downsampling_rate = patch_size[0] | |
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) | |
| scale = width ** -0.5 | |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
| self.num_patches_h = int(input_resolution[0] // patch_size[0]) | |
| self.num_patches_w = int(input_resolution[1] // patch_size[1]) | |
| self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width)) | |
| self.ln_pre = LayerNorm(width) | |
| self.transformer = Transformer(width, layers, heads) | |
| self.ln_post = LayerNorm(width) | |
| self.features_only = features_only # if True, return the final patches instead of the CLS token | |
| if features_only: | |
| self.channels = width | |
| else: | |
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
| self.channels = output_dim | |
| self.reduction = patch_size[0] | |
| self.clip_embed_dim = output_dim | |
| def adjust_pos_embed(self, h: int, w: int) -> None: | |
| """ | |
| Permanently adjust the size of the positional embedding matrix. | |
| Args: | |
| h: the height of the original input image. | |
| w: the width of the original input image. | |
| """ | |
| assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}" | |
| if self.input_resolution[0] != h or self.input_resolution[1] != w: | |
| new_num_patches_h = int(h // self.patch_size[0]) | |
| new_num_patches_w = int(w // self.patch_size[1]) | |
| positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension | |
| positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension | |
| positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") | |
| self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)) | |
| self.input_resolution = (h, w) | |
| self.num_patches_h = new_num_patches_h | |
| self.num_patches_w = new_num_patches_w | |
| def _interpolate_pos_embed(self, h: int, w: int) -> Tensor: | |
| """ | |
| Interpolate the positional embedding matrix to match the size of the input image. | |
| Args: | |
| h: the required number of patches along the height dimension. | |
| w: the required number of patches along the width dimension. | |
| """ | |
| if h == self.num_patches_h and w == self.num_patches_w: | |
| return self.positional_embedding | |
| else: | |
| positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension | |
| positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension | |
| positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") | |
| positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0) | |
| return positional_embedding | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| num_patches_h, num_patches_w = x.shape[-2:] | |
| positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype) | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| x = torch.cat([ | |
| self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
| x | |
| ], dim=1) | |
| x = x + positional_embedding | |
| x = self.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension | |
| x = self.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_post(x) | |
| if self.features_only: | |
| x = x[:, 1:, :] # remove the CLS token | |
| x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w) | |
| else: | |
| x = x[:, 0, :] | |
| x = x @ self.proj | |
| return x | |