lirannoc commited on
Commit
3b2a0e1
·
verified ·
1 Parent(s): db751a5

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +289 -326
modeling_super_linear.py CHANGED
@@ -1,24 +1,13 @@
1
-
2
- from typing import Optional, Tuple, Union
3
  import torch, torch.nn as nn, torch.nn.functional as F
 
 
 
4
 
5
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
6
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
  from .configuration_super_linear import SuperLinearConfig
8
 
9
- from typing import Tuple, Union
10
-
11
-
12
- import math
13
- import torch
14
- import numpy as np
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import matplotlib.pyplot as plt
18
- import os
19
- from torch.nn.functional import interpolate
20
-
21
- import datetime
22
 
23
  "-------------------------------------------------------------------------------------------------------------------"
24
  class RevIN(nn.Module):
@@ -95,117 +84,45 @@ class RevIN(nn.Module):
95
 
96
  return x
97
  "-------------------------------------------------------------------------------------------------------------------"
98
- class moving_avg(nn.Module):
99
- """
100
- Moving average block to highlight the trend of time series
101
- """
102
- def __init__(self, kernel_size, stride):
103
- super(moving_avg, self).__init__()
104
- self.kernel_size = kernel_size
105
- self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
106
-
107
- def forward(self, x):
108
- # x: [Batch, Input length]
109
- # padding on the both ends of time series
110
- front = x[:, 0:1].repeat(1, (self.kernel_size - 1) // 2)
111
- end = x[:, -1:].repeat(1, (self.kernel_size - 1) // 2)
112
- x = torch.cat([front, x, end], dim=1)
113
- x = self.avg(x.unsqueeze(1)).squeeze(1)
114
- return x
115
-
116
-
117
- class series_decomp(nn.Module):
118
- """
119
- Series decomposition block
120
- """
121
- def __init__(self, kernel_size):
122
- super(series_decomp, self).__init__()
123
- self.moving_avg = moving_avg(kernel_size, stride=1)
124
-
125
- def forward(self, x):
126
- moving_mean = self.moving_avg(x)
127
- res = x - moving_mean
128
- return res, moving_mean
129
-
130
-
131
- class DLinear(nn.Module):
132
- def __init__(self, input_len, output_len, kernel_size = 25):
133
- super(DLinear, self).__init__()
134
- self.seasonal = nn.Linear(input_len, output_len)
135
- self.trend = nn.Linear(input_len, output_len)
136
- self.moving_avg = moving_avg(kernel_size, stride=1)
137
- self.decompsition = series_decomp(kernel_size)
138
-
139
- def forward(self, x):
140
- # x: [Batch*Input length,Channel]
141
- seasonal_init, trend_init = self.decompsition(x)
142
- seasonal_output = self.seasonal(seasonal_init)
143
- trend_output = self.trend(trend_init)
144
- x = seasonal_output + trend_output
145
- return x # to [Batch, Output length, Channel]
146
-
147
  class Linear(nn.Module):
 
148
  def __init__(self, input_len, output_len):
149
  super(Linear, self).__init__()
150
  self.Linear = nn.Linear(input_len, output_len)
151
 
152
  def forward(self, x):
153
  # x: [Batch*Channel, Input length]
154
- x_shape = x.shape
155
- if len(x_shape) == 2:
156
- x = x.unsqueeze(-1)
157
- x = self.Linear(x)
158
- if len(x_shape) == 2:
159
- x = x.squeeze(-1)
160
  return x # to [Batch, Output length, Channel]
161
 
162
  class Naive(nn.Module):
 
163
  def __init__(self, input_len, output_len):
164
  super(Naive, self).__init__()
165
  self.output_len = output_len
166
 
167
-
168
  def forward(self, x):
169
  # x: [Batch*Channel, Input length]
170
-
171
-
172
- x = x[:,-1].unsqueeze(1).repeat(1, self.output_len)
173
-
174
-
175
  return x # to [Batch, Output length, Channel]
176
 
177
  class Mean(nn.Module):
 
178
  def __init__(self, input_len, output_len):
179
  super(Mean, self).__init__()
180
  self.output_len = output_len
181
 
182
  def forward(self, x):
183
  # x: [Batch*Channel, Input length]
184
-
185
- x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
186
-
187
  return x # to [Batch, Output length, Channel]
188
-
189
 
190
- class NLinear(nn.Module):
191
- def __init__(self, input_len, output_len):
192
- super(NLinear, self).__init__()
193
- self.Linear = nn.Linear(input_len, output_len)
194
-
195
- def forward(self, x):
196
- # x: [Batch* Input length,Channel]
197
- seq_last = x[:,-1:].detach()
198
- x = x - seq_last
199
- x = self.Linear(x)
200
-
201
- x = x + seq_last
202
- return x
203
-
204
-
205
  class RLinear(nn.Module):
 
206
  def __init__(self, input_len, output_len):
207
  super(RLinear, self).__init__()
208
- self.Linear = nn.Linear(input_len, output_len)
209
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
210
 
211
  def forward(self, x):
@@ -223,60 +140,69 @@ class RLinear(nn.Module):
223
  return x # to [Batch, Output length, Channel]
224
 
225
  "-------------------------------------------------------------------------------------------------------------------"
226
- class SparseNoisyMoE(nn.Module):
 
 
 
 
 
 
 
 
 
 
227
  def __init__(self, configs, experts=None):
228
- super(SparseNoisyMoE, self).__init__()
229
- input_dim = configs.seq_len
230
- output_dim = configs.pred_len
231
-
232
  self.noise_std = configs.noisy_gating_std
233
- self.noise_std_decay = configs.noisy_gating_std_decay
234
- self.experts = nn.ModuleList(experts)
235
  self.num_experts = len(experts)
236
  self.k = configs.top_k_experts
 
237
  if self.k > self.num_experts:
238
- print(f"Warning: k ({self.k}) is greater than the number of experts ({self.num_experts}). Setting k to {self.num_experts}.")
239
  self.k = self.num_experts
240
- self.d_model = configs.d_model
241
- self.mlp_gating = configs.mlp_gating
242
  self.moe_temp = configs.moe_temp
243
  self.use_fft = configs.use_fft
244
  self.fft_len = configs.fft_len
245
  self.moe_norm = configs.moe_norm
246
-
247
-
248
  if self.use_fft:
249
- if self.mlp_gating:
250
- self.gating_network = nn.Sequential(
251
- nn.Linear(self.fft_len//2, self.d_model),
252
- nn.ReLU(),
253
- nn.Linear(self.d_model, self.num_experts)
254
- )
255
-
256
- else:
257
- self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
258
  else:
259
- self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
260
 
261
  if self.moe_norm:
262
- self.batch_norm = nn.BatchNorm1d(self.num_experts)
263
-
264
 
265
-
266
- def get_periodogram(self, inputs, n=10000):
 
 
 
 
 
 
 
 
 
 
 
 
267
  if inputs.dim() == 2:
268
  x_0 = inputs.unsqueeze(2)
269
  else:
270
  x_0 = inputs
271
- x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
272
 
273
- v = torch.arange(0, n) / n
274
  dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
275
- dft = dft[:, :n//2, :]
276
- I = torch.abs(dft) ** 2
277
 
 
278
  I_sum = torch.sum(I, dim=1, keepdim=True)
279
- I_sum[I_sum == 0] = 1
280
  I = I / I_sum
281
 
282
  if torch.any(I_sum == 0):
@@ -289,279 +215,314 @@ class SparseNoisyMoE(nn.Module):
289
  return I
290
 
291
  def forward(self, x, get_prob=False):
 
 
 
 
 
 
 
 
 
 
 
 
292
  if self.use_fft:
293
- # x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
294
- x_0 = self.get_periodogram(x, n=self.fft_len)
295
  else:
296
  x_0 = x
297
 
298
- self.gate_outputs = self.gating_network(x_0) # g(X)
 
 
299
  if self.moe_norm:
300
- # self.gate_outputs = self.batch_norm(self.gate_outputs)
301
- self.gate_outputs = self.batch_norm(self.gate_outputs)
302
-
303
- #
304
 
 
305
  if not self.training:
306
  self.gate_outputs = self.gate_outputs / self.moe_temp
307
 
308
- # original
309
  noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
310
  if self.training:
311
  noisy_gate_outputs = self.gate_outputs + noise
312
- self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1) # N = 35, k=6,12,20
313
  else:
314
  self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
315
 
316
-
317
  self.topk_gates = F.softmax(self.topk_values, dim=1)
318
 
319
  batch_size = x.size(0)
 
320
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
321
 
 
322
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
323
  sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
324
 
 
325
  output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
326
-
327
- load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
328
 
329
  if get_prob:
330
  expert_probs = F.softmax(self.gate_outputs, dim=1)
331
- return output, load_balancing_loss, expert_probs
332
 
333
- return output, load_balancing_loss
334
-
335
- def calculate_load_balancing_loss(self, gate_outputs, batch_size):
336
- gate_probs = F.softmax(gate_outputs, dim=1)
337
-
338
- assignments = torch.argmax(gate_outputs, dim=1)
339
- self.D = torch.zeros(self.num_experts, device=gate_outputs.device)
340
- for i in range(self.num_experts):
341
- self.D[i] = torch.sum(assignments == i).float() / batch_size
342
-
343
- P = torch.mean(gate_probs, dim=0)
344
-
345
- load_balancing_loss = torch.sum(self.D * P) * self.num_experts
346
-
347
- return load_balancing_loss
348
 
349
 
350
-
351
- class superLinear(nn.Module):
 
 
 
 
 
 
 
 
352
  def __init__(self, configs):
353
- super(superLinear, self).__init__()
354
-
355
  self.configs = configs
356
- self.pred_len = configs.pred_len
357
- self.seq_len = configs.seq_len
358
- self.inf_pred_len = configs.inf_pred_len
359
- self.max_horizon = configs.max_horizon
360
- self.auto_regressive = configs.auto_regressive
361
- self.n_experts = configs.moe_n_experts
362
- self.moe = configs.moe
363
  self.model_name = "SuperLinear"
364
-
 
 
 
 
 
 
 
 
 
 
365
  if configs.freq_experts == "":
366
  self.freq_experts = None
367
  else:
368
  self.freq_experts = configs.freq_experts.split('_')
369
 
370
-
371
-
372
- self.moe_loss = None
373
  self.top_k_experts = configs.top_k_experts
374
- self.n_experts = configs.moe_n_experts
375
  self.freeze_experts = configs.freeze_experts
376
- self.layer_type = configs.layer_type
377
- self.model_name = "SuperLinear"
378
-
379
-
380
- self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
381
-
382
- # path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
- # dirs = os.listdir(path)
384
- # checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
385
-
386
- if self.freq_experts == "all":
387
- self.freq_experts = []
388
- for cp in checkpoints_paths:
389
- if self.layer_type in cp:
390
- cycle = cp.split("/")
391
 
 
392
  self.experts = {}
393
  if self.freq_experts is not None:
394
  for expert_freq in self.freq_experts:
395
  if expert_freq == "naive" or expert_freq == "Naive":
396
- self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
397
  elif expert_freq == "mean" or expert_freq == "Mean":
398
- self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
399
  else:
400
- self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
401
- # if configs.load_linear:
402
- # cycle = self.map_to_cycle(expert_freq)
403
- # cycle_str = f'cycle_{cycle}/'
404
- # cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
405
- # if len(cycle_checkpoint_path) > 0:
406
- # print()
407
- # print(cycle_str)
408
- # cycle_checkpoint_path = cycle_checkpoint_path[0]
409
- # #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
410
- # print(cycle_checkpoint_path)
411
- # self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
412
- # else:
413
- # print(f"Checkpoint for {cycle_str} not found in {path}")
414
- # raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
415
- # if configs.freeze_experts:
416
- # for param in self.experts[expert_freq].parameters():
417
- # param.requires_grad = False
418
 
419
- self.n_experts = len(self.experts)
 
 
420
  else:
421
- for i in range(self.n_experts):
422
- print(f"creating expert {i}")
423
- self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
- if configs.misc_moe>0:
427
- if configs.misc_moe == 1:
428
- #print("Creating misc expert")
429
- self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  else:
431
- for i in range(configs.misc_moe):
432
- #print(f"Creating misc expert {i}")
433
- self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
434
-
 
 
 
 
435
 
 
 
 
 
 
 
 
 
436
 
437
- self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
438
- self.dropout = nn.Dropout(configs.dropout)
439
-
440
- # if configs.load_weights:
441
- # print(f"Loading weights from {path}")
442
- # path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
443
- # if os.path.exists(path):
444
- # checkpoint = torch.load(path)
445
- # print(len(self.experts.keys()))
446
- # print(self.experts.keys())
447
- # print(self.state_dict().keys())
448
- # print(checkpoint.keys())
449
- # self.load_state_dict(checkpoint)
450
- # else:
451
- # print(f"Path {path} does not exist. Skipping loading weights.")
452
-
453
-
454
- # def map_to_cycle(self, freq):
455
- # if "/" in freq:
456
- # cycle = int(freq.split("/")[1])
457
- # elif "h" in freq:
458
- # cycle = 24
459
- # elif "2h":
460
- # cycle = 12
461
- # elif "3h" in freq:
462
- # cycle = 8
463
- # elif "4h" in freq:
464
- # cycle = 6
465
- # elif "D" in freq:
466
- # cycle = 7
467
- # elif "DM" in freq:
468
- # cycle = 30
469
- # elif "W" in freq:
470
- # cycle = 52
471
- # elif "M" in freq:
472
- # cycle = 12
473
- # elif "min" in freq:
474
- # cycle = 1440
475
- # elif "5min" in freq:
476
- # cycle = 288
477
- # elif "10min" in freq:
478
- # cycle = 144
479
- # elif "15min" in freq:
480
- # cycle = 96
481
- # elif "30min" in freq:
482
- # cycle = 48
483
- # else:
484
- # cycle = int(freq)
485
- # return cycle
486
-
487
-
488
- def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False, inf_pred_len=None):
489
-
490
- if inf_pred_len is None:
491
- inf_pred_len = self.inf_pred_len
492
-
493
- if len(x_enc.shape) > 2:
494
- x = x_enc.permute(0, 2, 1)
495
- B, V, L = x.shape
496
- else:
497
- x = x_enc
498
- B, L = x.shape
499
- V = 1
500
 
501
- short_lookback = False
502
- if L<self.seq_len:
503
- # print("test!")
504
- #ceil - very bad heuristic!
505
- scale_factor = self.seq_len / L
506
- scale_factor = int(np.ceil(scale_factor))
507
- orig_pred_len = inf_pred_len
508
 
509
- inf_pred_len = inf_pred_len*scale_factor
510
- x = interpolate(x_enc.permute(0, 2, 1), scale_factor=scale_factor, mode='linear')
511
 
512
- x = x[:,: , -self.seq_len:]
513
- orig_L = L
514
- L = self.seq_len
515
 
516
- short_lookback = True
 
517
 
518
- x = x.reshape(B * V, L)
519
-
520
- expert_probs = None
521
-
522
  if get_prob:
523
- out, self.moe_loss, expert_probs = self.moe(x, get_prob=True)
524
  else:
525
- out, self.moe_loss = self.moe(x)
526
-
527
- if self.auto_regressive and self.max_horizon < inf_pred_len:
528
- outputs = [out]
529
- ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
530
- for i in range(0, inf_pred_len, self.max_horizon):
531
- ar_out, _ = self.moe(ar_x)
532
- outputs.append(ar_out)
533
- ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
534
- out = torch.cat(outputs, dim=1)[:,:inf_pred_len]
535
- out = out.reshape(B, V, out.shape[-1])
536
-
537
 
538
- if short_lookback:
539
- out = interpolate(out, scale_factor=1/scale_factor, mode='linear')
540
- out = out[:, :,:orig_pred_len]
 
 
 
 
 
541
  result = out.permute(0, 2, 1)
 
 
 
 
542
  if get_prob:
543
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
544
  return result, expert_probs
545
  return result
546
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  "-------------------------------------------------------------------------------------------------------------------"
548
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
549
-
550
  config_class = SuperLinearConfig
551
 
552
  def __init__(self, config: SuperLinearConfig):
553
  super().__init__(config)
554
 
 
555
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
556
- backbone_cfg = type("Cfg", (), config.to_dict())()
557
- self.args = backbone_cfg
558
- self.backbone = superLinear(backbone_cfg)
559
  self.post_init()
560
 
561
-
 
 
562
  def forward(self,
563
- inputs_embeds: torch.Tensor = None,
564
- prediction_len: int = None,
565
  attention_mask: Optional[torch.Tensor] = None,
566
  past_key_values: Optional[Tuple] = None,
567
  use_cache: bool = True,
@@ -573,17 +534,19 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
573
  raise ValueError("Pass the time‑series as `inputs_embeds`")
574
 
575
  # backbone expects (B, C, L)
576
- preds = self.backbone(inputs_embeds, inf_pred_len=prediction_len)
 
 
 
577
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
578
 
579
 
580
- def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, prediction_len=None, **kwargs):
581
  if past_key_values is not None:
582
  # only feed the last new step
583
  inputs_embeds = inputs_embeds[:, -1:, :]
584
- return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values, "prediction_len": prediction_len}
585
 
586
  def _reorder_cache(self, past, beam_idx, **kwargs):
587
  return past # backbone keeps no KV cache
588
 
589
-
 
1
+ from typing import Optional, Tuple
 
2
  import torch, torch.nn as nn, torch.nn.functional as F
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import os
6
 
7
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
8
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
  from .configuration_super_linear import SuperLinearConfig
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  "-------------------------------------------------------------------------------------------------------------------"
13
  class RevIN(nn.Module):
 
84
 
85
  return x
86
  "-------------------------------------------------------------------------------------------------------------------"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  class Linear(nn.Module):
88
+ """Simple linear layer expert."""
89
  def __init__(self, input_len, output_len):
90
  super(Linear, self).__init__()
91
  self.Linear = nn.Linear(input_len, output_len)
92
 
93
  def forward(self, x):
94
  # x: [Batch*Channel, Input length]
95
+ x = x.clone()
96
+ x = self.Linear(x).clone()
 
 
 
 
97
  return x # to [Batch, Output length, Channel]
98
 
99
  class Naive(nn.Module):
100
+ """Naive forecasting expert - repeats last value."""
101
  def __init__(self, input_len, output_len):
102
  super(Naive, self).__init__()
103
  self.output_len = output_len
104
 
 
105
  def forward(self, x):
106
  # x: [Batch*Channel, Input length]
107
+ x = x[:,-1].unsqueeze(1).repeat(1, self.output_len)
 
 
 
 
108
  return x # to [Batch, Output length, Channel]
109
 
110
  class Mean(nn.Module):
111
+ """Mean forecasting expert - repeats mean value."""
112
  def __init__(self, input_len, output_len):
113
  super(Mean, self).__init__()
114
  self.output_len = output_len
115
 
116
  def forward(self, x):
117
  # x: [Batch*Channel, Input length]
118
+ x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
 
 
119
  return x # to [Batch, Output length, Channel]
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  class RLinear(nn.Module):
122
+ """Reversible Instance Normalization Linear layer expert."""
123
  def __init__(self, input_len, output_len):
124
  super(RLinear, self).__init__()
125
+ self.Linear = nn.Linear(input_len, output_len)
126
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
127
 
128
  def forward(self, x):
 
140
  return x # to [Batch, Output length, Channel]
141
 
142
  "-------------------------------------------------------------------------------------------------------------------"
143
+ class SparseMoE(nn.Module):
144
+ """
145
+ Sparse Mixture of Experts (MoE) module that routes inputs to the most relevant experts.
146
+
147
+ This implementation uses a gating network to determine which experts should process each input.
148
+ Only the top-k experts are used for each input, creating a sparse computation pattern.
149
+
150
+ Args:
151
+ configs: Configuration object containing MoE parameters
152
+ experts: Collection of expert modules (neural networks)
153
+ """
154
  def __init__(self, configs, experts=None):
155
+ super(SparseMoE, self).__init__()
 
 
 
156
  self.noise_std = configs.noisy_gating_std
157
+ self.experts = nn.ModuleList(experts) # Store experts in ModuleList for proper registration
 
158
  self.num_experts = len(experts)
159
  self.k = configs.top_k_experts
160
+
161
  if self.k > self.num_experts:
 
162
  self.k = self.num_experts
163
+
 
164
  self.moe_temp = configs.moe_temp
165
  self.use_fft = configs.use_fft
166
  self.fft_len = configs.fft_len
167
  self.moe_norm = configs.moe_norm
168
+
169
+ # Initialize gating network based on configuration
170
  if self.use_fft:
171
+ self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
 
 
 
 
 
 
 
 
172
  else:
173
+ self.gating_network = nn.Linear(configs.seq_len, self.num_experts, bias=True)
174
 
175
  if self.moe_norm:
176
+ self.gate_norm = nn.BatchNorm1d(self.num_experts)
 
177
 
178
+ def get_periodogram(self, inputs, n=10000):
179
+ """
180
+ Calculate the periodogram (power spectral density) of input time series.
181
+
182
+ The periodogram is used as a frequency-domain representation of the signal
183
+ to help the gating network identify periodic patterns.
184
+
185
+ Args:
186
+ inputs: Input time series tensor of shape [batch_size, sequence_length] or [batch_size, sequence_length, features]
187
+ n: Number of points in FFT computation
188
+
189
+ Returns:
190
+ Normalized periodogram of the input signals
191
+ """
192
  if inputs.dim() == 2:
193
  x_0 = inputs.unsqueeze(2)
194
  else:
195
  x_0 = inputs
196
+ x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True) # Remove mean (DC component)
197
 
