Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # Copyright 2024 Wen-Chin Huang | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| # LDNet modules | |
| # taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself) | |
| import torch | |
| from torch import nn | |
| STRIDE = 3 | |
| class Projection(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| hidden_dim, | |
| activation, | |
| output_type, | |
| _output_dim, | |
| output_step=1.0, | |
| range_clipping=False, | |
| ): | |
| super(Projection, self).__init__() | |
| self.output_type = output_type | |
| self.range_clipping = range_clipping | |
| if output_type == "scalar": | |
| output_dim = 1 | |
| if range_clipping: | |
| self.proj = nn.Tanh() | |
| elif output_type == "categorical": | |
| output_dim = _output_dim | |
| self.output_step = output_step | |
| else: | |
| raise NotImplementedError("wrong output_type: {}".format(output_type)) | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden_dim), | |
| activation(), | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| def forward(self, x, inference=False): | |
| output = self.net(x) | |
| # scalar / categorical | |
| if self.output_type == "scalar": | |
| # range clipping | |
| if self.range_clipping: | |
| return self.proj(output) * 2.0 + 3 | |
| else: | |
| return output | |
| else: | |
| if inference: | |
| return torch.argmax(output, dim=-1) * self.output_step + 1 | |
| else: | |
| return output |