Spaces:
Running
Running
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| # Field of View network architecture. | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class FOVNetwork(nn.Module): | |
| """Field of View estimation network.""" | |
| def __init__( | |
| self, | |
| num_features: int, | |
| fov_encoder: Optional[nn.Module] = None, | |
| ): | |
| """Initialize the Field of View estimation block. | |
| Args: | |
| ---- | |
| num_features: Number of features used. | |
| fov_encoder: Optional encoder to bring additional network capacity. | |
| """ | |
| super().__init__() | |
| # Create FOV head. | |
| fov_head0 = [ | |
| nn.Conv2d( | |
| num_features, num_features // 2, kernel_size=3, stride=2, padding=1 | |
| ), # 128 x 24 x 24 | |
| nn.ReLU(True), | |
| ] | |
| fov_head = [ | |
| nn.Conv2d( | |
| num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1 | |
| ), # 64 x 12 x 12 | |
| nn.ReLU(True), | |
| nn.Conv2d( | |
| num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1 | |
| ), # 32 x 6 x 6 | |
| nn.ReLU(True), | |
| nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0), | |
| ] | |
| if fov_encoder is not None: | |
| self.encoder = nn.Sequential( | |
| fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2) | |
| ) | |
| self.downsample = nn.Sequential(*fov_head0) | |
| else: | |
| fov_head = fov_head0 + fov_head | |
| self.head = nn.Sequential(*fov_head) | |
| def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor: | |
| """Forward the fov network. | |
| Args: | |
| ---- | |
| x (torch.Tensor): Input image. | |
| lowres_feature (torch.Tensor): Low resolution feature. | |
| Returns: | |
| ------- | |
| The field of view tensor. | |
| """ | |
| if hasattr(self, "encoder"): | |
| x = F.interpolate( | |
| x, | |
| size=None, | |
| scale_factor=0.25, | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| x = self.encoder(x)[:, 1:].permute(0, 2, 1) | |
| lowres_feature = self.downsample(lowres_feature) | |
| x = x.reshape_as(lowres_feature) + lowres_feature | |
| else: | |
| x = lowres_feature | |
| return self.head(x) | |