198
+ # Compute FFT and normalize
199
  dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
200
+ dft = dft[:, :n//2, :] # Keep only positive frequencies
201
+ I = torch.abs(dft) ** 2 # Power spectral density
202
 
203
+ # Normalize periodogram
204
  I_sum = torch.sum(I, dim=1, keepdim=True)
205
+ I_sum[I_sum == 0] = 1 # Avoid division by zero
206
  I = I / I_sum
207
 
208
  if torch.any(I_sum == 0):
 
215
  return I
216
 
217
  def forward(self, x, get_prob=False):
218
+ """
219
+ Forward pass through the Mixture of Experts.
220
+
221
+ Args:
222
+ x: Input tensor of shape [batch_size, sequence_length]
223
+ get_prob: Whether to return expert selection probabilities
224
+
225
+ Returns:
226
+ - Output tensor from the selected experts
227
+ - (Optional) Expert selection probabilities if get_prob is True
228
+ """
229
+ # Preprocess input if using FFT-based gating
230
  if self.use_fft:
231
+ x_0 = self.get_periodogram(x, n=self.fft_len)
 
232
  else:
233
  x_0 = x
234
 
235
+ # Get gating logits
236
+ self.gate_outputs = self.gating_network(x_0) # Raw gating scores
237
+
238
  if self.moe_norm:
239
+ self.gate_outputs = self.gate_norm(self.gate_outputs)
 
 
 
240
 
241
+ # Apply temperature scaling during inference
242
  if not self.training:
243
  self.gate_outputs = self.gate_outputs / self.moe_temp
244
 
245
+ # Add noise to gating logits during training (for exploration)
246
  noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
247
  if self.training:
248
  noisy_gate_outputs = self.gate_outputs + noise
249
+ self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
250
  else:
251
  self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
252
 
253
+ # Normalize the gate values with softmax
254
  self.topk_gates = F.softmax(self.topk_values, dim=1)
255
 
256
  batch_size = x.size(0)
257
+ # Get outputs from all experts
258
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
259
 
260
+ # Select only the outputs from the top-k experts
261
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
262
  sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
263
 
264
+ # Combine expert outputs using the gate values
265
  output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
 
 
266
 
267
  if get_prob:
268
  expert_probs = F.softmax(self.gate_outputs, dim=1)
269
+ return output, expert_probs
270
 
271
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
 
274
+ class Model(nn.Module):
275
+ """
276
+ Main model class that employs a Mixture of Experts for time series forecasting.
277
+
278
+ This model can work with various types of linear layers as experts and supports
279
+ both standard prediction and auto-regressive prediction for longer horizons.
280
+
281
+ Args:
282
+ configs: Configuration object containing model parameters
283
+ """
284
  def __init__(self, configs):
285
+ super(Model, self).__init__()
 
286
  self.configs = configs
 
 
 
 
 
 
 
287
  self.model_name = "SuperLinear"
288
+ self.train_pred_len = configs.train_pred_len
289
+ self.train_seq_len = configs.train_seq_len
290
+ self.resample_long_lookback = configs.resample_long_lookback
291
+ self.layer_type = configs.layer_type
292
+ self.load_weights_full = configs.load_weights_full
293
+ self.load_linear = configs.load_linear
294
+
295
+ if self.load_weights_full:
296
+ pass # TODO: implement full weight loading
297
+
298
+ # Parse frequency experts from configuration
299
  if configs.freq_experts == "":
300
  self.freq_experts = None
301
  else:
302
  self.freq_experts = configs.freq_experts.split('_')
303
 
 
 
 
304
  self.top_k_experts = configs.top_k_experts
 
305
  self.freeze_experts = configs.freeze_experts
306
+ path = configs.linear_freq_weights_path
307
+ linear_freq_dirs = os.listdir(path) if os.path.exists(path) else []
308
+ checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in linear_freq_dirs]
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ # Initialize experts based on frequency specification or create generic experts
311
  self.experts = {}
