lirannoc commited on
Commit
ad2aa43
·
verified ·
1 Parent(s): 5086864

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -68
modeling_super_linear.py CHANGED
@@ -1,12 +1,12 @@
1
- from typing import Optional, Tuple
2
- import torch, torch.nn as nn, torch.nn.functional as F
 
 
3
  import numpy as np
4
- import matplotlib.pyplot as plt
5
- import os
6
 
7
- from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
8
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
- from .configuration_super_linear import SuperLinearConfig
10
 
11
 
12
  "-------------------------------------------------------------------------------------------------------------------"
@@ -170,7 +170,7 @@ class SparseMoE(nn.Module):
170
  if self.use_fft:
171
  self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
172
  else:
173
- self.gating_network = nn.Linear(configs.seq_len, self.num_experts, bias=True)
174
 
175
  if self.moe_norm:
176
  self.gate_norm = nn.BatchNorm1d(self.num_experts)
@@ -289,11 +289,6 @@ class Model(nn.Module):
289
  self.train_seq_len = configs.train_seq_len
290
  self.resample_long_lookback = configs.resample_long_lookback
291
  self.layer_type = configs.layer_type
292
- self.load_weights_full = configs.load_weights_full
293
- self.load_linear = configs.load_linear
294
-
295
- if self.load_weights_full:
296
- pass # TODO: implement full weight loading
297
 
298
  # Parse frequency experts from configuration
299
  if configs.freq_experts == "":
@@ -303,9 +298,6 @@ class Model(nn.Module):
303
 
304
  self.top_k_experts = configs.top_k_experts
305
  self.freeze_experts = configs.freeze_experts
306
- path = configs.linear_freq_weights_path
307
- linear_freq_dirs = os.listdir(path) if os.path.exists(path) else []
308
- checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in linear_freq_dirs]
309
 
310
  # Initialize experts based on frequency specification or create generic experts
311
  self.experts = {}
@@ -324,19 +316,6 @@ class Model(nn.Module):
324
  else:
325
  # Default to RLinear if unknown layer type
326
  self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
327
-
328
- if self.load_linear and checkpoints_paths:
329
- cycle = self.map_to_cycle(expert_freq)
330
- cycle_str = f'cycle_{cycle}/'
331
- cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
332
- if len(cycle_checkpoint_path) > 0:
333
- cycle_checkpoint_path = cycle_checkpoint_path[0]
334
- print(f'Loading checkpoint: {cycle_checkpoint_path}')
335
- self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
336
-
337
- if self.freeze_experts:
338
- for param in self.experts[expert_freq].parameters():
339
- param.requires_grad = False
340
  else:
341
  # Create generic experts
342
  for i in range(configs.n_experts):
@@ -359,12 +338,8 @@ class Model(nn.Module):
359
  # Default to RLinear if unknown layer type
360
  self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
361
 
362
- # Initialize the MoE layer and dropout
363
  self.moe = SparseMoE(configs, experts=self.experts.values())
364
-
365
- # Load pre-trained weights if specified
366
- if configs.load_weights_full:
367
- pass # TODO: implement full weight loading
368
 
369
  print("Experts:", self.experts.keys())
370
 
@@ -470,40 +445,6 @@ class Model(nn.Module):
470
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
471
  return result, expert_probs
472
  return result
473
-
474
- def map_to_cycle(self, freq):
475
- """Map frequency string to cycle length for expert loading."""
476
- if "/" in freq:
477
- cycle = int(freq.split("/")[1])
478
- elif "h" in freq:
479
- cycle = 24
480
- elif "2h" in freq:
481
- cycle = 12
482
- elif "3h" in freq:
483
- cycle = 8
484
- elif "4h" in freq:
485
- cycle = 6
486
- elif "D" in freq:
487
- cycle = 7
488
- elif "DM" in freq:
489
- cycle = 30
490
- elif "W" in freq:
491
- cycle = 52
492
- elif "M" in freq:
493
- cycle = 12
494
- elif "min" in freq:
495
- cycle = 1440
496
- elif "5min" in freq:
497
- cycle = 288
498
- elif "10min" in freq:
499
- cycle = 144
500
- elif "15min" in freq:
501
- cycle = 96
502
- elif "30min" in freq:
503
- cycle = 48
504
- else:
505
- cycle = int(freq)
506
- return cycle
507
  "-------------------------------------------------------------------------------------------------------------------"
508
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
509
  config_class = SuperLinearConfig
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
  import numpy as np
 
 
6
 
7
+ from transformers import PreTrainedModel, GenerationMixin
8
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
9
+ from .configuration_super_linear import SuperLinearConfig
10
 
11
 
12
  "-------------------------------------------------------------------------------------------------------------------"
 
170
  if self.use_fft:
171
  self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
172
  else:
173
+ self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True)
174
 
175
  if self.moe_norm:
176
  self.gate_norm = nn.BatchNorm1d(self.num_experts)
 
289
  self.train_seq_len = configs.train_seq_len
290
  self.resample_long_lookback = configs.resample_long_lookback
291
  self.layer_type = configs.layer_type
 
 
 
 
 
292
 
293
  # Parse frequency experts from configuration
294
  if configs.freq_experts == "":
 
298
 
299
  self.top_k_experts = configs.top_k_experts
300
  self.freeze_experts = configs.freeze_experts
 
 
 
301
 
302
  # Initialize experts based on frequency specification or create generic experts
303
  self.experts = {}
 
316
  else:
317
  # Default to RLinear if unknown layer type
318
  self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  else:
320
  # Create generic experts
321
  for i in range(configs.n_experts):
 
338
  # Default to RLinear if unknown layer type
339
  self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
340
 
341
+ # Initialize the MoE layer
342
  self.moe = SparseMoE(configs, experts=self.experts.values())
 
 
 
 
343
 
344
  print("Experts:", self.experts.keys())
345
 
 
445
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
446
  return result, expert_probs
447
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  "-------------------------------------------------------------------------------------------------------------------"
449
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
450
  config_class = SuperLinearConfig