lirannoc commited on
Commit
b7195ad
·
verified ·
1 Parent(s): 3b2a0e1

Update configuration_super_linear.py

Browse files
Files changed (1) hide show
  1. configuration_super_linear.py +100 -62
configuration_super_linear.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Optional, Tuple
2
- import torch, torch.nn as nn, torch.nn.functional as F
3
  from transformers import (
4
  PretrainedConfig,
5
  PreTrainedModel,
@@ -15,75 +15,113 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
15
 
16
 
17
  class SuperLinearConfig(PretrainedConfig):
 
 
 
 
 
18
 
19
  model_type = "super_linear"
 
20
  def __init__(
21
  self,
 
 
 
22
  seq_len=512,
23
  pred_len=96,
24
  inf_pred_len=96,
25
  max_horizon=96,
26
- moe_n_experts=12,
27
- top_k_experts=5,
28
- moe =1,
29
- freq_experts= 'mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
30
- auto_regressive= 1,
31
- d_model= 128,
32
- dropout= 0.0,
33
- fft_len= 5000,
34
- freeze_experts= 1,
35
- layer_type= "RLinear",
36
- linear_checkpoints_dir= "checkpoints5",
37
- linear_checkpoints_path= "/cs/azencot_fsas/MoE/",
38
- load_linear = 0,
39
- load_weights =0,
40
- misc_moe = 10,
41
- mlp_gating = 0,
42
- moe_norm = 0,
43
- model_type= "super_linear",
44
- moe_temp = 1,
45
- noisy_gating_std = 0.1,
46
- noisy_gating_std_decay = 1,
47
- torch_dtype = "float32",
48
- transformers_version = "4.40.1",
49
- use_fft = 1,
50
- train_epochs = 30,
51
- patience = 5,
52
- lradj = "constant",
53
- learning_rate = 0.05,
54
- channel_ind = 0,
55
- full_size = 0,
56
- **kwargs, # any extra CLI args
 
 
 
 
 
 
 
 
 
 
 
57
  ):
58
- self.seq_len = seq_len
59
- self.moe = moe
60
- self.pred_len = pred_len
61
- self.inf_pred_len = inf_pred_len
62
- self.max_horizon = max_horizon
 
 
63
  self.auto_regressive = auto_regressive
64
- self.moe_n_experts = moe_n_experts
65
- self.top_k_experts = top_k_experts
66
- self.freq_experts = freq_experts
67
- self.freeze_experts = freeze_experts
68
- self.layer_type = layer_type
69
- self.linear_checkpoints_path = linear_checkpoints_path
70
- self.linear_checkpoints_dir = linear_checkpoints_dir
71
- self.load_linear = load_linear
72
- self.load_weights = load_weights
73
- self.misc_moe = misc_moe
74
- self.noisy_gating_std = noisy_gating_std
75
- self.noisy_gating_std_decay = noisy_gating_std_decay
76
- self.d_model = d_model
77
- self.mlp_gating = mlp_gating
78
- self.moe_norm = moe_norm
79
- self.moe_temp = moe_temp
80
- self.use_fft = use_fft
81
- self.fft_len = fft_len
82
- self.dropout = dropout
83
- self.train_epochs = train_epochs
84
- self.patience = patience
85
- self.lradj = lradj
86
- self.learning_rate = learning_rate
87
- self.channel_ind = channel_ind
88
- self.full_size = full_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  super().__init__(**kwargs)
 
1
  from typing import Optional, Tuple
2
+
3
  from transformers import (
4
  PretrainedConfig,
5
  PreTrainedModel,
 
15
 
16
 
17
  class SuperLinearConfig(PretrainedConfig):
18
+ """
19
+ Configuration for the SuperLinear MoE time–series foundation model.
20
+ Only *model_type* must be unique inside transformers; the rest mirrors
21
+ the __init__ arguments of your original Config object.
22
+ """
23
 
24
  model_type = "super_linear"
25
+
26
  def __init__(
27
  self,
28
+ # Model architecture parameters
29
+ train_seq_len=512,
30
+ train_pred_len=96,
31
  seq_len=512,
32
  pred_len=96,
33
  inf_pred_len=96,
34
  max_horizon=96,
35
+ auto_regressive=1,
36
+
37
+ # MoE parameters
38
+ moe_n_experts=4,
39
+ top_k_experts=12,
40
+ noisy_gating_std=0.1,
41
+ moe_temp=1.0,
42
+ moe_norm=False,
43
+ layer_type='RLinear',
44
+ n_experts=4,
45
+ comp_moe=12,
46
+ freeze_experts=True,
47
+ moe=1,
48
+
49
+ # FFT-based gating parameters
50
+ use_fft=True,
51
+ fft_len=5000,
52
+
53
+ # Expert configuration
54
+ freq_experts='mean_naive_1/4_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/90_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
55
+
56
+ # Model loading and saving
57
+ load_linear=True,
58
+ load_weights_full=True,
59
+ linear_freq_weights_path='./weights/linear_freq_weights/',
60
+ full_weights_path='./weights/full_weights/checkpoint.pth',
61
+
62
+ # Training parameters
63
+ resample_long_lookback=False,
64
+
65
+ # Legacy parameters for backward compatibility
66
+ linear_checkpoints_path='/cs/azencot_fsas/MoE/',
67
+ linear_checkpoints_dir="checkpoints5",
68
+ manual_moe=0,
69
+ misc_moe=1,
70
+ noisy_gating_std_decay=1,
71
+ ker_len=50,
72
+ con=0,
73
+ d_model=512,
74
+ mlp_gating=1,
75
+ dropout=0.0,
76
+ **kwargs,
77
  ):
78
+ # Model architecture parameters
79
+ self.train_seq_len = train_seq_len
80
+ self.train_pred_len = train_pred_len
81
+ self.seq_len = seq_len
82
+ self.pred_len = pred_len
83
+ self.inf_pred_len = inf_pred_len
84
+ self.max_horizon = max_horizon
85
  self.auto_regressive = auto_regressive
86
+
87
+ # MoE parameters
88
+ self.moe = moe
89
+ self.moe_n_experts = moe_n_experts
90
+ self.top_k_experts = top_k_experts
91
+ self.noisy_gating_std = noisy_gating_std
92
+ self.moe_temp = moe_temp
93
+ self.moe_norm = moe_norm
94
+ self.layer_type = layer_type
95
+ self.n_experts = n_experts
96
+ self.comp_moe = comp_moe
97
+ self.freeze_experts = freeze_experts
98
+
99
+ # FFT-based gating parameters
100
+ self.use_fft = use_fft
101
+ self.fft_len = fft_len
102
+
103
+ # Expert configuration
104
+ self.freq_experts = freq_experts
105
+
106
+ # Model loading and saving
107
+ self.load_linear = load_linear
108
+ self.load_weights_full = load_weights_full
109
+ self.linear_freq_weights_path = linear_freq_weights_path
110
+ self.full_weights_path = full_weights_path
111
+
112
+ # Training parameters
113
+ self.resample_long_lookback = resample_long_lookback
114
+
115
+ # Legacy parameters for backward compatibility
116
+ self.linear_checkpoints_path = linear_checkpoints_path
117
+ self.linear_checkpoints_dir = linear_checkpoints_dir
118
+ self.manual_moe = manual_moe
119
+ self.misc_moe = misc_moe
120
+ self.noisy_gating_std_decay = noisy_gating_std_decay
121
+ self.ker_len = ker_len
122
+ self.con = con
123
+ self.d_model = d_model
124
+ self.mlp_gating = mlp_gating
125
+ self.dropout = dropout
126
+
127
  super().__init__(**kwargs)