312
  if self.freq_experts is not None:
313
  for expert_freq in self.freq_experts:
314
  if expert_freq == "naive" or expert_freq == "Naive":
315
+ self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len)
316
  elif expert_freq == "mean" or expert_freq == "Mean":
317
+ self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
318
  else:
319
+ # Use the appropriate expert class based on layer_type
320
+ expert_classes = {'Linear': Linear, 'RLinear': RLinear}
321
+ if self.layer_type in expert_classes:
322
+ expert_class = expert_classes[self.layer_type]
323
+ self.experts[expert_freq] = expert_class(self.train_seq_len, self.train_pred_len)
324
+ else:
325
+ # Default to RLinear if unknown layer type
326
+ self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
327
+
328
+ if self.load_linear and checkpoints_paths:
329
+ cycle = self.map_to_cycle(expert_freq)
330
+ cycle_str = f'cycle_{cycle}/'
331
+ cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
332
+ if len(cycle_checkpoint_path) > 0:
333
+ cycle_checkpoint_path = cycle_checkpoint_path[0]
334
+ print(f'Loading checkpoint: {cycle_checkpoint_path}')
335
+ self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
 
336
 
337
+ if self.freeze_experts:
338
+ for param in self.experts[expert_freq].parameters():
339
+ param.requires_grad = False
340
  else:
