Update modeling_super_linear.py
Browse files- modeling_super_linear.py +32 -11
modeling_super_linear.py
CHANGED
|
@@ -511,6 +511,10 @@ class superLinear(nn.Module):
|
|
| 511 |
|
| 512 |
|
| 513 |
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
if len(x_enc.shape) > 2:
|
| 515 |
x = x_enc.permute(0, 2, 1)
|
| 516 |
B, V, L = x.shape
|
|
@@ -518,8 +522,26 @@ class superLinear(nn.Module):
|
|
| 518 |
x = x_enc
|
| 519 |
B, L = x.shape
|
| 520 |
V = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
x = x.reshape(B * V, L)
|
|
|
|
| 523 |
expert_probs = None
|
| 524 |
|
| 525 |
if get_prob:
|
|
@@ -527,23 +549,22 @@ class superLinear(nn.Module):
|
|
| 527 |
else:
|
| 528 |
out, self.moe_loss = self.moe(x)
|
| 529 |
|
| 530 |
-
|
| 531 |
-
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 532 |
-
#print("bitch")
|
| 533 |
outputs = [out]
|
| 534 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 535 |
-
for i in range(0,
|
| 536 |
ar_out, _ = self.moe(ar_x)
|
| 537 |
outputs.append(ar_out)
|
| 538 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 539 |
-
out = torch.cat(outputs, dim=1)[
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
out = out.reshape(B, V, out.shape[-1])
|
| 543 |
-
result = out.permute(0, 2, 1)
|
| 544 |
-
else:
|
| 545 |
-
result = out
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
if get_prob:
|
| 548 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 549 |
return result, expert_probs
|
|
|
|
| 511 |
|
| 512 |
|
| 513 |
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
|
| 514 |
+
|
| 515 |
+
if inf_pred_len is None:
|
| 516 |
+
inf_pred_len = self.inf_pred_len
|
| 517 |
+
|
| 518 |
if len(x_enc.shape) > 2:
|
| 519 |
x = x_enc.permute(0, 2, 1)
|
| 520 |
B, V, L = x.shape
|
|
|
|
| 522 |
x = x_enc
|
| 523 |
B, L = x.shape
|
| 524 |
V = 1
|
| 525 |
+
|
| 526 |
+
short_lookback = False
|
| 527 |
+
if L < self.seq_len:
|
| 528 |
+
# print("test!")
|
| 529 |
+
#ceil - very bad heuristic!
|
| 530 |
+
scale_factor = self.seq_len / L
|
| 531 |
+
scale_factor = int(np.ceil(scale_factor))
|
| 532 |
+
orig_pred_len = inf_pred_len
|
| 533 |
+
|
| 534 |
+
inf_pred_len = inf_pred_len*scale_factor
|
| 535 |
+
x = interpolate(x_enc.permute(0, 2, 1), scale_factor=scale_factor, mode='linear')
|
| 536 |
+
|
| 537 |
+
x = x[:,: , -self.seq_len:]
|
| 538 |
+
orig_L = L
|
| 539 |
+
L = self.seq_len
|
| 540 |
+
|
| 541 |
+
short_lookback = True
|
| 542 |
|
| 543 |
x = x.reshape(B * V, L)
|
| 544 |
+
|
| 545 |
expert_probs = None
|
| 546 |
|
| 547 |
if get_prob:
|
|
|
|
| 549 |
else:
|
| 550 |
out, self.moe_loss = self.moe(x)
|
| 551 |
|
| 552 |
+
if self.auto_regressive and self.max_horizon < inf_pred_len:
|
|
|
|
|
|
|
| 553 |
outputs = [out]
|
| 554 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 555 |
+
for i in range(0, inf_pred_len, self.max_horizon):
|
| 556 |
ar_out, _ = self.moe(ar_x)
|
| 557 |
outputs.append(ar_out)
|
| 558 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 559 |
+
out = torch.cat(outputs, dim=1)[:,:inf_pred_len]
|
| 560 |
+
out = out.reshape(B, V, out.shape[-1])
|
| 561 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
+
if short_lookback:
|
| 564 |
+
out = interpolate(out, scale_factor=1/scale_factor, mode='linear')
|
| 565 |
+
# print(out.shape)
|
| 566 |
+
out = out[:, :,:orig_pred_len]
|
| 567 |
+
result = out.permute(0, 2, 1)
|
| 568 |
if get_prob:
|
| 569 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 570 |
return result, expert_probs
|