Spaces:
Runtime error
Runtime error
| from typing import Type | |
| from torch import nn | |
| # Lightly adapted from | |
| # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa | |
| class MLPBlock(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| num_layers: int, | |
| act: Type[nn.Module], | |
| ) -> None: | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Sequential(nn.Linear(n, k), act()) | |
| for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) | |
| ) | |
| self.fc = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return self.fc(x) | |