341
+ # Create generic experts
342
+ for i in range(configs.n_experts):
343
+ expert_classes = {'Linear': Linear, 'RLinear': RLinear}
344
+ if self.layer_type in expert_classes:
345
+ expert_class = expert_classes[self.layer_type]
346
+ self.experts[str(i)] = expert_class(self.train_seq_len, self.train_pred_len)
347
+ else:
348
+ # Default to RLinear if unknown layer type
349
+ self.experts[str(i)] = RLinear(self.train_seq_len, self.train_pred_len)
350
+
351
+ # Create additional complementary experts if specified
352
+ if configs.comp_moe > 0:
353
+ for i in range(configs.comp_moe):
354
+ expert_classes = {'Linear': Linear, 'RLinear': RLinear}
355
+ if self.layer_type in expert_classes:
356
+ expert_class = expert_classes[self.layer_type]
357
+ self.experts[f"comp_{i}"] = expert_class(self.train_seq_len, self.train_pred_len)
358
+ else:
359
+ # Default to RLinear if unknown layer type
360
+ self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
361
+
362
+ # Initialize the MoE layer and dropout
363
+ self.moe = SparseMoE(configs, experts=self.experts.values())
364
+
365
+ # Load pre-trained weights if specified
366
+ if configs.load_weights_full:
367
+ pass # TODO: implement full weight loading
368
+
369
+ print("Experts:", self.experts.keys())
370
 
