lirannoc commited on
Commit
5a08ee9
·
verified ·
1 Parent(s): fc594cb

Upload 7 files

Browse files
README.md CHANGED
@@ -1,3 +1,64 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SuperLinear: A Mixture of Experts Time Series Forecasting Model
2
+
3
+ SuperLinear is a novel time series forecasting model that employs a Mixture of Experts (MoE) architecture to achieve superior performance across various forecasting tasks. The model routes inputs to the most relevant experts based on frequency-domain analysis using FFT-based gating networks.
4
+
5
+ ## Model Architecture
6
+
7
+ The SuperLinear model consists of:
8
+
9
+ - **Sparse Mixture of Experts (MoE)**: Routes inputs to the top-k most relevant experts
10
+ - **FFT-based Gating Network**: Uses frequency domain analysis to determine expert routing
11
+ - **Frequency-specific Experts**: Pre-trained experts specialized for different temporal patterns
12
+
13
+ ## Key Features
14
+
15
+ - **Adaptive Expert Selection**: Dynamic routing based on input characteristics
16
+ - **Frequency-aware Processing**: Leverages FFT analysis for intelligent expert selection
17
+ - **Auto-regressive Capabilities**: Supports long-horizon forecasting
18
+ - **Multi-scale Processing**: Handles various sequence lengths through resampling
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ from transformers import AutoModelForCausalLM, AutoConfig
24
+ import torch
25
+
26
+ # Load the model
27
+ model = AutoModelForCausalLM.from_pretrained("path/to/superlinear", trust_remote_code=True)
28
+
29
+ # Prepare input time series data
30
+ # Shape: [batch_size, sequence_length, features]
31
+ input_data = torch.randn(1, 512, 1)
32
+
33
+ # Generate predictions
34
+ with torch.no_grad():
35
+ outputs = model(inputs_embeds=input_data, pred_len=96)
36
+ predictions = outputs.logits # Shape: [batch_size, prediction_length, features]
37
+ ```
38
+
39
+ ## Configuration
40
+
41
+ Key configuration parameters:
42
+
43
+ - `train_seq_len`: Training sequence length (default: 512)
44
+ - `train_pred_len`: Training prediction length (default: 96)
45
+ - `top_k_experts`: Number of experts to use (default: 12)
46
+ - `use_fft`: Whether to use FFT-based gating (default: True)
47
+ - `freq_experts`: Frequency-specific expert configuration
48
+ - `moe_temp`: Temperature for expert selection during inference (default: 1)
49
+
50
+ ## Citation
51
+
52
+ If you use SuperLinear in your research, please cite:
53
+
54
+ ```bibtex
55
+ @article{superlinear2024,
56
+ title={SuperLinear: Mixture of Experts for Time Series Forecasting},
57
+ author={Your Name},
58
+ year={2024}
59
+ }
60
+ ```
61
+
62
+ ## License
63
+
64
+ This model is released under the MIT License.
config.json CHANGED
@@ -31,6 +31,16 @@
31
  "_comment_training": "Training parameters",
32
  "resample_long_lookback": false,
33
 
 
 
 
 
 
 
 
 
 
 
34
  "_comment_system": "System and framework parameters",
35
  "model_type": "super_linear",
36
  "torch_dtype": "float32",
 
31
  "_comment_training": "Training parameters",
32
  "resample_long_lookback": false,
33
 
34
+ "_comment_horizon": "Auto-regressive and horizon parameters",
35
+ "long_horizon_scaling": 1,
36
+
37
+ "_comment_resampling": "Resampling and lookback-based parameters",
38
+ "lookback_resampling": 1,
39
+ "scale_list": "2,4,6",
40
+ "threshold": 0.2,
41
+ "freq_bound": 0.25,
42
+ "penalty_scale": 2.0,
43
+
44
  "_comment_system": "System and framework parameters",
45
  "model_type": "super_linear",
46
  "torch_dtype": "float32",
configuration_super_linear.py CHANGED
@@ -41,6 +41,16 @@ class SuperLinearConfig(PretrainedConfig):
41
  # Training parameters
42
  resample_long_lookback=False,
43
 
 
 
 
 
 
 
 
 
 
 
44
  **kwargs,
45
  ):
46
  # Model architecture parameters
@@ -66,4 +76,14 @@ class SuperLinearConfig(PretrainedConfig):
66
  # Training parameters
67
  self.resample_long_lookback = resample_long_lookback
68
 
 
 
 
 
 
 
 
 
 
 
69
  super().__init__(**kwargs)
 
41
  # Training parameters
42
  resample_long_lookback=False,
43
 
44
+ # Auto-regressive and horizon parameters
45
+ long_horizon_scaling=1,
46
+
47
+ # Resampling and lookback-based parameters
48
+ lookback_resampling=1,
49
+ scale_list=[2,4,6],
50
+ threshold=0.2,
51
+ freq_bound=0.25,
52
+ penalty_scale=2.0,
53
+
54
  **kwargs,
55
  ):
56
  # Model architecture parameters
 
76
  # Training parameters
77
  self.resample_long_lookback = resample_long_lookback
78
 
79
+ # Auto-regressive and horizon parameters
80
+ self.long_horizon_scaling = long_horizon_scaling
81
+
82
+ # Resampling and lookback-based parameters
83
+ self.lookback_resampling = lookback_resampling
84
+ self.scale_list = scale_list
85
+ self.threshold = threshold
86
+ self.freq_bound = freq_bound
87
+ self.penalty_scale = penalty_scale
88
+
89
  super().__init__(**kwargs)
example_usage.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example usage of SuperLinear model for time series forecasting.
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoConfig
8
+
9
+ def main():
10
+ # Load model configuration and model
11
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained("./", trust_remote_code=True)
13
+
14
+ # Set model to evaluation mode
15
+ model.eval()
16
+
17
+ # Create sample time series data
18
+ # Shape: [batch_size, sequence_length, features]
19
+ batch_size = 4
20
+ sequence_length = 512
21
+ num_features = 1
22
+ prediction_length = 96
23
+
24
+ # Generate synthetic time series data
25
+ t = torch.linspace(0, 10, sequence_length)
26
+ sample_data = torch.sin(t).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, num_features)
27
+
28
+ print(f"Input shape: {sample_data.shape}")
29
+
30
+ # Generate predictions
31
+ with torch.no_grad():
32
+ outputs = model(inputs_embeds=sample_data, pred_len=prediction_length)
33
+ predictions = outputs.logits
34
+
35
+ print(f"Prediction shape: {predictions.shape}")
36
+ print(f"Sample predictions: {predictions[0, :5, 0]}") # First 5 predictions of first batch
37
+
38
+ # Demonstrate with different prediction lengths
39
+ for pred_len in [24, 48, 96, 192]:
40
+ with torch.no_grad():
41
+ outputs = model(inputs_embeds=sample_data, pred_len=pred_len)
42
+ predictions = outputs.logits
43
+ print(f"Prediction length {pred_len}: output shape {predictions.shape}")
44
+
45
+ if __name__ == "__main__":
46
+ main()
modeling_super_linear.py CHANGED
@@ -1,17 +1,19 @@
1
- from typing import Optional, Tuple
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- import numpy as np
6
 
7
- from transformers import PreTrainedModel, GenerationMixin
8
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
  from .configuration_super_linear import SuperLinearConfig
10
 
11
 
12
  "-------------------------------------------------------------------------------------------------------------------"
13
  class RevIN(nn.Module):
14
- def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type = None, subtract_last = False):
15
  """
16
  :param num_features: the number of features or channels
17
  :param eps: a value added for numerical stability
@@ -26,13 +28,14 @@ class RevIN(nn.Module):
26
  if self.affine:
27
  self._init_params()
28
 
29
- def forward(self, x, mode:str):
30
  if mode == 'norm':
31
  self._get_statistics(x)
32
  x = self._normalize(x)
33
  elif mode == 'denorm':
34
  x = self._denormalize(x)
35
- else: raise NotImplementedError
 
36
  return x
37
 
38
  def _init_params(self):
@@ -44,18 +47,19 @@ class RevIN(nn.Module):
44
  dim2reduce = tuple(range(1, x.ndim-1))
45
 
46
  if self.subtract_last:
47
- self.last = x[:,-1,:].unsqueeze(1)
 
48
  else:
49
  self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
 
50
  self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
51
- if self.norm_type == "l1":
52
- self.denom = torch.sum(torch.abs(x), dim=dim2reduce, keepdim=True).detach()
53
- elif self.norm_type == "l2":
54
- self.denom = torch.sqrt(torch.sum(x**2, dim=dim2reduce, keepdim=True)).detach()
 
55
 
56
-
57
  def _normalize(self, x):
58
-
59
  if self.subtract_last:
60
  x = x - self.last
61
  else:
@@ -63,7 +67,7 @@ class RevIN(nn.Module):
63
  x = x / self.stdev
64
 
65
  if self.norm_type in ["l1", "l2"]:
66
- x = x / self.denom
67
 
68
  if self.affine:
69
  x = x * self.affine_weight
@@ -74,8 +78,10 @@ class RevIN(nn.Module):
74
  if self.affine:
75
  x = x - self.affine_bias
76
  x = x / (self.affine_weight + self.eps*self.eps)
 
77
  if self.norm_type in ["l1", "l2"]:
78
- x = x * self.denom
 
79
  x = x * self.stdev
80
  if self.subtract_last:
81
  x = x + self.last
@@ -173,7 +179,7 @@ class SparseMoE(nn.Module):
173
  self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True)
174
 
175
  if self.moe_norm:
176
- self.gate_norm = nn.BatchNorm1d(self.num_experts)
177
 
178
  def get_periodogram(self, inputs, n=10000):
179
  """
@@ -189,38 +195,28 @@ class SparseMoE(nn.Module):
189
  Returns:
190
  Normalized periodogram of the input signals
191
  """
192
- if inputs.dim() == 2:
193
- x_0 = inputs.unsqueeze(2)
194
- else:
195
- x_0 = inputs
196
- x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True) # Remove mean (DC component)
197
 
198
  # Compute FFT and normalize
199
  dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
200
- dft = dft[:, :n//2, :] # Keep only positive frequencies
201
  I = torch.abs(dft) ** 2 # Power spectral density
202
 
203
  # Normalize periodogram
204
  I_sum = torch.sum(I, dim=1, keepdim=True)
205
  I_sum[I_sum == 0] = 1 # Avoid division by zero
206
  I = I / I_sum
207
-
208
- if torch.any(I_sum == 0):
209
- print("Zeros in the sum")
210
- raise ValueError
211
-
212
- if inputs.dim() == 2:
213
- I = I.squeeze(2)
214
 
215
  return I
216
 
217
- def forward(self, x, get_prob=False):
218
  """
219
  Forward pass through the Mixture of Experts.
220
 
221
  Args:
222
  x: Input tensor of shape [batch_size, sequence_length]
223
  get_prob: Whether to return expert selection probabilities
 
224
 
225
  Returns:
226
  - Output tensor from the selected experts
@@ -233,27 +229,30 @@ class SparseMoE(nn.Module):
233
  x_0 = x
234
 
235
  # Get gating logits
236
- self.gate_outputs = self.gating_network(x_0) # Raw gating scores
237
 
238
  if self.moe_norm:
239
- self.gate_outputs = self.gate_norm(self.gate_outputs)
240
 
241
  # Apply temperature scaling during inference
242
  if not self.training:
243
- self.gate_outputs = self.gate_outputs / self.moe_temp
 
 
 
 
244
 
245
  # Add noise to gating logits during training (for exploration)
246
- noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
247
  if self.training:
248
- noisy_gate_outputs = self.gate_outputs + noise
249
- self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
 
250
  else:
251
- self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
252
 
253
  # Normalize the gate values with softmax
254
- self.topk_gates = F.softmax(self.topk_values, dim=1)
255
 
256
- batch_size = x.size(0)
257
  # Get outputs from all experts
258
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
259
 
@@ -262,10 +261,10 @@ class SparseMoE(nn.Module):
262
  sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
263
 
264
  # Combine expert outputs using the gate values
265
- output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
266
 
267
  if get_prob:
268
- expert_probs = F.softmax(self.gate_outputs, dim=1)
269
  return output, expert_probs
270
 
271
  return output
@@ -283,19 +282,36 @@ class Model(nn.Module):
283
  """
284
  def __init__(self, configs):
285
  super(Model, self).__init__()
286
- self.configs = configs
287
- self.model_name = "SuperLinear"
 
 
288
  self.train_pred_len = configs.train_pred_len
289
  self.train_seq_len = configs.train_seq_len
290
- self.resample_long_lookback = configs.resample_long_lookback
291
  self.layer_type = configs.layer_type
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  # Parse frequency experts from configuration
294
- if configs.freq_experts == "":
 
295
  self.freq_experts = None
296
  else:
297
- self.freq_experts = configs.freq_experts.split('_')
298
 
 
299
  self.top_k_experts = configs.top_k_experts
300
  self.freeze_experts = configs.freeze_experts
301
 
@@ -303,88 +319,222 @@ class Model(nn.Module):
303
  self.experts = {}
304
  if self.freq_experts is not None:
305
  for expert_freq in self.freq_experts:
306
- if expert_freq == "naive" or expert_freq == "Naive":
307
  self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len)
308
- elif expert_freq == "mean" or expert_freq == "Mean":
309
  self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
310
  else:
311
- # Use the appropriate expert class based on layer_type
312
- expert_classes = {'Linear': Linear, 'RLinear': RLinear}
313
- if self.layer_type in expert_classes:
314
- expert_class = expert_classes[self.layer_type]
315
- self.experts[expert_freq] = expert_class(self.train_seq_len, self.train_pred_len)
316
- else:
317
- # Default to RLinear if unknown layer type
318
- self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
319
  else:
320
- raise ValueError("No frequency experts specified in configuration.")
321
 
322
  # Create additional complementary experts if specified
323
- if configs.comp_moe > 0:
324
- for i in range(configs.comp_moe):
325
- expert_classes = {'Linear': Linear, 'RLinear': RLinear}
326
- if self.layer_type in expert_classes:
327
- expert_class = expert_classes[self.layer_type]
328
- self.experts[f"comp_{i}"] = expert_class(self.train_seq_len, self.train_pred_len)
329
- else:
330
- # Default to RLinear if unknown layer type
331
- self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
332
 
333
- # Initialize the MoE layer
334
  self.moe = SparseMoE(configs, experts=self.experts.values())
335
 
336
  print("Experts:", self.experts.keys())
337
 
338
- def add_experts(self, experts: dict):
339
  """
340
  Add new experts to the model.
341
 
342
  Args:
343
  experts: Dictionary of expert instances to add
 
 
 
344
  """
345
  for name, expert in experts.items():
346
- self.experts[name] = expert
 
 
 
 
347
  # Reinitialize the MoE layer with the updated experts
348
  self.moe = SparseMoE(self.configs, experts=self.experts.values())
349
  return self.moe
350
 
351
- def resample_seq_len(self, x, pred_len, inverse=False, orig_pred_len=None):
352
  """
353
- Resample sequence length for handling inputs shorter than expected training length.
 
 
 
354
 
355
  Args:
356
- x: Input tensor
357
- pred_len: Prediction length
358
- inverse: If True, downsample back to original scale; if False, upsample
359
- orig_pred_len: Original prediction length (required for inverse=True)
360
 
361
  Returns:
362
- Tuple of (resampled_tensor, updated_pred_len, scale_factor, orig_pred_len)
363
- For inverse=True: returns (resampled_tensor, None, None, None)
364
  """
365
- if not inverse:
366
- # Upsample if input is shorter than training length
367
- if x.size(-1) < self.train_seq_len:
368
- scale_factor = self.train_seq_len / x.size(-1)
369
- x_resampled = F.interpolate(x.unsqueeze(1), size=self.train_seq_len, mode='linear', align_corners=False).squeeze(1)
370
- pred_len_resampled = int(pred_len * scale_factor)
371
- return x_resampled, pred_len_resampled, scale_factor, pred_len
372
- else:
373
- return x, pred_len, None, None
374
- else:
375
- # Downsample back to original scale
376
- if orig_pred_len is not None:
377
- x_resampled = F.interpolate(x.unsqueeze(1), size=orig_pred_len, mode='linear', align_corners=False).squeeze(1)
378
- return x_resampled, None, None, None
379
- else:
380
- return x, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- def forward(self, x_in, get_prob=False, pred_len=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  """
384
  Forward pass through the model.
385
 
386
  Args:
387
- x_in: Encoder input tensor
388
  get_prob: Whether to return expert selection probabilities
389
  pred_len: Override for prediction length
390
 
@@ -398,97 +548,142 @@ class Model(nn.Module):
398
  x = x_in
399
  # If input is 2D, add a channel dimension
400
  if x_in.dim() == 2:
401
- x = x.unsqueeze(-1)
402
 
403
- # Permute to shape [batch_size, features, sequence_length]
404
- x = x.permute(0, 2, 1)
405
  B, V, L = x.shape
406
 
407
- scale_factor = None
408
- orig_pred_len = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
- # Handle resampling if input is shorter than training length
411
- if self.resample_long_lookback and L < self.train_seq_len:
412
- x, pred_len, scale_factor, orig_pred_len = self.resample_seq_len(x, pred_len, inverse=False)
 
 
 
 
413
 
414
- # Reshape for MoE processing
415
- x = x.reshape(B * V, x.size(-1))
416
 
417
- # Forward through MoE
 
 
 
 
418
  if get_prob:
419
  out, expert_probs = self.moe(x, get_prob=True)
420
  else:
421
  out = self.moe(x)
422
 
 
423
  if self.train_pred_len < pred_len:
424
  outputs = [out]
425
  ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:]
426
  for i in range(0, pred_len, self.train_pred_len):
427
  ar_out = self.moe(ar_x)
 
428
  outputs.append(ar_out)
429
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:]
430
  out = torch.cat(outputs, dim=1)[:, :pred_len]
431
 
432
- # Reshape back
433
- out = out.reshape(B, V, out.size(-1))
434
 
435
- # Handle resampling back to original scale if needed
436
- if scale_factor is not None:
437
- out, _, _, _ = self.resample_seq_len(out, None, inverse=True, orig_pred_len=orig_pred_len)
438
 
439
- # Return to original shape conventions
440
- result = out.permute(0, 2, 1)
441
-
442
- if x_in.dim() == 2:
443
- result = result.squeeze(-1)
444
 
 
 
 
 
445
  if get_prob:
446
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
447
- return result, expert_probs
448
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  "-------------------------------------------------------------------------------------------------------------------"
450
- class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
451
  config_class = SuperLinearConfig
452
 
453
  def __init__(self, config: SuperLinearConfig):
454
  super().__init__(config)
455
-
456
-
457
- # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
458
- backbone_cfg = type("Cfg", (), config.to_dict())()
459
- self.args = backbone_cfg
460
- self.backbone = Model(backbone_cfg)
461
- self.post_init()
462
 
463
  # ------------------------------------------------------------------
464
  # Forward pass expected by AutoModelForCausalLM
465
  # ------------------------------------------------------------------
466
  def forward(self,
467
- inputs_embeds: torch.Tensor = None,
468
- attention_mask: Optional[torch.Tensor] = None,
469
- past_key_values: Optional[Tuple] = None,
470
- use_cache: bool = True,
471
- labels: Optional[torch.Tensor] = None,
472
- **kwargs,) -> CausalLMOutputWithCrossAttentions:
473
 
474
-
475
  if inputs_embeds is None:
476
- raise ValueError("Pass the time‑series as `inputs_embeds`")
477
 
478
- # backbone expects (B, C, L)
479
  x_enc = inputs_embeds
480
 
481
  # backbone returns (B, pred_len, C)
482
- preds = self.backbone(x_enc, pred_len=kwargs.get("pred_len", None))
483
- return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
 
 
 
 
 
 
 
 
 
 
484
 
485
 
486
- def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
487
- if past_key_values is not None:
488
- # only feed the last new step
489
- inputs_embeds = inputs_embeds[:, -1:, :]
490
- return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
491
 
492
- def _reorder_cache(self, past, beam_idx, **kwargs):
493
- return past # backbone keeps no KV cache
494
 
 
1
+ from typing import Optional, Tuple, Dict, List, Union
2
+ import copy
3
+ import numpy as np
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from torch.nn.functional import interpolate
8
 
9
+ from transformers import PreTrainedModel
10
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
11
  from .configuration_super_linear import SuperLinearConfig
12
 
13
 
14
  "-------------------------------------------------------------------------------------------------------------------"
15
  class RevIN(nn.Module):
16
+ def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type=None, subtract_last=False):
17
  """
18
  :param num_features: the number of features or channels
19
  :param eps: a value added for numerical stability
 
28
  if self.affine:
29
  self._init_params()
30
 
31
+ def forward(self, x, mode: str):
32
  if mode == 'norm':
33
  self._get_statistics(x)
34
  x = self._normalize(x)
35
  elif mode == 'denorm':
36
  x = self._denormalize(x)
37
+ else:
38
+ raise NotImplementedError
39
  return x
40
 
41
  def _init_params(self):
 
47
  dim2reduce = tuple(range(1, x.ndim-1))
48
 
49
  if self.subtract_last:
50
+ self.last = x[:, -1:, :].detach()
51
+ self.mean = torch.mean(x[:, :-1, :], dim=dim2reduce, keepdim=True).detach()
52
  else:
53
  self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
54
+
55
  self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
56
+
57
+ if self.norm_type == "l1":
58
+ self.stdev = torch.mean(torch.abs(x - self.mean), dim=dim2reduce, keepdim=True).detach()
59
+ elif self.norm_type == "l2":
60
+ self.stdev = torch.sqrt(torch.mean((x - self.mean) ** 2, dim=dim2reduce, keepdim=True) + self.eps).detach()
61
 
 
62
  def _normalize(self, x):
 
63
  if self.subtract_last:
64
  x = x - self.last
65
  else:
 
67
  x = x / self.stdev
68
 
69
  if self.norm_type in ["l1", "l2"]:
70
+ x = x / self.stdev
71
 
72
  if self.affine:
73
  x = x * self.affine_weight
 
78
  if self.affine:
79
  x = x - self.affine_bias
80
  x = x / (self.affine_weight + self.eps*self.eps)
81
+
82
  if self.norm_type in ["l1", "l2"]:
83
+ x = x * self.stdev
84
+
85
  x = x * self.stdev
86
  if self.subtract_last:
87
  x = x + self.last
 
179
  self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True)
180
 
181
  if self.moe_norm:
182
+ self.batch_norm = nn.BatchNorm1d(self.num_experts)
183
 
184
  def get_periodogram(self, inputs, n=10000):
185
  """
 
195
  Returns:
196
  Normalized periodogram of the input signals
197
  """
198
+ x_0 = inputs - torch.mean(inputs, dim=1, keepdim=True) # Remove mean (DC component)
 
 
 
 
199
 
200
  # Compute FFT and normalize
201
  dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
202
+ dft = dft[:, :n//2] # Keep only positive frequencies
203
  I = torch.abs(dft) ** 2 # Power spectral density
204
 
205
  # Normalize periodogram
206
  I_sum = torch.sum(I, dim=1, keepdim=True)
207
  I_sum[I_sum == 0] = 1 # Avoid division by zero
208
  I = I / I_sum
 
 
 
 
 
 
 
209
 
210
  return I
211
 
212
+ def forward(self, x, get_prob=False, get_prob_only=False):
213
  """
214
  Forward pass through the Mixture of Experts.
215
 
216
  Args:
217
  x: Input tensor of shape [batch_size, sequence_length]
218
  get_prob: Whether to return expert selection probabilities
219
+ get_prob_only: Whether to return only probabilities without computation
220
 
221
  Returns:
222
  - Output tensor from the selected experts
 
229
  x_0 = x
230
 
231
  # Get gating logits
232
+ gate_outputs = self.gating_network(x_0) # Raw gating scores
233
 
234
  if self.moe_norm:
235
+ gate_outputs = self.batch_norm(gate_outputs)
236
 
237
  # Apply temperature scaling during inference
238
  if not self.training:
239
+ gate_outputs = gate_outputs / self.moe_temp
240
+
241
+ if get_prob_only:
242
+ expert_probs = F.softmax(gate_outputs, dim=1)
243
+ return expert_probs
244
 
245
  # Add noise to gating logits during training (for exploration)
 
246
  if self.training:
247
+ noise = torch.randn_like(gate_outputs).to(x.device) * self.noise_std
248
+ noisy_gate_outputs = gate_outputs + noise
249
+ topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
250
  else:
251
+ topk_values, topk_indices = torch.topk(gate_outputs, self.k, dim=1)
252
 
253
  # Normalize the gate values with softmax
254
+ topk_gates = F.softmax(topk_values, dim=1)
255
 
 
256
  # Get outputs from all experts
257
  expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
258
 
 
261
  sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
262
 
263
  # Combine expert outputs using the gate values
264
+ output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
265
 
266
  if get_prob:
267
+ expert_probs = F.softmax(gate_outputs, dim=1)
268
  return output, expert_probs
269
 
270
  return output
 
282
  """
283
  def __init__(self, configs):
284
  super(Model, self).__init__()
285
+
286
+ self.configs = copy.deepcopy(configs)
287
+
288
+ # Core model configuration
289
  self.train_pred_len = configs.train_pred_len
290
  self.train_seq_len = configs.train_seq_len
 
291
  self.layer_type = configs.layer_type
292
 
293
+
294
+ # Initialize additional configuration attributes with defaults
295
+ self.long_horizon_scaling = configs.long_horizon_scaling
296
+ self.lookback_resampling = configs.lookback_resampling
297
+ lookback_scale_str = configs.scale_list
298
+ if isinstance(lookback_scale_str, str):
299
+ self.scale_list = [float(x.strip()) for x in lookback_scale_str.split(',')]
300
+ else:
301
+ self.scale_list = lookback_scale_str # Already a list
302
+ self.threshold = configs.threshold
303
+ self.freq_bound = configs.freq_bound
304
+ self.penalty_scale = configs.penalty_scale
305
+ self.fft_len = configs.fft_len
306
+
307
  # Parse frequency experts from configuration
308
+ freq_experts_str = configs.freq_experts
309
+ if freq_experts_str == "":
310
  self.freq_experts = None
311
  else:
312
+ self.freq_experts = freq_experts_str.split('_')
313
 
314
+ # Expert configuration
315
  self.top_k_experts = configs.top_k_experts
316
  self.freeze_experts = configs.freeze_experts
317
 
 
319
  self.experts = {}
320
  if self.freq_experts is not None:
321
  for expert_freq in self.freq_experts:
322
+ if expert_freq.lower() == "naive":
323
  self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len)
324
+ elif expert_freq.lower() == "mean":
325
  self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
326
  else:
327
+ self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
328
+ self.n_experts = len(self.experts)
 
 
 
 
 
 
329
  else:
330
+ raise ValueError("Please specify experts in the configuration.")
331
 
332
  # Create additional complementary experts if specified
333
+ comp_moe = configs.comp_moe
334
+ if comp_moe > 0:
335
+ if comp_moe == 1:
336
+ print("Creating complementary expert")
337
+ self.experts["comp"] = RLinear(self.train_seq_len, self.train_pred_len)
338
+ else:
339
+ for i in range(comp_moe):
340
+ print(f"Creating complementary expert {i}")
341
+ self.experts["comp_"+str(i)] = RLinear(self.train_seq_len, self.train_pred_len)
342
 
343
+ # Initialize the MoE layer and dropout
344
  self.moe = SparseMoE(configs, experts=self.experts.values())
345
 
346
  print("Experts:", self.experts.keys())
347
 
348
+ def add_experts(self, experts: Dict[str, nn.Module]) -> nn.Module:
349
  """
350
  Add new experts to the model.
351
 
352
  Args:
353
  experts: Dictionary of expert instances to add
354
+
355
+ Returns:
356
+ Updated MoE layer
357
  """
358
  for name, expert in experts.items():
359
+ if name not in self.experts:
360
+ self.experts[name] = expert
361
+ print(f"Added expert: {name}")
362
+ else:
363
+ print(f"Expert {name} already exists. Skipping addition.")
364
  # Reinitialize the MoE layer with the updated experts
365
  self.moe = SparseMoE(self.configs, experts=self.experts.values())
366
  return self.moe
367
 
368
+ def apply_long_horizon_scaling(self, ar_out: torch.Tensor, ar_x: torch.Tensor) -> torch.Tensor:
369
  """
370
+ Apply scaling to auto-regressive outputs to maintain statistical properties during long horizon prediction.
371
+
372
+ This function identifies cases where the variance of the new predictions exceeds the variance
373
+ of the input sequence and applies scaling to maintain consistent statistical properties.
374
 
375
  Args:
376
+ ar_out: Auto-regressive output tensor of shape [batch_size * features, pred_len]
377
+ ar_x: Input sequence tensor of shape [batch_size * features, seq_len]
 
 
378
 
379
  Returns:
380
+ Scaled auto-regressive output tensor
 
381
  """
382
+ if not (self.long_horizon_scaling and not self.training):
383
+ return ar_out
384
+
385
+ # Calculate statistics for scaling
386
+ std_new = torch.std(ar_out, dim=1, keepdim=True)
387
+ mean_new = torch.mean(ar_out, dim=1, keepdim=True)
388
+ std_old = torch.std(ar_x, dim=1, keepdim=True)
389
+
390
+ # Find indices where new variance exceeds old variance
391
+ inds = torch.where(std_new / std_old > 1)[0]
392
+
393
+ if len(inds) > 0:
394
+ # Center the outputs around their mean
395
+ ar_out_centered = ar_out[inds] - mean_new[inds]
396
+
397
+ # Calculate scaling factor to match old variance
398
+ scaling = std_old[inds] / (std_new[inds] + 1e-8)
399
+
400
+ # Scale and shift back to mean_new
401
+ ar_out_adjusted = ar_out_centered * scaling + mean_new[inds]
402
+ ar_out[inds] = ar_out_adjusted
403
+
404
+ return ar_out
405
+
406
+ def lookback_resample_search(self, x, scale_list=[2,4,6], min_lookback=512):
407
+ """
408
+ Search for optimal resampling scale based on lookback analysis of expert selection.
409
+
410
+ This function analyzes the frequency content and expert selection lookback to determine
411
+ the best resampling scale for each input sequence, potentially improving model performance
412
+ by matching input characteristics to expert capabilities.
413
+
414
+ Args:
415
+ x: Input tensor of shape [batch_size, features, sequence_length]
416
+ scale_list: List of potential downsampling scales to evaluate
417
+ min_lookback: Minimum sequence length required after resampling
418
+
419
+ Returns:
420
+ Tuple of (resampled_input, final_scales) where:
421
+ - resampled_input: Optimally resampled input tensor
422
+ - final_scales: Scale factors used for each sample
423
+ """
424
+ B, V, L = x.shape
425
+
426
+ lookback = self.train_seq_len
427
+ x_0 = x.reshape(B*V, L)[:, -lookback:]
428
+ output_x = x_0.clone()[:, -lookback:]
429
+
430
+ x_reshape = x.reshape(B*V, L)
431
+ x_fft_init = self.moe.get_periodogram(x_reshape, n=self.fft_len)
432
+
433
+ right_cumsum = torch.cumsum(x_fft_init, dim=-1)
434
+ mask = right_cumsum > 1-self.threshold
435
+ j_threshold = mask.float().argmax(dim=-1)
436
+
437
+ freqs = np.array([np.linspace(0, 0.5, self.fft_len//2)])
438
+ threshhold_freqs = np.take_along_axis(freqs, j_threshold.unsqueeze(-1).detach().cpu().numpy(), axis=1)
439
+
440
+ # where threshhold_freqs is 0, set to a small value to avoid division by zero
441
+ threshhold_freqs[threshhold_freqs == 0] = self.freq_bound
442
+ max_scale_factor = (self.freq_bound/ threshhold_freqs).astype(int).flatten()
443
 
444
+
445
+ if self.threshold==0:
446
+ max_scale_factor = np.inf * np.ones(B*V, dtype=int)
447
+
448
+ # Compute energy loss penalty for each potential scale
449
+ energy_loss_penalties = {}
450
+ total_energy = torch.sum(x_fft_init, dim=-1) # Total energy per sample
451
+
452
+ for scale in scale_list:
453
+ if scale <= 1:
454
+ continue # No penalty for upsampling or no scaling
455
+
456
+ # Calculate Nyquist frequency after downsampling
457
+ nyquist_after_downsample = 0.5 / scale
458
+
459
+ # Find frequency bins that will be lost (above new Nyquist)
460
+ freq_bins = torch.linspace(0, 0.5, self.fft_len//2, device=x_fft_init.device)
461
+ lost_freq_mask = freq_bins > nyquist_after_downsample
462
+
463
+ # Calculate energy that will be lost
464
+ lost_energy = torch.sum(x_fft_init[:, lost_freq_mask], dim=-1)
465
+ # Energy loss fraction (0 = no loss, 1 = all energy lost)
466
+ energy_loss_fraction = lost_energy / (total_energy + 1e-10)
467
+ energy_loss_penalties[scale] = energy_loss_fraction
468
+
469
+ # Get initial entropy
470
+ prob = self.moe(x_0, get_prob_only=True)
471
+ best_scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1)
472
+ final_scales = torch.ones(B*V, device=x.device)
473
+
474
+ for scale in scale_list:
475
+ x_interp = torch.nn.functional.interpolate(
476
+ x, scale_factor=1/scale, mode='linear', align_corners=True
477
+ )
478
+
479
+ if x_interp.shape[2] >= min_lookback:
480
+ x_interp_reshaped = x_interp.reshape(B*V, x_interp.shape[-1])
481
+ x_interp_reshaped = x_interp_reshaped[:, -lookback:]
482
+ prob = self.moe(x_interp_reshaped, get_prob_only=True)
483
+
484
+ scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1)
485
+
486
+ # Add energy loss penalty
487
+ if scale in energy_loss_penalties:
488
+ energy_penalty = energy_loss_penalties[scale]
489
+ scores = scores + energy_penalty*self.penalty_scale
490
+
491
+ idx = np.where((scores < best_scores).cpu() & torch.tensor(max_scale_factor >= scale))[0]
492
+
493
+ if len(idx) > 0:
494
+ output_x[idx] = x_interp_reshaped[idx]
495
+ best_scores[idx] = scores[idx]
496
+ final_scales[idx] = scale
497
+
498
+ return output_x.reshape(B, V, output_x.shape[-1]), final_scales
499
+
500
+ def lookback_resample_reverse(self, y, final_scales, inf_pred_len=None):
501
+ """
502
+ Reverse the resampling operation on the output.
503
+
504
+ This function upsamples the model outputs back to the original scale
505
+ based on the resampling factors used during input processing.
506
+
507
+ Args:
508
+ y: Output tensor from model of shape [batch_size, features, pred_len]
509
+ final_scales: Scale factors used during input resampling
510
+ inf_pred_len: Target prediction length
511
+
512
+ Returns:
513
+ Upsampled output tensor of shape [batch_size, features, inf_pred_len]
514
+ """
515
+ B, V, L = y.shape
516
+ y_reshaped = y.view(B*V, L)
517
+ y_out = y_reshaped[:, :inf_pred_len]
518
+
519
+ unique_scales = torch.unique(final_scales)
520
+ for scale in unique_scales:
521
+ scale_val = scale.item() # Convert tensor to scalar
522
+ if scale_val > 1:
523
+ idx = torch.where(final_scales == scale)[0]
524
+
525
+ if len(idx) > 0:
526
+ y_interp = torch.nn.functional.interpolate(
527
+ y_reshaped[idx].unsqueeze(1), scale_factor=scale_val, mode='linear', align_corners=True
528
+ )
529
+ y_out[idx] = y_interp.reshape(len(idx), y_interp.shape[-1])[:, :inf_pred_len]
530
+ return y_out.reshape(B, V, inf_pred_len)
531
+
532
+ def forward(self, x_in: torch.Tensor, get_prob: bool = False, pred_len: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
533
  """
534
  Forward pass through the model.
535
 
536
  Args:
537
+ x_in: Encoder input tensor of shape [batch_size, sequence_length] or [batch_size, features, sequence_length]
538
  get_prob: Whether to return expert selection probabilities
539
  pred_len: Override for prediction length
540
 
 
548
  x = x_in
549
  # If input is 2D, add a channel dimension
550
  if x_in.dim() == 2:
551
+ x = x.unsqueeze(1)
552
 
 
 
553
  B, V, L = x.shape
554
 
555
+ short_lookback = False
556
+ orig_pred_len = pred_len
557
+
558
+ if L < self.train_seq_len:
559
+ # Handle case where input sequence is shorter than expected
560
+ # by interpolating to the required length
561
+ scale_factor = self.train_seq_len / L
562
+ scale_factor = int(np.ceil(scale_factor))
563
+
564
+ pred_len = pred_len * scale_factor
565
+ x = interpolate(x, scale_factor=scale_factor, mode='linear')
566
+
567
+ x = x[:, :, -self.train_seq_len:]
568
+ L = self.train_seq_len
569
+
570
+ short_lookback = True
571
+
572
+ # lookback resampling logic
573
+ final_scales = None
574
+
575
+ if self.lookback_resampling and L > self.train_seq_len:
576
 
577
+ x_resampled, final_scales = self.lookback_resample_search(
578
+ x, self.scale_list, self.train_seq_len
579
+ )
580
+
581
+ # Update x and L for the resampled input
582
+ x = x_resampled
583
+ L = x.shape[-1]
584
 
 
 
585
 
586
+ # Reshape to process each feature independently
587
+ x = x.reshape(B * V, L)
588
+ expert_probs = None
589
+
590
+ # Forward pass through MoE
591
  if get_prob:
592
  out, expert_probs = self.moe(x, get_prob=True)
593
  else:
594
  out = self.moe(x)
595
 
596
+ # Auto-regressive prediction for long horizons
597
  if self.train_pred_len < pred_len:
598
  outputs = [out]
599
  ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:]
600
  for i in range(0, pred_len, self.train_pred_len):
601
  ar_out = self.moe(ar_x)
602
+ ar_out = self.apply_long_horizon_scaling(ar_out, ar_x)
603
  outputs.append(ar_out)
604
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:]
605
  out = torch.cat(outputs, dim=1)[:, :pred_len]
606
 
607
+ # Reshape back to batch format
608
+ out = out.reshape(B, V, out.shape[-1])
609
 
610
+ # Apply lookback resampling reverse if it was used
611
+ if self.lookback_resampling and final_scales is not None and not short_lookback:
612
+ out = self.lookback_resample_reverse(out, final_scales, orig_pred_len)
613
 
614
+ # If we used interpolation earlier, now downsample back to original scale
615
+ if short_lookback:
616
+ out = interpolate(out, scale_factor=1/scale_factor, mode='linear')
617
+ out = out[:, :, :orig_pred_len]
 
618
 
619
+
620
+ if x_in.dim() == 2:
621
+ out = out.squeeze(1)
622
+
623
  if get_prob:
624
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
625
+ # expert_probs = expert_probs.permute(0, 2, 1) # [batch_size, num_experts, sequence_length]
626
+ if x_in.dim() == 2:
627
+ expert_probs = expert_probs.squeeze(-1)
628
+ return out, expert_probs
629
+
630
+ return out
631
+
632
+ def map_to_cycle(self, freq: str) -> int:
633
+ """
634
+ Map frequency string notation to cycle length (number of periods).
635
+
636
+ Args:
637
+ freq: String representing a time frequency (e.g., "h" for hourly, "D" for daily)
638
+
639
+ Returns:
640
+ Integer representing the number of periods in the cycle
641
+ """
642
+ cycle = int(freq.split("/")[1])
643
+ return cycle
644
+
645
  "-------------------------------------------------------------------------------------------------------------------"
646
+ class SuperLinearForCausalLM(PreTrainedModel):
647
  config_class = SuperLinearConfig
648
 
649
  def __init__(self, config: SuperLinearConfig):
650
  super().__init__(config)
651
+
652
+ # the backbone keeps its own Config dataclass, so build one on-the-fly:
653
+ backbone_cfg = type("Cfg", (), config.to_dict())()
654
+ self.args = backbone_cfg
655
+ self.backbone = Model(backbone_cfg)
656
+ self.post_init()
 
657
 
658
  # ------------------------------------------------------------------
659
  # Forward pass expected by AutoModelForCausalLM
660
  # ------------------------------------------------------------------
661
  def forward(self,
662
+ inputs_embeds: torch.Tensor = None,
663
+ pred_len: Optional[int] = None,
664
+ get_prob: bool = False,
665
+ **kwargs) -> CausalLMOutputWithCrossAttentions:
 
 
666
 
 
667
  if inputs_embeds is None:
668
+ raise ValueError("inputs_embeds must be provided")
669
 
670
+ # backbone expects (B, C, L) or (B, L)
671
  x_enc = inputs_embeds
672
 
673
  # backbone returns (B, pred_len, C)
674
+ if get_prob:
675
+ preds, probs = self.backbone(x_enc, pred_len=pred_len, get_prob=True)
676
+ else:
677
+ preds = self.backbone(x_enc, pred_len=pred_len, get_prob=False)
678
+ probs = None
679
+
680
+ return CausalLMOutputWithCrossAttentions(
681
+ logits=preds,
682
+ hidden_states=None,
683
+ attentions=probs
684
+ )
685
+
686
 
687
 
 
 
 
 
 
688
 
 
 
689
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pyyaml
2
+ numpy
3
+ pandas
4
+ torch
5
+ scikit-learn
6
+ ipykernel
7
+ transformers>=4.40.1
8
+ datasets>=2.18.0
9
+ accelerate>=0.28.0