Spaces:
Running
Running
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| try: | |
| from timm.layers import resample_abs_pos_embed | |
| except ImportError as err: | |
| print("ImportError: {0}".format(err)) | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| def make_vit_b16_backbone( | |
| model, | |
| encoder_feature_dims, | |
| encoder_feature_layer_ids, | |
| vit_features, | |
| start_index=1, | |
| use_grad_checkpointing=False, | |
| ) -> nn.Module: | |
| """Make a ViTb16 backbone for the DPT model.""" | |
| if use_grad_checkpointing: | |
| model.set_grad_checkpointing() | |
| vit_model = nn.Module() | |
| vit_model.hooks = encoder_feature_layer_ids | |
| vit_model.model = model | |
| vit_model.features = encoder_feature_dims | |
| vit_model.vit_features = vit_features | |
| vit_model.model.start_index = start_index | |
| vit_model.model.patch_size = vit_model.model.patch_embed.patch_size | |
| vit_model.model.is_vit = True | |
| vit_model.model.forward = vit_model.model.forward_features | |
| return vit_model | |
| def forward_features_eva_fixed(self, x): | |
| """Encode features.""" | |
| x = self.patch_embed(x) | |
| x, rot_pos_embed = self._pos_embed(x) | |
| for blk in self.blocks: | |
| if self.grad_checkpointing: | |
| x = checkpoint(blk, x, rot_pos_embed) | |
| else: | |
| x = blk(x, rot_pos_embed) | |
| x = self.norm(x) | |
| return x | |
| def resize_vit(model: nn.Module, img_size) -> nn.Module: | |
| """Resample the ViT module to the given size.""" | |
| patch_size = model.patch_embed.patch_size | |
| model.patch_embed.img_size = img_size | |
| grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) | |
| model.patch_embed.grid_size = grid_size | |
| pos_embed = resample_abs_pos_embed( | |
| model.pos_embed, | |
| grid_size, # img_size | |
| num_prefix_tokens=( | |
| 0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens | |
| ), | |
| ) | |
| model.pos_embed = torch.nn.Parameter(pos_embed) | |
| return model | |
| def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module: | |
| """Resample the ViT patch size to the given one.""" | |
| # interpolate patch embedding | |
| if hasattr(model, "patch_embed"): | |
| old_patch_size = model.patch_embed.patch_size | |
| if ( | |
| new_patch_size[0] != old_patch_size[0] | |
| or new_patch_size[1] != old_patch_size[1] | |
| ): | |
| patch_embed_proj = model.patch_embed.proj.weight | |
| patch_embed_proj_bias = model.patch_embed.proj.bias | |
| use_bias = True if patch_embed_proj_bias is not None else False | |
| _, _, h, w = patch_embed_proj.shape | |
| new_patch_embed_proj = torch.nn.functional.interpolate( | |
| patch_embed_proj, | |
| size=[new_patch_size[0], new_patch_size[1]], | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| new_patch_embed_proj = ( | |
| new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1]) | |
| ) | |
| model.patch_embed.proj = nn.Conv2d( | |
| in_channels=model.patch_embed.proj.in_channels, | |
| out_channels=model.patch_embed.proj.out_channels, | |
| kernel_size=new_patch_size, | |
| stride=new_patch_size, | |
| bias=use_bias, | |
| ) | |
| if use_bias: | |
| model.patch_embed.proj.bias = patch_embed_proj_bias | |
| model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj) | |
| model.patch_size = new_patch_size | |
| model.patch_embed.patch_size = new_patch_size | |
| model.patch_embed.img_size = ( | |
| int( | |
| model.patch_embed.img_size[0] | |
| * new_patch_size[0] | |
| / old_patch_size[0] | |
| ), | |
| int( | |
| model.patch_embed.img_size[1] | |
| * new_patch_size[1] | |
| / old_patch_size[1] | |
| ), | |
| ) | |
| return model | |