razmars commited on
Commit
431f8e0
·
verified ·
1 Parent(s): 28a2d64

Upload 5 files

Browse files
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "super_linear",
3
+ "architectures": [
4
+ "SuperLinearForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_super_linear.SuperLinearConfig",
8
+ "AutoModelForCausalLM": "modeling_super_linear.SuperLinearForCausalLM"
9
+ },
10
+ "auto_regressive": 1,
11
+ "con": 0,
12
+ "d_model": 512,
13
+ "dropout": 0.0,
14
+ "fft_len": 10000,
15
+ "freeze_experts": 1,
16
+ "freq_experts": "mean_naive_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600",
17
+ "inf_pred_len": 96,
18
+ "ker_len": 50,
19
+ "layer_type": "RLinear",
20
+ "linear_checkpoints_dir": "checkpoints5",
21
+ "linear_checkpoints_path": "/cs/azencot_fsas/MoE/",
22
+ "load_linear": 0,
23
+ "manual_moe": 0,
24
+ "max_horizon": 96,
25
+ "misc_moe": 1,
26
+ "mlp_gating": 1,
27
+ "model_type": "super_linear",
28
+ "moe": 1,
29
+ "moe_n_experts": 8,
30
+ "moe_temp": 1,
31
+ "noisy_gating_std": 0.1,
32
+ "noisy_gating_std_decay": 1,
33
+ "pred_len": 96,
34
+ "seq_len": 512,
35
+ "top_k_experts": 3,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.40.1",
38
+ "use_fft": 1
39
+ }
configuration_super_linear.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch, torch.nn as nn, torch.nn.functional as F
3
+
4
+ from transformers import (
5
+ PretrainedConfig,
6
+ PreTrainedModel,
7
+ GenerationMixin,
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ )
11
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
12
+
13
+ # 1) --------------------------------------------------------------------------
14
+ # CONFIG
15
+ # -----------------------------------------------------------------------------
16
+
17
+
18
+ class SuperLinearConfig(PretrainedConfig):
19
+ """
20
+ Configuration for the SuperLinear MoE time–series foundation model.
21
+ Only *model_type* must be unique inside transformers; the rest mirrors
22
+ the __init__ arguments of your original Config object.
23
+ """
24
+
25
+ model_type = "super_linear"
26
+
27
+ def __init__(
28
+ self,
29
+ seq_len=512,
30
+ pred_len=96,
31
+ inf_pred_len=96,
32
+ max_horizon=96,
33
+ auto_regressive=1,
34
+ moe_n_experts=8,
35
+ top_k_experts=3,
36
+ moe =1,
37
+ freq_experts='mean_naive_1/6_1/7_1/8_1/12_1/14_1/16_1/21_1/24_1/28_1/30_1/32_1/36_1/42_1/48_1/52_1/56_1/60_1/72_1/84_1/96_1/120_1/144_1/168_1/180_1/224_1/252_1/288_1/336_1/365_1/504_1/672_1/1008_1/1440_1/2016_1/3600',
38
+ **kwargs, # any extra CLI args
39
+ ):
40
+ self.seq_len = seq_len
41
+ self.moe = moe
42
+ self.pred_len = pred_len
43
+ self.inf_pred_len = inf_pred_len
44
+ self.max_horizon = max_horizon
45
+ self.auto_regressive = auto_regressive
46
+ self.moe_n_experts = moe_n_experts
47
+ self.top_k_experts = top_k_experts
48
+ self.freq_experts = freq_experts
49
+ self.freeze_experts = 1
50
+ self.layer_type = "RLinear"
51
+ self.linear_checkpoints_path = '/cs/azencot_fsas/MoE/'
52
+ self.linear_checkpoints_dir = "checkpoints5"
53
+ self.load_linear = 0
54
+ self.manual_moe = 0
55
+ self.misc_moe = 1
56
+ self.noisy_gating_std = 0.1
57
+ self.noisy_gating_std_decay = 1
58
+ self.ker_len = 50
59
+ self.con = 0
60
+ self.d_model = 512
61
+ self.mlp_gating = 1
62
+ self.moe_temp = 1
63
+ self.use_fft = 1
64
+ self.fft_len = 10000
65
+ self.dropout = 0.0
66
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.40.1"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b75b2af5dad25c465306bed1ba128eb39b76cad2107e90b42579e7c3e5d192b
3
+ size 17419560
modeling_super_linear.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch, torch.nn as nn, torch.nn.functional as F
3
+
4
+ from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
5
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
6
+ from SuperLinear.model.super_linear_config import SuperLinearConfig
7
+
8
+
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+ import numpy as np
13
+
14
+
15
+ "-------------------------------------------------------------------------------------------------------------------"
16
+ class RevIN(nn.Module):
17
+ def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type = None, subtract_last = False):
18
+ """
19
+ :param num_features: the number of features or channels
20
+ :param eps: a value added for numerical stability
21
+ :param affine: if True, RevIN has learnable affine parameters
22
+ """
23
+ super(RevIN, self).__init__()
24
+ self.num_features = num_features
25
+ self.eps = eps
26
+ self.affine = affine
27
+ self.subtract_last = subtract_last
28
+ self.norm_type = norm_type
29
+ if self.affine:
30
+ self._init_params()
31
+
32
+ def forward(self, x, mode:str):
33
+ if mode == 'norm':
34
+ self._get_statistics(x)
35
+ x = self._normalize(x)
36
+ elif mode == 'denorm':
37
+ x = self._denormalize(x)
38
+ else: raise NotImplementedError
39
+ return x
40
+
41
+ def _init_params(self):
42
+ # initialize RevIN params: (C,)
43
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
44
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
45
+
46
+ def _get_statistics(self, x):
47
+ dim2reduce = tuple(range(1, x.ndim-1))
48
+
49
+ if self.subtract_last:
50
+ self.last = x[:,-1,:].unsqueeze(1)
51
+ else:
52
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
53
+ self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
54
+ if self.norm_type == "l1":
55
+ self.denom = torch.sum(torch.abs(x), dim=dim2reduce, keepdim=True).detach()
56
+ elif self.norm_type == "l2":
57
+ self.denom = torch.sqrt(torch.sum(x**2, dim=dim2reduce, keepdim=True)).detach()
58
+
59
+
60
+ def _normalize(self, x):
61
+
62
+ if self.subtract_last:
63
+ x = x - self.last
64
+ else:
65
+ x = x - self.mean
66
+ x = x / self.stdev
67
+
68
+ if self.norm_type in ["l1", "l2"]:
69
+ x = x / self.denom
70
+
71
+ if self.affine:
72
+ x = x * self.affine_weight
73
+ x = x + self.affine_bias
74
+ return x
75
+
76
+ def _denormalize(self, x):
77
+ if self.affine:
78
+ x = x - self.affine_bias
79
+ x = x / (self.affine_weight + self.eps*self.eps)
80
+ if self.norm_type in ["l1", "l2"]:
81
+ x = x * self.denom
82
+ x = x * self.stdev
83
+ if self.subtract_last:
84
+ x = x + self.last
85
+ else:
86
+ x = x + self.mean
87
+
88
+ return x
89
+ "-------------------------------------------------------------------------------------------------------------------"
90
+ class moving_avg(nn.Module):
91
+ """
92
+ Moving average block to highlight the trend of time series
93
+ """
94
+ def __init__(self, kernel_size, stride):
95
+ super(moving_avg, self).__init__()
96
+ self.kernel_size = kernel_size
97
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
98
+ """
99
+ def forward(self, x):
100
+ # padding on the both ends of time series
101
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
102
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
103
+ x = torch.cat([front, x, end], dim=1)
104
+ x = self.avg(x.permute(0, 2, 1))
105
+ x = x.permute(0, 2, 1)
106
+ return x
107
+ """
108
+ def forward(self, x):
109
+ # x: [Batch, Input length]
110
+ # padding on the both ends of time series
111
+ front = x[:, 0:1].repeat(1, (self.kernel_size - 1) // 2)
112
+ end = x[:, -1:].repeat(1, (self.kernel_size - 1) // 2)
113
+ x = torch.cat([front, x, end], dim=1)
114
+ x = self.avg(x.unsqueeze(1)).squeeze(1)
115
+ return x
116
+
117
+
118
+ class series_decomp(nn.Module):
119
+ """
120
+ Series decomposition block
121
+ """
122
+ def __init__(self, kernel_size):
123
+ super(series_decomp, self).__init__()
124
+ self.moving_avg = moving_avg(kernel_size, stride=1)
125
+
126
+ def forward(self, x):
127
+ moving_mean = self.moving_avg(x)
128
+ res = x - moving_mean
129
+ return res, moving_mean
130
+
131
+
132
+ class DLinear(nn.Module):
133
+ def __init__(self, input_len, output_len, kernel_size = 25):
134
+ super(DLinear, self).__init__()
135
+ self.seasonal = nn.Linear(input_len, output_len)
136
+ self.trend = nn.Linear(input_len, output_len)
137
+ self.moving_avg = moving_avg(kernel_size, stride=1)
138
+ self.decompsition = series_decomp(kernel_size)
139
+
140
+ def forward(self, x):
141
+ # x: [Batch*Input length,Channel]
142
+ seasonal_init, trend_init = self.decompsition(x)
143
+ seasonal_output = self.seasonal(seasonal_init)
144
+ trend_output = self.trend(trend_init)
145
+ x = seasonal_output + trend_output
146
+ return x # to [Batch, Output length, Channel]
147
+
148
+ class Linear(nn.Module):
149
+ def __init__(self, input_len, output_len):
150
+ super(Linear, self).__init__()
151
+ self.Linear = nn.Linear(input_len, output_len)
152
+
153
+ def forward(self, x):
154
+ # x: [Batch*Channel, Input length]
155
+ x = x.clone()
156
+ x = self.Linear(x).clone()
157
+ return x # to [Batch, Output length, Channel]
158
+
159
+ class Naive(nn.Module):
160
+ def __init__(self, input_len, output_len):
161
+ super(Naive, self).__init__()
162
+ self.output_len = output_len
163
+
164
+
165
+ def forward(self, x):
166
+ # x: [Batch*Channel, Input length]
167
+ x = x[:,-1].unsqueeze(1).repeat(1, self.output_len)
168
+ return x # to [Batch, Output length, Channel]
169
+
170
+ class Mean(nn.Module):
171
+ def __init__(self, input_len, output_len):
172
+ super(Mean, self).__init__()
173
+ self.output_len = output_len
174
+
175
+ def forward(self, x):
176
+ # x: [Batch*Channel, Input length]
177
+ x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
178
+ return x # to [Batch, Output length, Channel]
179
+
180
+
181
+ class NLinear(nn.Module):
182
+ def __init__(self, input_len, output_len):
183
+ super(NLinear, self).__init__()
184
+ self.Linear = nn.Linear(input_len, output_len)
185
+
186
+ def forward(self, x):
187
+ # x: [Batch, Input length,Channel]
188
+ seq_last = x[:,-1:].detach()
189
+ x = x - seq_last
190
+ x = self.Linear(x)
191
+ return x+seq_last # to [Batch, Output length, Channel]
192
+
193
+
194
+ class RLinear(nn.Module):
195
+ def __init__(self, input_len, output_len):
196
+ super(RLinear, self).__init__()
197
+ self.Linear = nn.Linear(input_len, output_len)
198
+ self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
199
+
200
+ def forward(self, x):
201
+ # x: [Batch, Input length,Channel]
202
+ x_shape = x.shape
203
+ if len(x_shape) == 2:
204
+ x = x.unsqueeze(-1)
205
+ x = x.clone()
206
+ x = self.revin_layer(x, 'norm')
207
+
208
+ x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
209
+ x = self.revin_layer(x, 'denorm')
210
+ if len(x_shape) == 2:
211
+ x = x.squeeze(-1)
212
+ return x # to [Batch, Output length, Channel]
213
+
214
+
215
+ "-------------------------------------------------------------------------------------------------------------------"
216
+ class SparseNoisyMoE(nn.Module):
217
+ def __init__(self, configs, experts=None):
218
+ super(SparseNoisyMoE, self).__init__()
219
+ input_dim = configs.seq_len
220
+ output_dim = configs.pred_len
221
+ self.k = configs.top_k_experts
222
+ self.noise_std = configs.noisy_gating_std
223
+ self.noise_std_decay = configs.noisy_gating_std_decay
224
+ self.experts = nn.ModuleList(experts)
225
+ self.num_experts = len(experts)
226
+ self.ker_len = configs.ker_len
227
+ self.con = configs.con
228
+ self.d_model = configs.d_model
229
+ self.mlp_gating = configs.mlp_gating
230
+ self.moe_temp = configs.moe_temp
231
+ self.use_fft = configs.use_fft
232
+ self.fft_len = configs.fft_len
233
+
234
+ if self.use_fft:
235
+ if self.mlp_gating:
236
+ self.gating_network = nn.Sequential(
237
+ nn.Linear(self.fft_len//2, self.d_model),
238
+ nn.ReLU(),
239
+ nn.Linear(self.d_model, self.num_experts)
240
+ )
241
+ else:
242
+ self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
243
+ else:
244
+ self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
245
+
246
+ def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
247
+ if inputs.dim() == 2:
248
+ x_0 = inputs.unsqueeze(2)
249
+ else:
250
+ x_0 = inputs
251
+ x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
252
+
253
+ v = torch.arange(0, n) / n
254
+ if con:
255
+ if ker_len is None:
256
+ ker_len = n // 4
257
+ ker_len = min(ker_len, 50)
258
+
259
+ x_0 = x_0.permute(0, 2, 1)
260
+ ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
261
+ x_c = F.conv1d(x_0, ker, padding="same")
262
+ x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
263
+ x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
264
+ x_0 = x_0 - x_c
265
+ x_0 = x_0.permute(0, 2, 1)
266
+
267
+ dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
268
+ dft = dft[:, :n//2, :]
269
+ I = torch.abs(dft) ** 2
270
+
271
+ I_sum = torch.sum(I, dim=1, keepdim=True)
272
+ I_sum[I_sum == 0] = 1
273
+ I = I / I_sum
274
+
275
+ if torch.any(I_sum == 0):
276
+ print("Zeros in the sum")
277
+ raise ValueError
278
+
279
+ if inputs.dim() == 2:
280
+ I = I.squeeze(2)
281
+
282
+ return I
283
+
284
+ def forward(self, x, get_prob=False):
285
+ if self.use_fft:
286
+ x_0 = self.get_periodogram(x, ker_len=self.ker_len, n=self.fft_len, con=self.con)
287
+ else:
288
+ x_0 = x
289
+
290
+ self.gate_outputs = self.gating_network(x_0)
291
+
292
+ if not self.training:
293
+ self.gate_outputs = self.gate_outputs / self.moe_temp
294
+
295
+ noise = torch.randn_like(self.gate_outputs).to(x.device) * self.noise_std
296
+ if self.training:
297
+ noisy_gate_outputs = self.gate_outputs + noise
298
+ self.topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1)
299
+ else:
300
+ self.topk_values, topk_indices = torch.topk(self.gate_outputs, self.k, dim=1)
301
+
302
+ self.topk_gates = F.softmax(self.topk_values, dim=1)
303
+
304
+ batch_size = x.size(0)
305
+ expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
306
+
307
+ topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
308
+ sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded)
309
+
310
+ output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
311
+
312
+ load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
313
+
314
+ if get_prob:
315
+ expert_probs = F.softmax(self.gate_outputs, dim=1)
316
+ return output, load_balancing_loss, expert_probs
317
+
318
+ return output, load_balancing_loss
319
+
320
+ def calculate_load_balancing_loss(self, gate_outputs, batch_size):
321
+ gate_probs = F.softmax(gate_outputs, dim=1)
322
+
323
+ assignments = torch.argmax(gate_outputs, dim=1)
324
+ self.D = torch.zeros(self.num_experts, device=gate_outputs.device)
325
+ for i in range(self.num_experts):
326
+ self.D[i] = torch.sum(assignments == i).float() / batch_size
327
+
328
+ P = torch.mean(gate_probs, dim=0)
329
+
330
+ load_balancing_loss = torch.sum(self.D * P) * self.num_experts
331
+
332
+ return load_balancing_loss
333
+
334
+
335
+ class superLinear(nn.Module):
336
+ def __init__(self, configs):
337
+ super(superLinear, self).__init__()
338
+
339
+ self.configs = configs
340
+ self.pred_len = configs.pred_len
341
+ self.seq_len = configs.seq_len
342
+ self.inf_pred_len = configs.inf_pred_len
343
+ self.max_horizon = configs.max_horizon
344
+ self.auto_regressive = configs.auto_regressive
345
+ self.n_experts = configs.moe_n_experts
346
+ self.moe = configs.moe
347
+
348
+ if configs.freq_experts == "":
349
+ self.freq_experts = None
350
+ else:
351
+ self.freq_experts = configs.freq_experts.split('_')
352
+
353
+ print("self.freq_experts:", self.freq_experts)
354
+
355
+ self.moe_loss = None
356
+ self.top_k_experts = configs.top_k_experts
357
+ # self.noisy_gating = configs.noisy_gating
358
+ self.n_experts = configs.moe_n_experts
359
+ self.freeze_experts = configs.freeze_experts
360
+ self.layer_type = configs.layer_type
361
+ self.model_name = "SuperLinear"
362
+
363
+ print("self.layer_type", self.layer_type)
364
+ self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
365
+ path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir + "/"
366
+ dirs = os.listdir(path)
367
+ checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
368
+
369
+ if self.freq_experts == "all":
370
+ self.freq_experts = []
371
+ for cp in checkpoints_paths:
372
+ if self.layer_type in cp:
373
+ cycle = cp.split("/")
374
+
375
+ self.experts = {}
376
+ if self.freq_experts is not None:
377
+ for expert_freq in self.freq_experts:
378
+ if expert_freq == "naive" or expert_freq == "Naive":
379
+ self.experts[expert_freq] = Naive(self.seq_len, self.pred_len)
380
+ elif expert_freq == "mean" or expert_freq == "Mean":
381
+ self.experts[expert_freq] = Mean(self.seq_len, self.pred_len)
382
+ else:
383
+ self.experts[expert_freq] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
384
+ if configs.load_linear:
385
+ cycle = self.map_to_cycle(expert_freq)
386
+ cycle_str = f'cycle_{cycle}/'
387
+ cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
388
+ if len(cycle_checkpoint_path) > 0:
389
+ print()
390
+ print(cycle_str)
391
+ cycle_checkpoint_path = cycle_checkpoint_path[0]
392
+ #print(f'loading checkpoint with layer type: {self.layer_type} and cycle: {cycle_str}')
393
+ print(cycle_checkpoint_path)
394
+ self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
395
+ else:
396
+ print(f"Checkpoint for {cycle_str} not found in {path}")
397
+ raise ValueError(f"Checkpoint for {cycle_str} not found in {path}")
398
+ if configs.freeze_experts:
399
+ for param in self.experts[expert_freq].parameters():
400
+ param.requires_grad = False
401
+
402
+ self.n_experts = len(self.experts)
403
+ else:
404
+ for i in range(self.n_experts):
405
+ print(f"creating expert {i}")
406
+ self.experts[str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
407
+
408
+ self.manual_moe = configs.manual_moe
409
+
410
+ if configs.misc_moe == 1:
411
+ self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
412
+
413
+ self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
414
+ self.dropout = nn.Dropout(configs.dropout)
415
+
416
+ def map_to_cycle(self, freq):
417
+ if "/" in freq:
418
+ cycle = int(freq.split("/")[1])
419
+ elif "h" in freq:
420
+ cycle = 24
421
+ elif "2h":
422
+ cycle = 12
423
+ elif "3h" in freq:
424
+ cycle = 8
425
+ elif "4h" in freq:
426
+ cycle = 6
427
+ elif "D" in freq:
428
+ cycle = 7
429
+ elif "DM" in freq:
430
+ cycle = 30
431
+ elif "W" in freq:
432
+ cycle = 52
433
+ elif "M" in freq:
434
+ cycle = 12
435
+ elif "min" in freq:
436
+ cycle = 1440
437
+ elif "5min" in freq:
438
+ cycle = 288
439
+ elif "10min" in freq:
440
+ cycle = 144
441
+ elif "15min" in freq:
442
+ cycle = 96
443
+ elif "30min" in freq:
444
+ cycle = 48
445
+ else:
446
+ cycle = int(freq)
447
+ return cycle
448
+
449
+
450
+ def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=False):
451
+ x = x_enc.permute(0, 2, 1)
452
+ B, V, L = x.shape
453
+ x = x.reshape(B * V, L)
454
+
455
+ expert_probs = None
456
+
457
+ if get_prob:
458
+ out, self.moe_loss, expert_probs = self.moe(x, get_prob=True)
459
+ else:
460
+ out, self.moe_loss = self.moe(x)
461
+
462
+ if self.auto_regressive and self.max_horizon < self.inf_pred_len:
463
+ outputs = [out]
464
+ ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
465
+ for i in range(0, self.inf_pred_len, self.max_horizon):
466
+ ar_out, _ = self.moe(ar_x)
467
+ outputs.append(ar_out)
468
+ ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
469
+ out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
470
+ out = out.reshape(B, V, out.shape[-1])
471
+ result = out.permute(0, 2, 1)
472
+
473
+ if get_prob:
474
+ expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
475
+ return result, expert_probs
476
+ return result
477
+
478
+ "-------------------------------------------------------------------------------------------------------------------"
479
+ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
480
+ config_class = SuperLinearConfig
481
+
482
+ def __init__(self, config: SuperLinearConfig):
483
+ super().__init__(config)
484
+
485
+
486
+ # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
487
+ backbone_cfg = type("Cfg", (), config.to_dict())()
488
+ self.backbone = superLinear(backbone_cfg)
489
+
490
+ # optional final projection: map backbone output to discrete bins
491
+ # (delete if your model already returns logits over a vocabulary)
492
+ self.vocab_size = getattr(config, "vocab_size", None)
493
+ if self.vocab_size is not None:
494
+ self.lm_head = nn.Linear(backbone_cfg.pred_len, self.vocab_size)
495
+
496
+ self.post_init() # HF weight init
497
+
498
+ # ------------------------------------------------------------------
499
+ # Forward pass expected by AutoModelForCausalLM
500
+ # ------------------------------------------------------------------
501
+ def forward(self,
502
+ inputs_embeds: torch.Tensor = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ past_key_values: Optional[Tuple] = None,
505
+ use_cache: bool = True,
506
+ labels: Optional[torch.Tensor] = None,
507
+ **kwargs,) -> CausalLMOutputWithCrossAttentions:
508
+
509
+
510
+ if inputs_embeds is None:
511
+ raise ValueError("Pass the time‑series as `inputs_embeds`")
512
+
513
+ print(f"Input shape: {inputs_embeds.shape}")
514
+ # backbone expects (B, C, L)
515
+ x_enc = inputs_embeds
516
+
517
+
518
+ # backbone returns (B, pred_len, C)
519
+ preds = self.backbone(x_enc)[0]
520
+
521
+ # if we keep continuous values, treat them as logits directly
522
+ logits = (preds if self.vocab_size is None else self.lm_head(preds).transpose(1, 2))
523
+
524
+ loss = None
525
+ if labels is not None:
526
+ # shift for causal objective
527
+ shift_logits = logits[..., :-1, :].contiguous()
528
+ shift_labels = labels[..., 1:].contiguous()
529
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
530
+
531
+ return CausalLMOutputWithCrossAttentions(loss=loss,logits=logits,past_key_values=None,hidden_states=None,attentions=None,)
532
+
533
+
534
+ def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
535
+ if past_key_values is not None:
536
+ # only feed the last new step
537
+ inputs_embeds = inputs_embeds[:, -1:, :]
538
+ return {"inputs_embeds": inputs_embeds, "past_key_values": past_key_values}
539
+
540
+ def _reorder_cache(self, past, beam_idx, **kwargs):
541
+ return past # backbone keeps no KV cache
542
+
543
+ "-------------------------------------------------------------------------------------------------------------------"
544
+ # 3) --------------------------------------------------------------------------
545
+ # REGISTRATION (one‑liner you run **once** before .from_pretrained)
546
+ # -----------------------------------------------------------------------------
547
+
548
+
549
+ AutoConfig.register(SuperLinearConfig.model_type, SuperLinearConfig)
550
+ AutoModelForCausalLM.register(SuperLinearConfig, SuperLinearForCausalLM)