lirannoc commited on
Commit
19a850b
·
verified ·
1 Parent(s): 7a346d0

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -0
modeling_super_linear.py CHANGED
@@ -420,6 +420,15 @@ class Model(nn.Module):
420
  else:
421
  out = self.moe(x)
422
 
 
 
 
 
 
 
 
 
 
423
  # Reshape back
424
  out = out.reshape(B, V, out.size(-1))
425
 
 
420
  else:
421
  out = self.moe(x)
422
 
423
+ if self.max_horizon < inf_pred_len:
424
+ outputs = [out]
425
+ ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
426
+ for i in range(0, inf_pred_len, self.max_horizon):
427
+ ar_out, _ = self.moe(ar_x)
428
+ outputs.append(ar_out)
429
+ ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
430
+ out = torch.cat(outputs, dim=1)[:, :inf_pred_len]
431
+
432
  # Reshape back
433
  out = out.reshape(B, V, out.size(-1))
434