razmars commited on
Commit
38cedcd
·
verified ·
1 Parent(s): c039a8b

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +32 -11
modeling_super_linear.py CHANGED
@@ -511,6 +511,10 @@ class superLinear(nn.Module):
511
 
512
 
513
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
 
 
 
 
514
  if len(x_enc.shape) > 2:
515
  x = x_enc.permute(0, 2, 1)
516
  B, V, L = x.shape
@@ -518,8 +522,26 @@ class superLinear(nn.Module):
518
  x = x_enc
519
  B, L = x.shape
520
  V = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  x = x.reshape(B * V, L)
 
523
  expert_probs = None
524
 
525
  if get_prob:
@@ -527,23 +549,22 @@ class superLinear(nn.Module):
527
  else:
528
  out, self.moe_loss = self.moe(x)
529
 
530
-
531
- if self.auto_regressive and self.max_horizon < self.inf_pred_len:
532
- #print("bitch")
533
  outputs = [out]
534
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
535
- for i in range(0, self.inf_pred_len, self.max_horizon):
536
  ar_out, _ = self.moe(ar_x)
537
  outputs.append(ar_out)
538
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
539
- out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
540
-
541
- if len(x_enc.shape) > 2:
542
- out = out.reshape(B, V, out.shape[-1])
543
- result = out.permute(0, 2, 1)
544
- else:
545
- result = out
546
 
 
 
 
 
 
547
  if get_prob:
548
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
549
  return result, expert_probs
 
511
 
512
 
513
  def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
514
+
515
+ if inf_pred_len is None:
516
+ inf_pred_len = self.inf_pred_len
517
+
518
  if len(x_enc.shape) > 2:
519
  x = x_enc.permute(0, 2, 1)
520
  B, V, L = x.shape
 
522
  x = x_enc
523
  B, L = x.shape
524
  V = 1
525
+
526
+ short_lookback = False
527
+ if L < self.seq_len:
528
+ # print("test!")
529
+ #ceil - very bad heuristic!
530
+ scale_factor = self.seq_len / L
531
+ scale_factor = int(np.ceil(scale_factor))
532
+ orig_pred_len = inf_pred_len
533
+
534
+ inf_pred_len = inf_pred_len*scale_factor
535
+ x = interpolate(x_enc.permute(0, 2, 1), scale_factor=scale_factor, mode='linear')
536
+
537
+ x = x[:,: , -self.seq_len:]
538
+ orig_L = L
539
+ L = self.seq_len
540
+
541
+ short_lookback = True
542
 
543
  x = x.reshape(B * V, L)
544
+
545
  expert_probs = None
546
 
547
  if get_prob:
 
549
  else:
550
  out, self.moe_loss = self.moe(x)
551
 
552
+ if self.auto_regressive and self.max_horizon < inf_pred_len:
 
 
553
  outputs = [out]
554
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
555
+ for i in range(0, inf_pred_len, self.max_horizon):
556
  ar_out, _ = self.moe(ar_x)
557
  outputs.append(ar_out)
558
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
559
+ out = torch.cat(outputs, dim=1)[:,:inf_pred_len]
560
+ out = out.reshape(B, V, out.shape[-1])
561
+
 
 
 
 
562
 
563
+ if short_lookback:
564
+ out = interpolate(out, scale_factor=1/scale_factor, mode='linear')
565
+ # print(out.shape)
566
+ out = out[:, :,:orig_pred_len]
567
+ result = out.permute(0, 2, 1)
568
  if get_prob:
569
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
570
  return result, expert_probs