Spaces:
Running
Running
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| # DepthProEncoder combining patch and image encoders. | |
| from __future__ import annotations | |
| import math | |
| from typing import Iterable, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class DepthProEncoder(nn.Module): | |
| """DepthPro Encoder. | |
| An encoder aimed at creating multi-resolution encodings from Vision Transformers. | |
| """ | |
| def __init__( | |
| self, | |
| dims_encoder: Iterable[int], | |
| patch_encoder: nn.Module, | |
| image_encoder: nn.Module, | |
| hook_block_ids: Iterable[int], | |
| decoder_features: int, | |
| ): | |
| """Initialize DepthProEncoder. | |
| The framework | |
| 1. creates an image pyramid, | |
| 2. generates overlapping patches with a sliding window at each pyramid level, | |
| 3. creates batched encodings via vision transformer backbones, | |
| 4. produces multi-resolution encodings. | |
| Args: | |
| ---- | |
| img_size: Backbone image resolution. | |
| dims_encoder: Dimensions of the encoder at different layers. | |
| patch_encoder: Backbone used for patches. | |
| image_encoder: Backbone used for global image encoder. | |
| hook_block_ids: Hooks to obtain intermediate features for the patch encoder model. | |
| decoder_features: Number of feature output in the decoder. | |
| """ | |
| super().__init__() | |
| self.dims_encoder = list(dims_encoder) | |
| self.patch_encoder = patch_encoder | |
| self.image_encoder = image_encoder | |
| self.hook_block_ids = list(hook_block_ids) | |
| patch_encoder_embed_dim = patch_encoder.embed_dim | |
| image_encoder_embed_dim = image_encoder.embed_dim | |
| self.out_size = int( | |
| patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0] | |
| ) | |
| def _create_project_upsample_block( | |
| dim_in: int, | |
| dim_out: int, | |
| upsample_layers: int, | |
| dim_int: Optional[int] = None, | |
| ) -> nn.Module: | |
| if dim_int is None: | |
| dim_int = dim_out | |
| # Projection. | |
| blocks = [ | |
| nn.Conv2d( | |
| in_channels=dim_in, | |
| out_channels=dim_int, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ) | |
| ] | |
| # Upsampling. | |
| blocks += [ | |
| nn.ConvTranspose2d( | |
| in_channels=dim_int if i == 0 else dim_out, | |
| out_channels=dim_out, | |
| kernel_size=2, | |
| stride=2, | |
| padding=0, | |
| bias=False, | |
| ) | |
| for i in range(upsample_layers) | |
| ] | |
| return nn.Sequential(*blocks) | |
| self.upsample_latent0 = _create_project_upsample_block( | |
| dim_in=patch_encoder_embed_dim, | |
| dim_int=self.dims_encoder[0], | |
| dim_out=decoder_features, | |
| upsample_layers=3, | |
| ) | |
| self.upsample_latent1 = _create_project_upsample_block( | |
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2 | |
| ) | |
| self.upsample0 = _create_project_upsample_block( | |
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1 | |
| ) | |
| self.upsample1 = _create_project_upsample_block( | |
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1 | |
| ) | |
| self.upsample2 = _create_project_upsample_block( | |
| dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1 | |
| ) | |
| self.upsample_lowres = nn.ConvTranspose2d( | |
| in_channels=image_encoder_embed_dim, | |
| out_channels=self.dims_encoder[3], | |
| kernel_size=2, | |
| stride=2, | |
| padding=0, | |
| bias=True, | |
| ) | |
| self.fuse_lowres = nn.Conv2d( | |
| in_channels=(self.dims_encoder[3] + self.dims_encoder[3]), | |
| out_channels=self.dims_encoder[3], | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ) | |
| # Obtain intermediate outputs of the blocks. | |
| self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook( | |
| self._hook0 | |
| ) | |
| self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook( | |
| self._hook1 | |
| ) | |
| def _hook0(self, model, input, output): | |
| self.backbone_highres_hook0 = output | |
| def _hook1(self, model, input, output): | |
| self.backbone_highres_hook1 = output | |
| def img_size(self) -> int: | |
| """Return the full image size of the SPN network.""" | |
| return self.patch_encoder.patch_embed.img_size[0] * 4 | |
| def _create_pyramid( | |
| self, x: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Create a 3-level image pyramid.""" | |
| # Original resolution: 1536 by default. | |
| x0 = x | |
| # Middle resolution: 768 by default. | |
| x1 = F.interpolate( | |
| x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False | |
| ) | |
| # Low resolution: 384 by default, corresponding to the backbone resolution. | |
| x2 = F.interpolate( | |
| x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False | |
| ) | |
| return x0, x1, x2 | |
| def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: | |
| """Split the input into small patches with sliding window.""" | |
| patch_size = 384 | |
| patch_stride = int(patch_size * (1 - overlap_ratio)) | |
| image_size = x.shape[-1] | |
| steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 | |
| x_patch_list = [] | |
| for j in range(steps): | |
| j0 = j * patch_stride | |
| j1 = j0 + patch_size | |
| for i in range(steps): | |
| i0 = i * patch_stride | |
| i1 = i0 + patch_size | |
| x_patch_list.append(x[..., j0:j1, i0:i1]) | |
| return torch.cat(x_patch_list, dim=0) | |
| def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: | |
| """Merge the patched input into a image with sliding window.""" | |
| steps = int(math.sqrt(x.shape[0] // batch_size)) | |
| idx = 0 | |
| output_list = [] | |
| for j in range(steps): | |
| output_row_list = [] | |
| for i in range(steps): | |
| output = x[batch_size * idx : batch_size * (idx + 1)] | |
| if j != 0: | |
| output = output[..., padding:, :] | |
| if i != 0: | |
| output = output[..., :, padding:] | |
| if j != steps - 1: | |
| output = output[..., :-padding, :] | |
| if i != steps - 1: | |
| output = output[..., :, :-padding] | |
| output_row_list.append(output) | |
| idx += 1 | |
| output_row = torch.cat(output_row_list, dim=-1) | |
| output_list.append(output_row) | |
| output = torch.cat(output_list, dim=-2) | |
| return output | |
| def reshape_feature( | |
| self, embeddings: torch.Tensor, width, height, cls_token_offset=1 | |
| ): | |
| """Discard class token and reshape 1D feature map to a 2D grid.""" | |
| b, hw, c = embeddings.shape | |
| # Remove class token. | |
| if cls_token_offset > 0: | |
| embeddings = embeddings[:, cls_token_offset:, :] | |
| # Shape: (batch, height, width, dim) -> (batch, dim, height, width) | |
| embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2) | |
| return embeddings | |
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| """Encode input at multiple resolutions. | |
| Args: | |
| ---- | |
| x (torch.Tensor): Input image. | |
| Returns: | |
| ------- | |
| Multi resolution encoded features. | |
| """ | |
| batch_size = x.shape[0] | |
| # Step 0: create a 3-level image pyramid. | |
| x0, x1, x2 = self._create_pyramid(x) | |
| # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino) | |
| # resolution. | |
| # 5x5 @ 384x384 at the highest resolution (1536x1536). | |
| x0_patches = self.split(x0, overlap_ratio=0.25) | |
| # 3x3 @ 384x384 at the middle resolution (768x768). | |
| x1_patches = self.split(x1, overlap_ratio=0.5) | |
| # 1x1 # 384x384 at the lowest resolution (384x384). | |
| x2_patches = x2 | |
| # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1). | |
| x_pyramid_patches = torch.cat( | |
| (x0_patches, x1_patches, x2_patches), | |
| dim=0, | |
| ) | |
| # Step 2: Run the backbone (BeiT) model and get the result of large batch size. | |
| x_pyramid_encodings = self.patch_encoder(x_pyramid_patches) | |
| x_pyramid_encodings = self.reshape_feature( | |
| x_pyramid_encodings, self.out_size, self.out_size | |
| ) | |
| # Step 3: merging. | |
| # Merge highres latent encoding. | |
| x_latent0_encodings = self.reshape_feature( | |
| self.backbone_highres_hook0, | |
| self.out_size, | |
| self.out_size, | |
| ) | |
| x_latent0_features = self.merge( | |
| x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 | |
| ) | |
| x_latent1_encodings = self.reshape_feature( | |
| self.backbone_highres_hook1, | |
| self.out_size, | |
| self.out_size, | |
| ) | |
| x_latent1_features = self.merge( | |
| x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 | |
| ) | |
| # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1. | |
| x0_encodings, x1_encodings, x2_encodings = torch.split( | |
| x_pyramid_encodings, | |
| [len(x0_patches), len(x1_patches), len(x2_patches)], | |
| dim=0, | |
| ) | |
| # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps. | |
| x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3) | |
| # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps. | |
| x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6) | |
| # 24x24 feature maps. | |
| x2_features = x2_encodings | |
| # Apply the image encoder model. | |
| x_global_features = self.image_encoder(x2_patches) | |
| x_global_features = self.reshape_feature( | |
| x_global_features, self.out_size, self.out_size | |
| ) | |
| # Upsample feature maps. | |
| x_latent0_features = self.upsample_latent0(x_latent0_features) | |
| x_latent1_features = self.upsample_latent1(x_latent1_features) | |
| x0_features = self.upsample0(x0_features) | |
| x1_features = self.upsample1(x1_features) | |
| x2_features = self.upsample2(x2_features) | |
| x_global_features = self.upsample_lowres(x_global_features) | |
| x_global_features = self.fuse_lowres( | |
| torch.cat((x2_features, x_global_features), dim=1) | |
| ) | |
| return [ | |
| x_latent0_features, | |
| x_latent1_features, | |
| x0_features, | |
| x1_features, | |
| x_global_features, | |
| ] | |