371
+ def add_experts(self, experts: dict):
372
+ """
373
+ Add new experts to the model.
374
+
375
+ Args:
376
+ experts: Dictionary of expert instances to add
377
+ """
378
+ for name, expert in experts.items():
379
+ self.experts[name] = expert
380
+ # Reinitialize the MoE layer with the updated experts
381
+ self.moe = SparseMoE(self.configs, experts=self.experts.values())
382
+ return self.moe
383
 
384
+ def resample_seq_len(self, x, pred_len, inverse=False, orig_pred_len=None):
385
+ """
386
+ Resample sequence length for handling inputs shorter than expected training length.
387
+
388
+ Args:
389
+ x: Input tensor
390
+ pred_len: Prediction length
391
+ inverse: If True, downsample back to original scale; if False, upsample
392
+ orig_pred_len: Original prediction length (required for inverse=True)
393
+
394
+ Returns:
395
+ Tuple of (resampled_tensor, updated_pred_len, scale_factor, orig_pred_len)
396
+ For inverse=True: returns (resampled_tensor, None, None, None)
397
+ """
398
+ if not inverse:
399
+ # Upsample if input is shorter than training length
400
+ if x.size(-1) < self.train_seq_len:
401
+ scale_factor = self.train_seq_len / x.size(-1)
402
+ x_resampled = F.interpolate(x.unsqueeze(1), size=self.train_seq_len, mode='linear', align_corners=False).squeeze(1)
403
+ pred_len_resampled = int(pred_len * scale_factor)
404
+ return x_resampled, pred_len_resampled, scale_factor, pred_len
405
  else:
