Update modeling_super_linear.py
Browse files- modeling_super_linear.py +4 -4
modeling_super_linear.py
CHANGED
|
@@ -210,13 +210,13 @@ class RLinear(nn.Module):
|
|
| 210 |
if L < self.seq_len:
|
| 211 |
in_features = L
|
| 212 |
W = self.Linear.weight.detach()
|
| 213 |
-
fixed_weights =
|
| 214 |
-
dynamic_weights =
|
| 215 |
|
| 216 |
-
if in_features != self.weights.size(1)
|
| 217 |
dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, in_features-self.seq_len), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
|
| 218 |
if self.fixed_in != 0:
|
| 219 |
-
fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon,
|
| 220 |
|
| 221 |
x = self.revin_layer(x, 'norm')
|
| 222 |
x = F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1))
|
|
|
|
| 210 |
if L < self.seq_len:
|
| 211 |
in_features = L
|
| 212 |
W = self.Linear.weight.detach()
|
| 213 |
+
fixed_weights = W[:, :L]
|
| 214 |
+
dynamic_weights = W[:, L:]
|
| 215 |
|
| 216 |
+
if in_features != self.weights.size(1):
|
| 217 |
dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, in_features-self.seq_len), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
|
| 218 |
if self.fixed_in != 0:
|
| 219 |
+
fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, L), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
|
| 220 |
|
| 221 |
x = self.revin_layer(x, 'norm')
|
| 222 |
x = F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1))
|