drbh
commited on
Commit
·
89e2950
1
Parent(s):
aa23f77
feat: support shared experts layer and tests
Browse files- tests/test_mb_moe_shared_expert.py +139 -0
- tests/test_mb_moe_shared_expert_multi.py +200 -0
- torch-ext/megablocks/layers.py +267 -3
tests/test_mb_moe_shared_expert.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import megablocks
|
| 3 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_megablocks_moe_mlp_with_shared_expert_import():
|
| 7 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
| 8 |
+
assert hasattr(mlp, 'shared_up_proj_weight')
|
| 9 |
+
assert hasattr(mlp, 'shared_down_proj_weight')
|
| 10 |
+
assert hasattr(mlp, 'set_shared_expert_weights')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_set_shared_expert_weights():
|
| 14 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
| 15 |
+
|
| 16 |
+
hidden_size = 128
|
| 17 |
+
shared_expert_hidden_size = 256
|
| 18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
+
dtype = torch.float32
|
| 20 |
+
|
| 21 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
|
| 22 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
|
| 23 |
+
up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
|
| 24 |
+
down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
|
| 25 |
+
|
| 26 |
+
mlp.set_shared_expert_weights(
|
| 27 |
+
up_proj_weight=up_proj_weight,
|
| 28 |
+
down_proj_weight=down_proj_weight,
|
| 29 |
+
up_proj_bias=up_proj_bias,
|
| 30 |
+
down_proj_bias=down_proj_bias,
|
| 31 |
+
weighted_sum=True,
|
| 32 |
+
activation_fn=torch.nn.functional.gelu
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
|
| 36 |
+
assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
|
| 37 |
+
assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
|
| 38 |
+
assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
|
| 39 |
+
assert mlp.shared_expert_weighted_sum == True
|
| 40 |
+
assert mlp.shared_activation_fn == torch.nn.functional.gelu
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_create_shared_expert_weights():
|
| 44 |
+
hidden_size = 128
|
| 45 |
+
shared_expert_hidden_size = 256
|
| 46 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 47 |
+
dtype = torch.float32
|
| 48 |
+
|
| 49 |
+
def init_method(tensor):
|
| 50 |
+
torch.nn.init.xavier_uniform_(tensor)
|
| 51 |
+
|
| 52 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
| 53 |
+
hidden_size=hidden_size,
|
| 54 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
| 55 |
+
device=device,
|
| 56 |
+
dtype=dtype,
|
| 57 |
+
init_method=init_method
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
|
| 61 |
+
assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
|
| 62 |
+
assert up_proj_weight.device.type == device.type
|
| 63 |
+
assert down_proj_weight.device.type == device.type
|
| 64 |
+
assert up_proj_weight.dtype == dtype
|
| 65 |
+
assert down_proj_weight.dtype == dtype
|
| 66 |
+
assert up_proj_bias is None
|
| 67 |
+
assert down_proj_bias is None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_shared_expert_weights_none_by_default():
|
| 71 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
| 72 |
+
|
| 73 |
+
assert mlp.shared_up_proj_weight is None
|
| 74 |
+
assert mlp.shared_down_proj_weight is None
|
| 75 |
+
assert mlp.shared_up_proj_bias is None
|
| 76 |
+
assert mlp.shared_down_proj_bias is None
|
| 77 |
+
assert mlp.shared_expert_weighted_sum == False
|
| 78 |
+
assert mlp.shared_activation_fn is None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def test_inheritance_from_megablocks_moe_mlp():
|
| 82 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
| 83 |
+
|
| 84 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
| 85 |
+
assert isinstance(mlp, MegaBlocksMoeMLP)
|
| 86 |
+
assert hasattr(mlp, 'forward')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_shared_expert_weights_custom_init():
|
| 90 |
+
hidden_size = 64
|
| 91 |
+
shared_expert_hidden_size = 128
|
| 92 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 93 |
+
dtype = torch.float16
|
| 94 |
+
|
| 95 |
+
def custom_init(tensor):
|
| 96 |
+
torch.nn.init.constant_(tensor, 0.5)
|
| 97 |
+
|
| 98 |
+
def custom_output_init(tensor):
|
| 99 |
+
torch.nn.init.constant_(tensor, 0.1)
|
| 100 |
+
|
| 101 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
| 102 |
+
hidden_size=hidden_size,
|
| 103 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
| 104 |
+
device=device,
|
| 105 |
+
dtype=dtype,
|
| 106 |
+
init_method=custom_init,
|
| 107 |
+
output_layer_init_method=custom_output_init
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
assert torch.all(up_proj_weight == 0.5)
|
| 111 |
+
assert torch.all(down_proj_weight == 0.1)
|
| 112 |
+
assert up_proj_weight.dtype == dtype
|
| 113 |
+
assert down_proj_weight.dtype == dtype
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_shared_expert_weights_dimensions():
|
| 117 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
| 118 |
+
|
| 119 |
+
batch_size = 4
|
| 120 |
+
seq_len = 16
|
| 121 |
+
hidden_size = 128
|
| 122 |
+
shared_expert_hidden_size = 256
|
| 123 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 124 |
+
|
| 125 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
|
| 126 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
|
| 127 |
+
|
| 128 |
+
mlp.set_shared_expert_weights(
|
| 129 |
+
up_proj_weight=up_proj_weight,
|
| 130 |
+
down_proj_weight=down_proj_weight
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
x = torch.randn(seq_len, batch_size, hidden_size, device=device)
|
| 134 |
+
|
| 135 |
+
expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
|
| 136 |
+
expected_down_output_shape = (seq_len, batch_size, hidden_size)
|
| 137 |
+
|
| 138 |
+
assert up_proj_weight.shape[1] == x.shape[-1]
|
| 139 |
+
assert down_proj_weight.shape[0] == x.shape[-1]
|
tests/test_mb_moe_shared_expert_multi.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
import torch.multiprocessing as mp
|
| 4 |
+
import os
|
| 5 |
+
import pytest
|
| 6 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_distributed_shared_expert_test(rank, world_size):
|
| 10 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 11 |
+
os.environ["MASTER_PORT"] = "12356"
|
| 12 |
+
os.environ["RANK"] = str(rank)
|
| 13 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 14 |
+
|
| 15 |
+
dist.init_process_group(
|
| 16 |
+
backend="gloo",
|
| 17 |
+
rank=rank,
|
| 18 |
+
world_size=world_size,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
| 22 |
+
|
| 23 |
+
hidden_size = 128
|
| 24 |
+
shared_expert_hidden_size = 192
|
| 25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
|
| 27 |
+
def simple_init(tensor):
|
| 28 |
+
torch.nn.init.xavier_uniform_(tensor)
|
| 29 |
+
|
| 30 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
| 31 |
+
hidden_size=hidden_size,
|
| 32 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
| 33 |
+
device=torch.device(device),
|
| 34 |
+
dtype=torch.float32,
|
| 35 |
+
init_method=simple_init
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
model.set_shared_expert_weights(
|
| 39 |
+
up_proj_weight=shared_up_proj_weight,
|
| 40 |
+
down_proj_weight=shared_down_proj_weight,
|
| 41 |
+
up_proj_bias=shared_up_proj_bias,
|
| 42 |
+
down_proj_bias=shared_down_proj_bias,
|
| 43 |
+
weighted_sum=True,
|
| 44 |
+
activation_fn=torch.nn.functional.gelu
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
| 48 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
| 49 |
+
assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
|
| 50 |
+
|
| 51 |
+
print(f"Rank {rank}: Shared expert setup test passed!")
|
| 52 |
+
|
| 53 |
+
dist.destroy_process_group()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
|
| 57 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 58 |
+
os.environ["MASTER_PORT"] = "12357"
|
| 59 |
+
os.environ["RANK"] = str(rank)
|
| 60 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 61 |
+
|
| 62 |
+
dist.init_process_group(
|
| 63 |
+
backend="gloo",
|
| 64 |
+
rank=rank,
|
| 65 |
+
world_size=world_size,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
| 69 |
+
|
| 70 |
+
hidden_size = 64
|
| 71 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 72 |
+
|
| 73 |
+
def simple_init(tensor):
|
| 74 |
+
torch.nn.init.xavier_uniform_(tensor)
|
| 75 |
+
|
| 76 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
| 77 |
+
hidden_size=hidden_size,
|
| 78 |
+
shared_expert_hidden_size=96,
|
| 79 |
+
device=torch.device(device),
|
| 80 |
+
dtype=torch.float32,
|
| 81 |
+
init_method=simple_init
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
model.set_shared_expert_weights(
|
| 85 |
+
up_proj_weight=shared_up_proj_weight,
|
| 86 |
+
down_proj_weight=shared_down_proj_weight,
|
| 87 |
+
weighted_sum=False,
|
| 88 |
+
activation_fn=torch.nn.functional.relu
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
| 92 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
| 93 |
+
assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
|
| 94 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
|
| 95 |
+
|
| 96 |
+
print(f"Rank {rank}: Weighted sum setup test passed!")
|
| 97 |
+
|
| 98 |
+
dist.destroy_process_group()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
| 102 |
+
def test_shared_expert_distributed_functionality(world_size):
|
| 103 |
+
if world_size == 1:
|
| 104 |
+
# Single process test
|
| 105 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
| 106 |
+
|
| 107 |
+
hidden_size = 128
|
| 108 |
+
shared_expert_hidden_size = 192
|
| 109 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 110 |
+
|
| 111 |
+
def simple_init(tensor):
|
| 112 |
+
torch.nn.init.xavier_uniform_(tensor)
|
| 113 |
+
|
| 114 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
| 115 |
+
hidden_size=hidden_size,
|
| 116 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
| 117 |
+
device=torch.device(device),
|
| 118 |
+
dtype=torch.float32,
|
| 119 |
+
init_method=simple_init
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
model.set_shared_expert_weights(
|
| 123 |
+
up_proj_weight=shared_up_proj_weight,
|
| 124 |
+
down_proj_weight=shared_down_proj_weight,
|
| 125 |
+
up_proj_bias=shared_up_proj_bias,
|
| 126 |
+
down_proj_bias=shared_down_proj_bias,
|
| 127 |
+
weighted_sum=True,
|
| 128 |
+
activation_fn=torch.nn.functional.gelu
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
| 132 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
| 133 |
+
assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
|
| 134 |
+
|
| 135 |
+
print("Single process shared expert setup test passed!")
|
| 136 |
+
else:
|
| 137 |
+
# Multi-process test
|
| 138 |
+
mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
|
| 139 |
+
print("Multi-process shared expert test completed successfully!")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
| 143 |
+
def test_shared_expert_distributed_weighted_sum(world_size):
|
| 144 |
+
if world_size == 1:
|
| 145 |
+
# Single process test
|
| 146 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
| 147 |
+
|
| 148 |
+
hidden_size = 64
|
| 149 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 150 |
+
|
| 151 |
+
def simple_init(tensor):
|
| 152 |
+
torch.nn.init.xavier_uniform_(tensor)
|
| 153 |
+
|
| 154 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
| 155 |
+
hidden_size=hidden_size,
|
| 156 |
+
shared_expert_hidden_size=96,
|
| 157 |
+
device=torch.device(device),
|
| 158 |
+
dtype=torch.float32,
|
| 159 |
+
init_method=simple_init
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
model.set_shared_expert_weights(
|
| 163 |
+
up_proj_weight=shared_up_proj_weight,
|
| 164 |
+
down_proj_weight=shared_down_proj_weight,
|
| 165 |
+
weighted_sum=False,
|
| 166 |
+
activation_fn=torch.nn.functional.relu
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
| 170 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
| 171 |
+
assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
|
| 172 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
|
| 173 |
+
|
| 174 |
+
print("Single process weighted sum setup test passed!")
|
| 175 |
+
else:
|
| 176 |
+
# Multi-process test
|
| 177 |
+
mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
|
| 178 |
+
print("Multi-process shared expert weighted sum test completed successfully!")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_shared_expert_single_process():
|
| 182 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
| 183 |
+
|
| 184 |
+
assert model.shared_up_proj_weight is None
|
| 185 |
+
assert model.shared_down_proj_weight is None
|
| 186 |
+
assert hasattr(model, 'set_shared_expert_weights')
|
| 187 |
+
|
| 188 |
+
print("Single process shared expert basic test passed!")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
test_shared_expert_single_process()
|
| 193 |
+
print("Single process test passed!")
|
| 194 |
+
|
| 195 |
+
os.environ['WORLD_SIZE'] = '2'
|
| 196 |
+
test_shared_expert_distributed_functionality()
|
| 197 |
+
print("Distributed functionality test passed!")
|
| 198 |
+
|
| 199 |
+
test_shared_expert_distributed_weighted_sum()
|
| 200 |
+
print("Distributed weighted sum test passed!")
|
torch-ext/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,125 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
def get_device_mesh(model):
|
| 684 |
# Extract device_mesh from child's unused pre_hook closure
|
| 685 |
try:
|
|
@@ -687,7 +866,7 @@ def get_device_mesh(model):
|
|
| 687 |
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 688 |
# Extract the device_mesh from the closure
|
| 689 |
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 690 |
-
except:
|
| 691 |
return None
|
| 692 |
|
| 693 |
|
|
@@ -703,8 +882,11 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 703 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 704 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
| 708 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 709 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 710 |
|
|
@@ -734,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 734 |
hidden_size=self.experts.hidden_size,
|
| 735 |
mlp_impl=mlp_impl,
|
| 736 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
def get_device_mesh(model):
|
| 863 |
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
try:
|
|
|
|
| 866 |
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
# Extract the device_mesh from the closure
|
| 868 |
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
return None
|
| 871 |
|
| 872 |
|
|
|
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
|
| 885 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|