406
+ return x, pred_len, None, None
407
+ else:
408
+ # Downsample back to original scale
409
+ if orig_pred_len is not None:
410
+ x_resampled = F.interpolate(x.unsqueeze(1), size=orig_pred_len, mode='linear', align_corners=False).squeeze(1)
411
+ return x_resampled, None, None, None
412
+ else:
413
+ return x, None, None, None
414
 
415
+ def forward(self, x_in, get_prob=False, pred_len=None):
416
+ """
417
+ Forward pass through the model.
418
+
419
+ Args:
420
+ x_in: Encoder input tensor
421
+ get_prob: Whether to return expert selection probabilities
422
+ pred_len: Override for prediction length
423
 
424
+ Returns:
425
+ - Prediction tensor
426
+ - (Optional) Expert selection probabilities if get_prob is True
427
+ """
428
+ if pred_len is None:
429
+ pred_len = self.train_pred_len
430
+
431
+ x = x_in
432
+ # If input is 2D, add a channel dimension
433
+ if x_in.dim() == 2:
434
+ x = x.unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ # Permute to shape [batch_size, features, sequence_length]
437
+ x = x.permute(0, 2, 1)
438
+ B, V, L = x.shape
 
 
 
 
439
 
440
+ scale_factor = None
441
+ orig_pred_len = None
442
 
