| import torch | |
| import torch.nn as nn | |
| class ClassificationHead(nn.Module): | |
| """Classification head for the network.""" | |
| def __init__(self, id_to_gps): | |
| super().__init__() | |
| self.id_to_gps = id_to_gps | |
| def forward(self, x): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| gps = self.id_to_gps(x.argmax(dim=-1)) | |
| return {"label": x, **gps} | |