Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| class FCN(nn.Module): | |
| def __init__(self, d_model, n_commands, n_args, args_dim=256, abs_targets=False): | |
| super().__init__() | |
| self.n_args = n_args | |
| self.args_dim = args_dim | |
| self.abs_targets = abs_targets | |
| self.command_fcn = nn.Linear(d_model, n_commands) | |
| if abs_targets: | |
| self.args_fcn = nn.Linear(d_model, n_args) | |
| else: | |
| self.args_fcn = nn.Linear(d_model, n_args * args_dim) | |
| def forward(self, out): | |
| S, N, _ = out.shape | |
| command_logits = self.command_fcn(out) # Shape [S, N, n_commands] | |
| args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim] | |
| if not self.abs_targets: | |
| args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim] | |
| return command_logits, args_logits | |
| class ArgumentFCN(nn.Module): | |
| def __init__(self, d_model, n_args, args_dim=256, abs_targets=False): | |
| super().__init__() | |
| self.n_args = n_args | |
| self.args_dim = args_dim | |
| self.abs_targets = abs_targets | |
| # classification -> regression | |
| if abs_targets: | |
| self.args_fcn = nn.Sequential( | |
| nn.Linear(d_model, n_args * args_dim), | |
| nn.Linear(n_args * args_dim, n_args) | |
| ) | |
| else: | |
| self.args_fcn = nn.Linear(d_model, n_args * args_dim) | |
| def forward(self, out): | |
| S, N, _ = out.shape | |
| args_logits = self.args_fcn(out) # Shape [S, N, n_args * args_dim] | |
| if not self.abs_targets: | |
| args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim) # Shape [S, N, n_args, args_dim] | |
| return args_logits | |
| class HierarchFCN(nn.Module): | |
| def __init__(self, d_model, dim_z): | |
| super().__init__() | |
| # self.visibility_fcn = nn.Linear(d_model, 2) | |
| # self.z_fcn = nn.Linear(d_model, dim_z) | |
| self.visibility_fcn = nn.Linear(dim_z, 2) | |
| self.z_fcn = nn.Linear(dim_z, dim_z) | |
| def forward(self, out): | |
| G, N, _ = out.shape | |
| visibility_logits = self.visibility_fcn(out) # Shape [G, N, 2] | |
| z = self.z_fcn(out) # Shape [G, N, dim_z] | |
| return visibility_logits.unsqueeze(0), z.unsqueeze(0) | |
| class ResNet(nn.Module): | |
| def __init__(self, d_model): | |
| super().__init__() | |
| self.linear1 = nn.Sequential( | |
| nn.Linear(d_model, d_model), nn.ReLU() | |
| ) | |
| self.linear2 = nn.Sequential( | |
| nn.Linear(d_model, d_model), nn.ReLU() | |
| ) | |
| self.linear3 = nn.Sequential( | |
| nn.Linear(d_model, d_model), nn.ReLU() | |
| ) | |
| self.linear4 = nn.Sequential( | |
| nn.Linear(d_model, d_model), nn.ReLU() | |
| ) | |
| def forward(self, z): | |
| z = z + self.linear1(z) | |
| z = z + self.linear2(z) | |
| z = z + self.linear3(z) | |
| z = z + self.linear4(z) | |
| return z | |