razmars commited on
Commit
d3e03c1
·
verified ·
1 Parent(s): 0b03843

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -4
modeling_super_linear.py CHANGED
@@ -210,13 +210,13 @@ class RLinear(nn.Module):
210
  if L < self.seq_len:
211
  in_features = L
212
  W = self.Linear.weight.detach()
213
- fixed_weights = self.weights[:, :L]
214
- dynamic_weights = self.weights[:, L:]
215
 
216
- if in_features != self.weights.size(1) or out_features != self.weights.size(0):
217
  dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, in_features-self.seq_len), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
218
  if self.fixed_in != 0:
219
- fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
220
 
221
  x = self.revin_layer(x, 'norm')
222
  x = F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1))
 
210
  if L < self.seq_len:
211
  in_features = L
212
  W = self.Linear.weight.detach()
213
+ fixed_weights = W[:, :L]
214
+ dynamic_weights = W[:, L:]
215
 
216
+ if in_features != self.weights.size(1):
217
  dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, in_features-self.seq_len), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
218
  if self.fixed_in != 0:
219
+ fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(self.horizon, L), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
220
 
221
  x = self.revin_layer(x, 'norm')
222
  x = F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1))