Update modeling_super_linear.py
Browse files- modeling_super_linear.py +9 -0
modeling_super_linear.py
CHANGED
|
@@ -420,6 +420,15 @@ class Model(nn.Module):
|
|
| 420 |
else:
|
| 421 |
out = self.moe(x)
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
# Reshape back
|
| 424 |
out = out.reshape(B, V, out.size(-1))
|
| 425 |
|
|
|
|
| 420 |
else:
|
| 421 |
out = self.moe(x)
|
| 422 |
|
| 423 |
+
if self.max_horizon < inf_pred_len:
|
| 424 |
+
outputs = [out]
|
| 425 |
+
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 426 |
+
for i in range(0, inf_pred_len, self.max_horizon):
|
| 427 |
+
ar_out, _ = self.moe(ar_x)
|
| 428 |
+
outputs.append(ar_out)
|
| 429 |
+
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 430 |
+
out = torch.cat(outputs, dim=1)[:, :inf_pred_len]
|
| 431 |
+
|
| 432 |
# Reshape back
|
| 433 |
out = out.reshape(B, V, out.size(-1))
|
| 434 |
|