Update modeling_super_linear.py
Browse files- modeling_super_linear.py +9 -68
modeling_super_linear.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
from
|
| 2 |
-
import torch
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import os
|
| 6 |
|
| 7 |
-
from transformers
|
| 8 |
-
from transformers.modeling_outputs
|
| 9 |
-
from .configuration_super_linear
|
| 10 |
|
| 11 |
|
| 12 |
"-------------------------------------------------------------------------------------------------------------------"
|
|
@@ -170,7 +170,7 @@ class SparseMoE(nn.Module):
|
|
| 170 |
if self.use_fft:
|
| 171 |
self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
|
| 172 |
else:
|
| 173 |
-
self.gating_network = nn.Linear(configs.
|
| 174 |
|
| 175 |
if self.moe_norm:
|
| 176 |
self.gate_norm = nn.BatchNorm1d(self.num_experts)
|
|
@@ -289,11 +289,6 @@ class Model(nn.Module):
|
|
| 289 |
self.train_seq_len = configs.train_seq_len
|
| 290 |
self.resample_long_lookback = configs.resample_long_lookback
|
| 291 |
self.layer_type = configs.layer_type
|
| 292 |
-
self.load_weights_full = configs.load_weights_full
|
| 293 |
-
self.load_linear = configs.load_linear
|
| 294 |
-
|
| 295 |
-
if self.load_weights_full:
|
| 296 |
-
pass # TODO: implement full weight loading
|
| 297 |
|
| 298 |
# Parse frequency experts from configuration
|
| 299 |
if configs.freq_experts == "":
|
|
@@ -303,9 +298,6 @@ class Model(nn.Module):
|
|
| 303 |
|
| 304 |
self.top_k_experts = configs.top_k_experts
|
| 305 |
self.freeze_experts = configs.freeze_experts
|
| 306 |
-
path = configs.linear_freq_weights_path
|
| 307 |
-
linear_freq_dirs = os.listdir(path) if os.path.exists(path) else []
|
| 308 |
-
checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in linear_freq_dirs]
|
| 309 |
|
| 310 |
# Initialize experts based on frequency specification or create generic experts
|
| 311 |
self.experts = {}
|
|
@@ -324,19 +316,6 @@ class Model(nn.Module):
|
|
| 324 |
else:
|
| 325 |
# Default to RLinear if unknown layer type
|
| 326 |
self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
|
| 327 |
-
|
| 328 |
-
if self.load_linear and checkpoints_paths:
|
| 329 |
-
cycle = self.map_to_cycle(expert_freq)
|
| 330 |
-
cycle_str = f'cycle_{cycle}/'
|
| 331 |
-
cycle_checkpoint_path = [cp for cp in checkpoints_paths if (cycle_str in cp and self.layer_type in cp)]
|
| 332 |
-
if len(cycle_checkpoint_path) > 0:
|
| 333 |
-
cycle_checkpoint_path = cycle_checkpoint_path[0]
|
| 334 |
-
print(f'Loading checkpoint: {cycle_checkpoint_path}')
|
| 335 |
-
self.experts[expert_freq].load_state_dict(torch.load(cycle_checkpoint_path))
|
| 336 |
-
|
| 337 |
-
if self.freeze_experts:
|
| 338 |
-
for param in self.experts[expert_freq].parameters():
|
| 339 |
-
param.requires_grad = False
|
| 340 |
else:
|
| 341 |
# Create generic experts
|
| 342 |
for i in range(configs.n_experts):
|
|
@@ -359,12 +338,8 @@ class Model(nn.Module):
|
|
| 359 |
# Default to RLinear if unknown layer type
|
| 360 |
self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
|
| 361 |
|
| 362 |
-
# Initialize the MoE layer
|
| 363 |
self.moe = SparseMoE(configs, experts=self.experts.values())
|
| 364 |
-
|
| 365 |
-
# Load pre-trained weights if specified
|
| 366 |
-
if configs.load_weights_full:
|
| 367 |
-
pass # TODO: implement full weight loading
|
| 368 |
|
| 369 |
print("Experts:", self.experts.keys())
|
| 370 |
|
|
@@ -470,40 +445,6 @@ class Model(nn.Module):
|
|
| 470 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 471 |
return result, expert_probs
|
| 472 |
return result
|
| 473 |
-
|
| 474 |
-
def map_to_cycle(self, freq):
|
| 475 |
-
"""Map frequency string to cycle length for expert loading."""
|
| 476 |
-
if "/" in freq:
|
| 477 |
-
cycle = int(freq.split("/")[1])
|
| 478 |
-
elif "h" in freq:
|
| 479 |
-
cycle = 24
|
| 480 |
-
elif "2h" in freq:
|
| 481 |
-
cycle = 12
|
| 482 |
-
elif "3h" in freq:
|
| 483 |
-
cycle = 8
|
| 484 |
-
elif "4h" in freq:
|
| 485 |
-
cycle = 6
|
| 486 |
-
elif "D" in freq:
|
| 487 |
-
cycle = 7
|
| 488 |
-
elif "DM" in freq:
|
| 489 |
-
cycle = 30
|
| 490 |
-
elif "W" in freq:
|
| 491 |
-
cycle = 52
|
| 492 |
-
elif "M" in freq:
|
| 493 |
-
cycle = 12
|
| 494 |
-
elif "min" in freq:
|
| 495 |
-
cycle = 1440
|
| 496 |
-
elif "5min" in freq:
|
| 497 |
-
cycle = 288
|
| 498 |
-
elif "10min" in freq:
|
| 499 |
-
cycle = 144
|
| 500 |
-
elif "15min" in freq:
|
| 501 |
-
cycle = 96
|
| 502 |
-
elif "30min" in freq:
|
| 503 |
-
cycle = 48
|
| 504 |
-
else:
|
| 505 |
-
cycle = int(freq)
|
| 506 |
-
return cycle
|
| 507 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 508 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 509 |
config_class = SuperLinearConfig
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from transformers import PreTrainedModel, GenerationMixin
|
| 8 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 9 |
+
from .configuration_super_linear import SuperLinearConfig
|
| 10 |
|
| 11 |
|
| 12 |
"-------------------------------------------------------------------------------------------------------------------"
|
|
|
|
| 170 |
if self.use_fft:
|
| 171 |
self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True)
|
| 172 |
else:
|
| 173 |
+
self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True)
|
| 174 |
|
| 175 |
if self.moe_norm:
|
| 176 |
self.gate_norm = nn.BatchNorm1d(self.num_experts)
|
|
|
|
| 289 |
self.train_seq_len = configs.train_seq_len
|
| 290 |
self.resample_long_lookback = configs.resample_long_lookback
|
| 291 |
self.layer_type = configs.layer_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
# Parse frequency experts from configuration
|
| 294 |
if configs.freq_experts == "":
|
|
|
|
| 298 |
|
| 299 |
self.top_k_experts = configs.top_k_experts
|
| 300 |
self.freeze_experts = configs.freeze_experts
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
# Initialize experts based on frequency specification or create generic experts
|
| 303 |
self.experts = {}
|
|
|
|
| 316 |
else:
|
| 317 |
# Default to RLinear if unknown layer type
|
| 318 |
self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
else:
|
| 320 |
# Create generic experts
|
| 321 |
for i in range(configs.n_experts):
|
|
|
|
| 338 |
# Default to RLinear if unknown layer type
|
| 339 |
self.experts[f"comp_{i}"] = RLinear(self.train_seq_len, self.train_pred_len)
|
| 340 |
|
| 341 |
+
# Initialize the MoE layer
|
| 342 |
self.moe = SparseMoE(configs, experts=self.experts.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
print("Experts:", self.experts.keys())
|
| 345 |
|
|
|
|
| 445 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
| 446 |
return result, expert_probs
|
| 447 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 449 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 450 |
config_class = SuperLinearConfig
|