443
+ # Handle resampling if input is shorter than training length
444
+ if self.resample_long_lookback and L < self.train_seq_len:
445
+ x, pred_len, scale_factor, orig_pred_len = self.resample_seq_len(x, pred_len, inverse=False)
446
 
447
+ # Reshape for MoE processing
448
+ x = x.reshape(B * V, x.size(-1))
449
 
450
+ # Forward through MoE
 
 
 
451
  if get_prob:
452
+ out, expert_probs = self.moe(x, get_prob=True)
453
  else:
454
+ out = self.moe(x)
 
 
 
 
 
 
 
 
 
 
 
455
 
456
+ # Reshape back
457
+ out = out.reshape(B, V, out.size(-1))
458
+
459
+ # Handle resampling back to original scale if needed
460
+ if scale_factor is not None:
461
+ out, _, _, _ = self.resample_seq_len(out, None, inverse=True, orig_pred_len=orig_pred_len)
462
+
463
+ # Return to original shape conventions
464
  result = out.permute(0, 2, 1)
465
+
466
+ if x_in.dim() == 2:
467
+ result = result.squeeze(-1)
468
+
469
  if get_prob:
470
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
471
  return result, expert_probs
472
  return result
473
+
474
+ def map_to_cycle(self, freq):
475
+ """Map frequency string to cycle length for expert loading."""
476
+ if "/" in freq:
477
+ cycle = int(freq.split("/")[1])
478
+ elif "h" in freq:
479
+ cycle = 24
480
+ elif "2h" in freq:
481
+ cycle = 12
482
+ elif "3h" in freq:
483
+ cycle = 8
484
+ elif "4h" in freq:
485
+ cycle = 6
486
+ elif "D" in freq:
487
+ cycle = 7
488
+ elif "DM" in freq:
489
+ cycle = 30
490
+ elif "W" in freq:
491
+ cycle = 52
492
+ elif "M" in freq:
493
+ cycle = 12
494
+ elif "min" in freq:
495
+ cycle = 1440
496
+ elif "5min" in freq:
497
+ cycle = 288
498
+ elif "10min" in freq:
499
+ cycle = 144
500
+ elif "15min" in freq:
501
+ cycle = 96
502
+ elif "30min" in freq:
503
+ cycle = 48
504
+ else:
505
+ cycle = int(freq)
506
+ return cycle
507
  "-------------------------------------------------------------------------------------------------------------------"
508
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
 
509
  config_class = SuperLinearConfig
510
 
511
  def __init__(self, config: SuperLinearConfig):
512
  super().__init__(config)
513
 
514
+
515
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
516
+ backbone_cfg = type("Cfg", (), config.to_dict())()
517
+ self.args = backbone_cfg
518
+ self.backbone = Model(backbone_cfg)
519
  self.post_init()
520
 
521
+ # ------------------------------------------------------------------
522
+ # Forward pass expected by AutoModelForCausalLM
523
+ # ------------------------------------------------------------------
524
  def forward(self,
525
+ inputs_embeds: torch.Tensor = None,
 
526
  attention_mask: Optional[torch.Tensor] = None,
527
  past_key_values: Optional[Tuple] = None,
528
  use_cache: bool = True,
 
534
  raise ValueError("Pass the time‑series as `inputs_embeds`")
535
 
536
  # backbone expects (B, C, L)
537
+ x_enc = inputs_embeds
538
+
539
+ # backbone returns (B, pred_len, C)
540
+ preds = self.backbone(x_enc)
541
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
542
 
543
 
544
+ def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
545
  if past_key_values is not None:
546
  # only feed the last new step
547
  inputs_embeds = inputs_embeds[:, -1:, :]
548
+ return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
549
 
550
  def _reorder_cache(self, past, beam_idx, **kwargs):
551
  return past # backbone keeps no KV cache
552