|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
from typing import Optional, Any, TYPE_CHECKING |
|
|
|
|
|
from . import _layers |
|
|
from . import ops |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
|
def register_fake(fn): |
|
|
return lambda name: fn |
|
|
|
|
|
else: |
|
|
try: |
|
|
from torch.library import register_fake |
|
|
except ImportError: |
|
|
try: |
|
|
from torch.library import impl_abstract as register_fake |
|
|
except ImportError: |
|
|
|
|
|
def register_fake(op_name): |
|
|
def decorator(fn): |
|
|
return fn |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
def _install_meta_kernels(): |
|
|
"""Install meta kernels for existing MegaBlocks operations""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(ops, "sort"): |
|
|
original_sort = ops.sort |
|
|
|
|
|
def sort_with_meta(x, end_bit=None): |
|
|
if torch.compiler.is_compiling(): |
|
|
print("Using meta kernel for sort") |
|
|
|
|
|
return torch.empty_like(x), torch.empty_like(x) |
|
|
|
|
|
return original_sort(x, end_bit) |
|
|
|
|
|
ops.sort = sort_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "histogram"): |
|
|
original_histogram = ops.histogram |
|
|
|
|
|
def histogram_with_meta(x, max_val): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
return torch.empty((max_val,), dtype=torch.int32, device=x.device) |
|
|
return original_histogram(x, max_val) |
|
|
|
|
|
ops.histogram = histogram_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "inclusive_cumsum"): |
|
|
original_inclusive_cumsum = ops.inclusive_cumsum |
|
|
|
|
|
def inclusive_cumsum_with_meta(x, dim): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
return torch.empty_like(x) |
|
|
return original_inclusive_cumsum(x, dim) |
|
|
|
|
|
ops.inclusive_cumsum = inclusive_cumsum_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "binned_gather"): |
|
|
original_binned_gather = ops.binned_gather |
|
|
|
|
|
def binned_gather_with_meta(x, indices, bins, bin_size, top_k): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
if x.dim() >= 2: |
|
|
hidden_size = x.size(-1) |
|
|
return torch.empty( |
|
|
(bin_size, x.size(1), hidden_size), |
|
|
dtype=x.dtype, |
|
|
device=x.device, |
|
|
) |
|
|
else: |
|
|
return torch.empty((bin_size,), dtype=x.dtype, device=x.device) |
|
|
return original_binned_gather(x, indices, bins, bin_size, top_k) |
|
|
|
|
|
ops.binned_gather = binned_gather_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "binned_scatter"): |
|
|
original_binned_scatter = ops.binned_scatter |
|
|
|
|
|
def binned_scatter_with_meta(x, indices, weights, bins, top_k): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
if x.dim() >= 3: |
|
|
return torch.empty( |
|
|
(x.size(1), x.size(2)), dtype=x.dtype, device=x.device |
|
|
) |
|
|
else: |
|
|
return torch.empty_like(x) |
|
|
return original_binned_scatter(x, indices, weights, bins, top_k) |
|
|
|
|
|
ops.binned_scatter = binned_scatter_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "gather"): |
|
|
original_gather = ops.gather |
|
|
|
|
|
def gather_with_meta(x, indices, bin_ids, bins, top_k): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
if x.dim() >= 2: |
|
|
hidden_size = x.size(-1) |
|
|
return torch.empty( |
|
|
(indices.numel(), hidden_size), dtype=x.dtype, device=x.device |
|
|
) |
|
|
else: |
|
|
return torch.empty(indices.shape, dtype=x.dtype, device=x.device) |
|
|
return original_gather(x, indices, bin_ids, bins, top_k) |
|
|
|
|
|
ops.gather = gather_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "scatter"): |
|
|
original_scatter = ops.scatter |
|
|
|
|
|
def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
seq_len = ( |
|
|
indices.size(0) // top_k |
|
|
if indices.numel() > 0 and top_k > 0 |
|
|
else x.size(0) |
|
|
) |
|
|
if x.dim() >= 2: |
|
|
return torch.empty( |
|
|
(seq_len, x.size(-1)), dtype=x.dtype, device=x.device |
|
|
) |
|
|
else: |
|
|
return torch.empty((seq_len,), dtype=x.dtype, device=x.device) |
|
|
return original_scatter(x, indices, bin_ids, weights, bins, top_k) |
|
|
|
|
|
ops.scatter = scatter_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "replicate"): |
|
|
original_replicate = ops.replicate |
|
|
|
|
|
def replicate_with_meta(x, bins, num_outputs): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
return torch.empty( |
|
|
(x.shape[0], num_outputs), dtype=x.dtype, device=x.device |
|
|
) |
|
|
return original_replicate(x, bins, num_outputs) |
|
|
|
|
|
ops.replicate = replicate_with_meta |
|
|
|
|
|
|
|
|
if hasattr(ops, "repeat"): |
|
|
original_repeat = ops.repeat |
|
|
|
|
|
def repeat_with_meta(x, repeats): |
|
|
if torch.compiler.is_compiling(): |
|
|
|
|
|
if isinstance(repeats, (tuple, list)): |
|
|
new_shape = list(x.shape) |
|
|
for i, rep in enumerate(repeats): |
|
|
if i < len(new_shape): |
|
|
new_shape[i] *= rep |
|
|
return torch.empty(new_shape, dtype=x.dtype, device=x.device) |
|
|
else: |
|
|
new_shape = [x.size(0) * repeats] + list(x.shape[1:]) |
|
|
return torch.empty(new_shape, dtype=x.dtype, device=x.device) |
|
|
return original_repeat(x, repeats) |
|
|
|
|
|
ops.repeat = repeat_with_meta |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
_install_meta_kernels() |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
warnings.warn( |
|
|
f"Failed to install meta kernels for torch.compile support: {e}", UserWarning |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def set_expert_model_parallel_attributes( |
|
|
tensor: torch.Tensor, |
|
|
is_parallel: bool, |
|
|
): |
|
|
assert not hasattr(tensor, "expert_model_parallel") |
|
|
setattr(tensor, "expert_model_parallel", is_parallel) |
|
|
|
|
|
|
|
|
|
|
|
def expert_sharding_degree( |
|
|
world_size: int, |
|
|
moe_num_experts: int, |
|
|
) -> int: |
|
|
esd = min(world_size, moe_num_experts) |
|
|
if (moe_num_experts % esd) != 0: |
|
|
raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") |
|
|
return esd |
|
|
|
|
|
|
|
|
|
|
|
def hidden_sharding_degree( |
|
|
world_size: int, |
|
|
moe_num_experts: int, |
|
|
ffn_hidden_size: int, |
|
|
) -> int: |
|
|
esd = expert_sharding_degree(world_size, moe_num_experts) |
|
|
hsd = world_size // esd |
|
|
if (ffn_hidden_size % hsd) != 0: |
|
|
raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") |
|
|
if (esd * hsd) != world_size: |
|
|
raise ValueError( |
|
|
f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." |
|
|
) |
|
|
return hsd |
|
|
|
|
|
|
|
|
|
|
|
def experts_per_rank( |
|
|
moe_num_experts: int, |
|
|
world_size: int, |
|
|
) -> int: |
|
|
return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) |
|
|
|
|
|
|
|
|
|
|
|
def features_per_rank( |
|
|
ffn_hidden_size: int, world_size: int, moe_num_experts: int |
|
|
) -> int: |
|
|
return ffn_hidden_size // hidden_sharding_degree( |
|
|
world_size, moe_num_experts, ffn_hidden_size |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: |
|
|
low = 1.0 - moe_jitter_eps |
|
|
high = 1.0 + moe_jitter_eps |
|
|
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) |
|
|
return x * (low + noise * (high - low)) |
|
|
|
|
|
|
|
|
|
|
|
def compute_top_k(scores: torch.Tensor, moe_top_k: int): |
|
|
if moe_top_k == 1: |
|
|
return scores.max(dim=-1, keepdim=True) |
|
|
return torch.topk(scores, moe_top_k, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
def route_tokens( |
|
|
x: torch.Tensor, |
|
|
router_weight: torch.Tensor, |
|
|
router_bias: torch.Tensor, |
|
|
moe_top_k: int, |
|
|
moe_num_experts: int, |
|
|
moe_jitter_eps: float = None, |
|
|
moe_normalize_expert_weights: int = None, |
|
|
uniform_expert_assignment: bool = False, |
|
|
training: bool = False, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
if training and moe_jitter_eps is not None: |
|
|
x = apply_jitter(x, moe_jitter_eps) |
|
|
|
|
|
x_flat = x.view(-1, x.shape[-1]) |
|
|
logits = torch.nn.functional.linear(x_flat, router_weight, router_bias) |
|
|
expert_weights, expert_indices = compute_top_k(logits, moe_top_k) |
|
|
expert_weights = expert_weights.softmax(dim=-1) |
|
|
if moe_normalize_expert_weights is not None: |
|
|
expert_weights = expert_weights / torch.norm( |
|
|
expert_weights, |
|
|
p=moe_normalize_expert_weights, |
|
|
dim=-1, |
|
|
keepdim=True, |
|
|
) |
|
|
if uniform_expert_assignment: |
|
|
expert_indices = _layers.router._uniform_expert_assignment( |
|
|
expert_indices, |
|
|
moe_num_experts, |
|
|
) |
|
|
|
|
|
return logits, expert_weights, expert_indices |
|
|
|
|
|
|
|
|
|
|
|
def scale_grad( |
|
|
w: torch.Tensor, |
|
|
gradient_scale: Optional[float] = None, |
|
|
) -> torch.Tensor: |
|
|
if gradient_scale is None: |
|
|
return w |
|
|
return _layers.mlp.scale_gradient(w, gradient_scale) |
|
|
|
|
|
|
|
|
|
|
|
def mlp_forward( |
|
|
x: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w1_bias: torch.Tensor, |
|
|
w2_bias: torch.Tensor, |
|
|
gradient_scale: Optional[float] = None, |
|
|
alpha: float = 1.702, |
|
|
limit: float = 7.0, |
|
|
): |
|
|
|
|
|
w1 = scale_grad(w1, gradient_scale) |
|
|
w2 = scale_grad(w2, gradient_scale) |
|
|
w1_bias = scale_grad(w1_bias, gradient_scale) |
|
|
w2_bias = scale_grad(w2_bias, gradient_scale) |
|
|
|
|
|
|
|
|
w1 = _layers.mlp.resolve_dtensor(w1) |
|
|
w2 = _layers.mlp.resolve_dtensor(w2) |
|
|
w1_bias = _layers.mlp.resolve_dtensor(w1_bias) |
|
|
w2_bias = _layers.mlp.resolve_dtensor(w2_bias) |
|
|
|
|
|
|
|
|
gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] |
|
|
gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
|
|
gate = gate.clamp(min=None, max=limit) |
|
|
up = up.clamp(min=-limit, max=limit) |
|
|
glu = gate * torch.sigmoid(gate * alpha) |
|
|
next_states = torch.bmm(((up + 1) * glu), w2) |
|
|
next_states += w2_bias[..., None, :] |
|
|
return next_states |
|
|
|
|
|
|
|
|
def shared_mlp_forward( |
|
|
x: torch.Tensor, |
|
|
up_proj_weight: torch.Tensor, |
|
|
down_proj_weight: torch.Tensor, |
|
|
up_proj_bias: Optional[torch.Tensor] = None, |
|
|
down_proj_bias: Optional[torch.Tensor] = None, |
|
|
activation_fn: Optional[Any] = None, |
|
|
gradient_scale: Optional[float] = None, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if activation_fn is None: |
|
|
activation_fn = torch.nn.functional.gelu |
|
|
|
|
|
|
|
|
up_proj_weight = scale_grad(up_proj_weight, gradient_scale) |
|
|
down_proj_weight = scale_grad(down_proj_weight, gradient_scale) |
|
|
if up_proj_bias is not None: |
|
|
up_proj_bias = scale_grad(up_proj_bias, gradient_scale) |
|
|
if down_proj_bias is not None: |
|
|
down_proj_bias = scale_grad(down_proj_bias, gradient_scale) |
|
|
|
|
|
|
|
|
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) |
|
|
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) |
|
|
if up_proj_bias is not None: |
|
|
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) |
|
|
if down_proj_bias is not None: |
|
|
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) |
|
|
|
|
|
|
|
|
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) |
|
|
|
|
|
|
|
|
x = activation_fn(x) |
|
|
|
|
|
|
|
|
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def combine_expert_shared_outputs( |
|
|
shared_expert_out: torch.Tensor, |
|
|
expert_out: torch.Tensor, |
|
|
shared_expert_weighted_sum: bool = False, |
|
|
moe_top_k: int = 1, |
|
|
) -> torch.Tensor: |
|
|
if shared_expert_weighted_sum: |
|
|
|
|
|
total_experts = moe_top_k + 1 |
|
|
shared_weight = 1.0 / total_experts |
|
|
expert_weight = moe_top_k / total_experts |
|
|
return shared_expert_out * shared_weight + expert_out * expert_weight |
|
|
else: |
|
|
|
|
|
return shared_expert_out + expert_out |
|
|
|
|
|
|
|
|
|
|
|
_LOAD_BALANCING_LOSS = [] |
|
|
|
|
|
|
|
|
def save_load_balancing_loss(loss): |
|
|
global _LOAD_BALANCING_LOSS |
|
|
_LOAD_BALANCING_LOSS.append(loss) |
|
|
|
|
|
|
|
|
def get_load_balancing_loss(): |
|
|
global _LOAD_BALANCING_LOSS |
|
|
return _LOAD_BALANCING_LOSS |
|
|
|
|
|
|
|
|
def clear_load_balancing_loss(): |
|
|
global _LOAD_BALANCING_LOSS |
|
|
_LOAD_BALANCING_LOSS.clear() |
|
|
|
|
|
|
|
|
def batched_load_balancing_loss(args): |
|
|
if args.moe_loss_weight == 0: |
|
|
return 0.0 |
|
|
|
|
|
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) |
|
|
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size |
|
|
if args.num_layers_per_virtual_pipeline_stage is not None: |
|
|
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage |
|
|
|
|
|
if len(tokens_per_expert) != num_layers_per_pipeline_stage: |
|
|
raise ValueError( |
|
|
f"Expected {num_layers_per_pipeline_stage} token_per_experts " |
|
|
f"but found {len(tokens_per_expert)}.\nnum_layers = " |
|
|
f"{args.num_layers}\npipeline_model_parallel_size = " |
|
|
f"{args.pipeline_model_parallel_size}\n" |
|
|
"num_layers_per_virtual_pipeline_stage" |
|
|
f" = {args.num_layers_per_virtual_pipeline_stage}", |
|
|
) |
|
|
if len(expert_scores) != num_layers_per_pipeline_stage: |
|
|
raise ValueError( |
|
|
f"Expected {num_layers_per_pipeline_stage} expert_scores " |
|
|
f"but found {len(tokens_per_expert)}.\nnum_layers = " |
|
|
f"{args.num_layers}\npipeline_model_parallel_size = " |
|
|
f"{args.pipeline_model_parallel_size}\n" |
|
|
"num_layers_per_virtual_pipeline_stage" |
|
|
f" = {args.num_layers_per_virtual_pipeline_stage}", |
|
|
) |
|
|
|
|
|
|
|
|
assert all( |
|
|
(x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) |
|
|
) |
|
|
|
|
|
tokens = expert_scores[0].shape[0] |
|
|
assert all( |
|
|
( |
|
|
( |
|
|
x.ndim == 2 |
|
|
and x.shape[1] == args.moe_num_experts |
|
|
and x.shape[0] == tokens |
|
|
) |
|
|
for x in expert_scores |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expert_scores = torch.cat(expert_scores, dim=1) |
|
|
if args.moe_lbl_in_fp32: |
|
|
expert_scores = expert_scores.float() |
|
|
if tokens != 0: |
|
|
expert_scores = expert_scores.mean(dim=0) |
|
|
else: |
|
|
expert_scores = expert_scores.sum(dim=0) |
|
|
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) |
|
|
|
|
|
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts |
|
|
assert tokens_per_expert.numel() == expected_values |
|
|
assert expert_scores.numel() == expected_values |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale_numerator = args.moe_num_experts * args.moe_loss_weight |
|
|
scale_denominator = args.num_layers * tokens * args.moe_top_k |
|
|
scale = scale_numerator / scale_denominator |
|
|
return scale * torch.dot(tokens_per_expert, expert_scores) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def expert_capacity( |
|
|
tokens: int, |
|
|
top_k: int, |
|
|
num_experts: int, |
|
|
expert_parallel_group: int, |
|
|
moe_capacity_factor: float, |
|
|
moe_expert_model_parallelism: bool, |
|
|
) -> int: |
|
|
world_size = ( |
|
|
dist.get_world_size(expert_parallel_group) |
|
|
if moe_expert_model_parallelism |
|
|
else 1 |
|
|
) |
|
|
|
|
|
tokens_per_expert = top_k * tokens * world_size / num_experts |
|
|
return int(moe_capacity_factor * tokens_per_expert) |
|
|
|
|
|
|
|
|
def load_balancing_loss( |
|
|
tokens_per_expert: torch.Tensor, |
|
|
expert_scores: torch.Tensor, |
|
|
top_k: int, |
|
|
num_experts: int, |
|
|
): |
|
|
assert len(expert_scores.size()) == 2 |
|
|
tokens, num_experts = expert_scores.size() |
|
|
assert num_experts == num_experts |
|
|
assert len(tokens_per_expert.size()) == 1 |
|
|
(num_experts,) = tokens_per_expert.size() |
|
|
assert num_experts == num_experts |
|
|
scale = num_experts / (tokens * top_k) |
|
|
return scale * torch.dot( |
|
|
tokens_per_expert.to(expert_scores.dtype), |
|
|
expert_scores.mean(dim=0), |
|
|
) |
|
|
|
|
|
|
|
|
def indices_and_bins( |
|
|
top_expert: torch.Tensor, |
|
|
sort_end_bit: int, |
|
|
num_experts: int, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
top_expert = top_expert.int() |
|
|
|
|
|
|
|
|
top_expert = top_expert.contiguous() |
|
|
|
|
|
|
|
|
with torch.cuda.device(top_expert.device): |
|
|
output = ops.sort(top_expert, sort_end_bit) |
|
|
bin_ids, indices = output |
|
|
tokens_per_expert = ops.histogram(top_expert, num_experts) |
|
|
bins = ops.inclusive_cumsum(tokens_per_expert, 0) |
|
|
|
|
|
bins = bins.view(1) if not len(bins.size()) else bins |
|
|
return indices, bin_ids, bins, tokens_per_expert |
|
|
|
|
|
|
|
|
def expert_capacity_fn( |
|
|
tokens: int, |
|
|
top_k: int, |
|
|
num_experts: int, |
|
|
expert_parallel_group: torch.distributed.ProcessGroup, |
|
|
moe_capacity_factor: float = 1.0, |
|
|
moe_expert_model_parallelism: bool = False, |
|
|
) -> int: |
|
|
world_size = ( |
|
|
dist.get_world_size(expert_parallel_group) |
|
|
if moe_expert_model_parallelism |
|
|
else 1 |
|
|
) |
|
|
tokens_per_expert = top_k * tokens * world_size / num_experts |
|
|
return int(moe_capacity_factor * tokens_per_expert) |
|
|
|
|
|
|
|
|
def permute_and_compute( |
|
|
x, |
|
|
tokens_per_expert, |
|
|
indices, |
|
|
bin_ids, |
|
|
expert_weights, |
|
|
bins, |
|
|
expert_capacity, |
|
|
top_k, |
|
|
w1, |
|
|
w2, |
|
|
w1_bias, |
|
|
w2_bias, |
|
|
gradient_scale, |
|
|
alpha, |
|
|
): |
|
|
|
|
|
x = x.view(-1, x.shape[-1]) |
|
|
|
|
|
|
|
|
with torch.cuda.device(x.device): |
|
|
x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) |
|
|
|
|
|
|
|
|
x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) |
|
|
|
|
|
|
|
|
with torch.cuda.device(x.device): |
|
|
|
|
|
out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) |
|
|
return out |
|
|
|
|
|
|
|
|
def forward_once( |
|
|
x: torch.Tensor, |
|
|
expert_weights: torch.Tensor, |
|
|
top_experts: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w1_bias: torch.Tensor, |
|
|
w2_bias: torch.Tensor, |
|
|
gradient_scale: Optional[float] = None, |
|
|
alpha: float = 1.702, |
|
|
sort_end_bit: int = 0, |
|
|
top_k: int = 4, |
|
|
num_experts: int = 128, |
|
|
expert_parallel_group: int = None, |
|
|
moe_capacity_factor: float = 1.0, |
|
|
moe_expert_model_parallelism: bool = False, |
|
|
mlp_impl: Optional[str] = None, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
expert_weights = expert_weights.flatten() |
|
|
top_experts = top_experts.flatten() |
|
|
|
|
|
with torch.no_grad(): |
|
|
indices, bin_ids, bins, tokens_per_expert = indices_and_bins( |
|
|
top_experts, sort_end_bit, num_experts |
|
|
) |
|
|
|
|
|
|
|
|
sl, bs, _ = x.size() |
|
|
|
|
|
expert_capacity = expert_capacity_fn( |
|
|
sl * bs, |
|
|
top_k, |
|
|
num_experts, |
|
|
expert_parallel_group, |
|
|
moe_capacity_factor, |
|
|
moe_expert_model_parallelism, |
|
|
) |
|
|
|
|
|
if expert_capacity == 0: |
|
|
expert_capacity = torch.max(tokens_per_expert).item() |
|
|
|
|
|
x = permute_and_compute( |
|
|
x, |
|
|
tokens_per_expert, |
|
|
indices, |
|
|
bin_ids, |
|
|
expert_weights, |
|
|
bins, |
|
|
expert_capacity, |
|
|
top_k, |
|
|
w1, |
|
|
w2, |
|
|
w1_bias, |
|
|
w2_bias, |
|
|
gradient_scale, |
|
|
alpha, |
|
|
) |
|
|
return x, tokens_per_expert |
|
|
|
|
|
|
|
|
def parallel_forward_once( |
|
|
x: torch.Tensor, |
|
|
expert_weights: torch.Tensor, |
|
|
top_experts: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w1_bias: torch.Tensor, |
|
|
w2_bias: torch.Tensor, |
|
|
gradient_scale: Optional[float] = None, |
|
|
alpha: float = 1.702, |
|
|
sort_end_bit: int = 0, |
|
|
top_k: int = 4, |
|
|
num_experts: int = 128, |
|
|
expert_parallel_group: torch.distributed.ProcessGroup = None, |
|
|
moe_capacity_factor: float = 1.0, |
|
|
moe_expert_model_parallelism: bool = True, |
|
|
hidden_size: int = 1152, |
|
|
mlp_impl: Optional[str] = "grouped", |
|
|
): |
|
|
|
|
|
expert_weights = expert_weights.flatten() |
|
|
top_experts = top_experts.flatten() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
indices, bin_ids, bins, tokens_per_expert = indices_and_bins( |
|
|
top_experts, sort_end_bit, num_experts |
|
|
) |
|
|
|
|
|
|
|
|
world_size = dist.get_world_size(expert_parallel_group) |
|
|
hidden_sharding_deg = hidden_sharding_degree( |
|
|
world_size, num_experts, hidden_size |
|
|
) |
|
|
experts_per_rank_val = experts_per_rank(num_experts, world_size) |
|
|
|
|
|
|
|
|
repeated_tokens_per_expert = ops.repeat( |
|
|
tokens_per_expert, (hidden_sharding_deg,) |
|
|
) |
|
|
|
|
|
|
|
|
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) |
|
|
|
|
|
|
|
|
tpe_handle = dist.all_to_all_single( |
|
|
parallel_tokens_per_expert, |
|
|
repeated_tokens_per_expert, |
|
|
group=expert_parallel_group, |
|
|
async_op=True, |
|
|
) |
|
|
|
|
|
|
|
|
x = x.view(-1, x.shape[-1]) |
|
|
x = ops.gather(x, indices, bin_ids, bins, top_k) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
tpe_handle.wait() |
|
|
|
|
|
|
|
|
repeated_tokens_per_expert = repeated_tokens_per_expert.view( |
|
|
world_size, experts_per_rank_val |
|
|
) |
|
|
parallel_tokens_per_expert = parallel_tokens_per_expert.view( |
|
|
world_size, experts_per_rank_val |
|
|
) |
|
|
|
|
|
|
|
|
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() |
|
|
|
|
|
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() |
|
|
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() |
|
|
tokens_received = sum(recv_counts) |
|
|
|
|
|
|
|
|
x = ops.repeat(x, (hidden_sharding_deg, 1)) |
|
|
|
|
|
|
|
|
parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( |
|
|
x, recv_counts, send_counts, expert_parallel_group, async_op=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) |
|
|
replicate_bins = ( |
|
|
replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins |
|
|
) |
|
|
|
|
|
|
|
|
parallel_top_expert = torch.remainder( |
|
|
torch.arange( |
|
|
num_experts * hidden_sharding_deg, |
|
|
dtype=torch.int32, |
|
|
device=indices.device, |
|
|
), |
|
|
experts_per_rank_val, |
|
|
) |
|
|
parallel_top_expert = ops.replicate( |
|
|
parallel_top_expert.unsqueeze(dim=0), |
|
|
replicate_bins, |
|
|
tokens_received, |
|
|
).flatten() |
|
|
|
|
|
|
|
|
parallel_bin_ids, parallel_indices = ops.sort( |
|
|
parallel_top_expert, |
|
|
sort_end_bit, |
|
|
) |
|
|
|
|
|
|
|
|
parallel_tokens_per_expert = parallel_tokens_per_expert.sum( |
|
|
dim=0, dtype=torch.int |
|
|
) |
|
|
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) |
|
|
parallel_bins = ( |
|
|
parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins |
|
|
) |
|
|
|
|
|
|
|
|
expert_capacity = expert_capacity_fn( |
|
|
tokens_received, |
|
|
top_k, |
|
|
experts_per_rank_val, |
|
|
expert_parallel_group, |
|
|
moe_capacity_factor, |
|
|
moe_expert_model_parallelism, |
|
|
) |
|
|
if expert_capacity == 0: |
|
|
expert_capacity = torch.max(parallel_tokens_per_expert).item() |
|
|
|
|
|
|
|
|
|
|
|
if mlp_impl == "grouped": |
|
|
|
|
|
|
|
|
|
|
|
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( |
|
|
dim=0, |
|
|
dtype=torch.int, |
|
|
) |
|
|
|
|
|
|
|
|
parallel_x_handle.wait() |
|
|
|
|
|
parallel_x = permute_and_compute( |
|
|
parallel_x, |
|
|
parallel_tokens_per_expert, |
|
|
parallel_indices, |
|
|
parallel_bin_ids, |
|
|
None, |
|
|
parallel_bins, |
|
|
expert_capacity, |
|
|
top_k=1, |
|
|
w1=w1, |
|
|
w2=w2, |
|
|
w1_bias=w1_bias, |
|
|
w2_bias=w2_bias, |
|
|
gradient_scale=gradient_scale, |
|
|
alpha=alpha, |
|
|
) |
|
|
|
|
|
|
|
|
x, _ = _layers.all_to_all.all_to_all( |
|
|
parallel_x, send_counts, recv_counts, expert_parallel_group |
|
|
) |
|
|
|
|
|
|
|
|
shape = (hidden_sharding_deg, -1, hidden_size) |
|
|
x = x.view(shape).sum(dim=0) |
|
|
|
|
|
|
|
|
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) |
|
|
|
|
|
return x, tokens_per_expert.flatten() |
|
|
|
|
|
|
|
|
def moe_forward( |
|
|
x: torch.Tensor, |
|
|
router_weight: torch.Tensor, |
|
|
router_bias: Optional[torch.Tensor], |
|
|
moe_top_k: int, |
|
|
moe_num_experts: int, |
|
|
moe_jitter_eps: float = None, |
|
|
moe_normalize_expert_weights: int = None, |
|
|
uniform_expert_assignment: bool = False, |
|
|
training: bool = False, |
|
|
w1: torch.Tensor = None, |
|
|
w2: torch.Tensor = None, |
|
|
w1_bias: torch.Tensor = None, |
|
|
w2_bias: torch.Tensor = None, |
|
|
gradient_scale: Optional[float] = None, |
|
|
alpha: float = 1.702, |
|
|
sort_end_bit: int = 0, |
|
|
expert_parallel_group: torch.distributed.ProcessGroup = None, |
|
|
moe_capacity_factor: float = 1.0, |
|
|
moe_expert_model_parallelism: bool = False, |
|
|
forward_fn: Any = None, |
|
|
hidden_size: int = None, |
|
|
mlp_impl: str = "grouped", |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
logits, expert_weights, expert_indices = route_tokens( |
|
|
x, |
|
|
router_weight, |
|
|
router_bias, |
|
|
moe_top_k, |
|
|
moe_num_experts, |
|
|
moe_jitter_eps, |
|
|
moe_normalize_expert_weights, |
|
|
uniform_expert_assignment, |
|
|
training, |
|
|
) |
|
|
|
|
|
|
|
|
router_scores = ( |
|
|
torch.zeros_like(logits) |
|
|
.scatter_(1, expert_indices, expert_weights) |
|
|
.transpose(0, 1) |
|
|
) |
|
|
|
|
|
in_shape = x.size() |
|
|
|
|
|
|
|
|
forward_args = { |
|
|
"x": x, |
|
|
"expert_weights": expert_weights, |
|
|
"top_experts": expert_indices, |
|
|
"w1": w1, |
|
|
"w2": w2, |
|
|
"w1_bias": w1_bias, |
|
|
"w2_bias": w2_bias, |
|
|
"gradient_scale": gradient_scale, |
|
|
"alpha": alpha, |
|
|
"sort_end_bit": sort_end_bit, |
|
|
"top_k": moe_top_k, |
|
|
"num_experts": moe_num_experts, |
|
|
"expert_parallel_group": expert_parallel_group, |
|
|
"moe_capacity_factor": moe_capacity_factor, |
|
|
"moe_expert_model_parallelism": moe_expert_model_parallelism, |
|
|
"mlp_impl": mlp_impl, |
|
|
} |
|
|
|
|
|
|
|
|
if moe_expert_model_parallelism and hidden_size is not None: |
|
|
forward_args["hidden_size"] = hidden_size |
|
|
elif moe_expert_model_parallelism and hidden_size is None: |
|
|
|
|
|
forward_args["hidden_size"] = x.shape[-1] |
|
|
|
|
|
|
|
|
x, tokens_per_expert = forward_fn(**forward_args) |
|
|
|
|
|
|
|
|
moe_loss_weight = 0.0 |
|
|
if training and moe_loss_weight > 0: |
|
|
save_load_balancing_loss((tokens_per_expert, logits)) |
|
|
|
|
|
|
|
|
x = x.view(in_shape) |
|
|
|
|
|
return x, expert_weights, router_scores |
|
|
|
|
|
|
|
|
def moe_forward_with_shared_expert( |
|
|
x: torch.Tensor, |
|
|
router_weight: torch.Tensor, |
|
|
router_bias: Optional[torch.Tensor], |
|
|
moe_top_k: int, |
|
|
moe_num_experts: int, |
|
|
moe_jitter_eps: float = None, |
|
|
moe_normalize_expert_weights: int = None, |
|
|
uniform_expert_assignment: bool = False, |
|
|
training: bool = False, |
|
|
w1: torch.Tensor = None, |
|
|
w2: torch.Tensor = None, |
|
|
w1_bias: torch.Tensor = None, |
|
|
w2_bias: torch.Tensor = None, |
|
|
gradient_scale: Optional[float] = None, |
|
|
alpha: float = 1.702, |
|
|
sort_end_bit: int = 0, |
|
|
expert_parallel_group: torch.distributed.ProcessGroup = None, |
|
|
moe_capacity_factor: float = 1.0, |
|
|
moe_expert_model_parallelism: bool = False, |
|
|
forward_fn: Any = None, |
|
|
hidden_size: int = None, |
|
|
mlp_impl: str = "grouped", |
|
|
|
|
|
shared_up_proj_weight: Optional[torch.Tensor] = None, |
|
|
shared_down_proj_weight: Optional[torch.Tensor] = None, |
|
|
shared_up_proj_bias: Optional[torch.Tensor] = None, |
|
|
shared_down_proj_bias: Optional[torch.Tensor] = None, |
|
|
shared_expert_weighted_sum: bool = False, |
|
|
shared_activation_fn: Optional[Any] = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
expert_out, expert_weights, router_scores = moe_forward( |
|
|
x=x, |
|
|
router_weight=router_weight, |
|
|
router_bias=router_bias, |
|
|
moe_top_k=moe_top_k, |
|
|
moe_num_experts=moe_num_experts, |
|
|
moe_jitter_eps=moe_jitter_eps, |
|
|
moe_normalize_expert_weights=moe_normalize_expert_weights, |
|
|
uniform_expert_assignment=uniform_expert_assignment, |
|
|
training=training, |
|
|
w1=w1, |
|
|
w2=w2, |
|
|
w1_bias=w1_bias, |
|
|
w2_bias=w2_bias, |
|
|
gradient_scale=gradient_scale, |
|
|
alpha=alpha, |
|
|
sort_end_bit=sort_end_bit, |
|
|
expert_parallel_group=expert_parallel_group, |
|
|
moe_capacity_factor=moe_capacity_factor, |
|
|
moe_expert_model_parallelism=moe_expert_model_parallelism, |
|
|
forward_fn=forward_fn, |
|
|
hidden_size=hidden_size, |
|
|
mlp_impl=mlp_impl, |
|
|
) |
|
|
|
|
|
|
|
|
if shared_up_proj_weight is not None and shared_down_proj_weight is not None: |
|
|
shared_expert_out = shared_mlp_forward( |
|
|
x=x, |
|
|
up_proj_weight=shared_up_proj_weight, |
|
|
down_proj_weight=shared_down_proj_weight, |
|
|
up_proj_bias=shared_up_proj_bias, |
|
|
down_proj_bias=shared_down_proj_bias, |
|
|
activation_fn=shared_activation_fn, |
|
|
gradient_scale=gradient_scale, |
|
|
) |
|
|
|
|
|
|
|
|
combined_out = combine_expert_shared_outputs( |
|
|
shared_expert_out=shared_expert_out, |
|
|
expert_out=expert_out, |
|
|
shared_expert_weighted_sum=shared_expert_weighted_sum, |
|
|
moe_top_k=moe_top_k, |
|
|
) |
|
|
|
|
|
return combined_out, expert_weights, router_scores |
|
|
|
|
|
|
|
|
return expert_out, expert_weights, router_scores |
|
|
|
|
|
|
|
|
def create_shared_expert_weights( |
|
|
hidden_size: int, |
|
|
shared_expert_hidden_size: int, |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
init_method: Any, |
|
|
output_layer_init_method: Any = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
|
|
|
if output_layer_init_method is None: |
|
|
output_layer_init_method = init_method |
|
|
|
|
|
|
|
|
up_proj_weight = torch.empty( |
|
|
shared_expert_hidden_size, |
|
|
hidden_size, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
down_proj_weight = torch.empty( |
|
|
hidden_size, |
|
|
shared_expert_hidden_size, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
init_method(up_proj_weight) |
|
|
output_layer_init_method(down_proj_weight) |
|
|
|
|
|
|
|
|
return up_proj_weight, down_proj_weight, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_mesh(model): |
|
|
|
|
|
try: |
|
|
|
|
|
hook = next( |
|
|
h |
|
|
for h in model.experts._forward_pre_hooks.values() |
|
|
if "device_mesh" in h.__code__.co_freevars |
|
|
) |
|
|
|
|
|
return hook.__closure__[ |
|
|
hook.__code__.co_freevars.index("device_mesh") |
|
|
].cell_contents |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
class MegaBlocksMoeMLP(torch.nn.Module): |
|
|
can_torch_compile: bool = True |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
moe_top_k = getattr(self.router, "top_k", 4) |
|
|
moe_num_experts = getattr(self.experts, "num_experts", 128) |
|
|
gradient_scale = getattr(self.experts, "gradient_scale", None) |
|
|
alpha = getattr(self.experts, "alpha", 1.0) |
|
|
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) |
|
|
moe_jitter_eps = getattr(self.experts, "jitter_eps", None) |
|
|
moe_normalize_expert_weights = getattr( |
|
|
self.experts, "normalize_expert_weights", None |
|
|
) |
|
|
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) |
|
|
|
|
|
expert_parallel_group = getattr(self, "expert_parallel_group", None) |
|
|
if expert_parallel_group is None: |
|
|
device_mesh = get_device_mesh(self) |
|
|
expert_parallel_group = device_mesh.get_group() if device_mesh else None |
|
|
|
|
|
has_parallel = ( |
|
|
expert_parallel_group is not None |
|
|
and dist.is_initialized() |
|
|
and dist.get_world_size(expert_parallel_group) > 1 |
|
|
) |
|
|
forward_fn = parallel_forward_once if has_parallel else forward_once |
|
|
|
|
|
sort_end_bit = max( |
|
|
int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1 |
|
|
) |
|
|
mlp_impl = getattr(self, "mlp_impl", "grouped") |
|
|
output, expert_weights_out, *_ = moe_forward( |
|
|
x=x, |
|
|
router_weight=self.router.weight, |
|
|
router_bias=self.router.bias, |
|
|
moe_top_k=moe_top_k, |
|
|
moe_num_experts=moe_num_experts, |
|
|
moe_jitter_eps=moe_jitter_eps, |
|
|
moe_normalize_expert_weights=moe_normalize_expert_weights, |
|
|
uniform_expert_assignment=uniform_expert_assignment, |
|
|
training=self.training, |
|
|
w1=self.experts.gate_up_proj, |
|
|
w2=self.experts.down_proj, |
|
|
w1_bias=self.experts.gate_up_proj_bias, |
|
|
w2_bias=self.experts.down_proj_bias, |
|
|
gradient_scale=gradient_scale, |
|
|
alpha=alpha, |
|
|
sort_end_bit=sort_end_bit, |
|
|
expert_parallel_group=expert_parallel_group, |
|
|
moe_capacity_factor=moe_capacity_factor, |
|
|
moe_expert_model_parallelism=has_parallel, |
|
|
forward_fn=forward_fn, |
|
|
hidden_size=self.experts.hidden_size, |
|
|
mlp_impl=mlp_impl, |
|
|
) |
|
|
return output, expert_weights_out |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"] |
|
|
|
|
|
|
|
|
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.shared_up_proj_weight = None |
|
|
self.shared_down_proj_weight = None |
|
|
self.shared_up_proj_bias = None |
|
|
self.shared_down_proj_bias = None |
|
|
self.shared_expert_weighted_sum = False |
|
|
self.shared_activation_fn = None |
|
|
|
|
|
def set_shared_expert_weights( |
|
|
self, |
|
|
up_proj_weight: torch.Tensor, |
|
|
down_proj_weight: torch.Tensor, |
|
|
up_proj_bias: Optional[torch.Tensor] = None, |
|
|
down_proj_bias: Optional[torch.Tensor] = None, |
|
|
weighted_sum: bool = False, |
|
|
activation_fn: Optional[Any] = None, |
|
|
): |
|
|
self.shared_up_proj_weight = up_proj_weight |
|
|
self.shared_down_proj_weight = down_proj_weight |
|
|
self.shared_up_proj_bias = up_proj_bias |
|
|
self.shared_down_proj_bias = down_proj_bias |
|
|
self.shared_expert_weighted_sum = weighted_sum |
|
|
self.shared_activation_fn = activation_fn |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
moe_top_k = getattr(self.router, "top_k", 4) |
|
|
moe_num_experts = getattr(self.experts, "num_experts", 128) |
|
|
gradient_scale = getattr(self.experts, "gradient_scale", None) |
|
|
alpha = getattr(self.experts, "alpha", 1.0) |
|
|
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) |
|
|
moe_jitter_eps = getattr(self.experts, "jitter_eps", None) |
|
|
moe_normalize_expert_weights = getattr( |
|
|
self.experts, "normalize_expert_weights", None |
|
|
) |
|
|
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) |
|
|
|
|
|
expert_parallel_group = getattr(self, "expert_parallel_group", None) |
|
|
if expert_parallel_group is None: |
|
|
device_mesh = get_device_mesh(self) |
|
|
expert_parallel_group = device_mesh.get_group() if device_mesh else None |
|
|
|
|
|
has_parallel = ( |
|
|
expert_parallel_group is not None |
|
|
and dist.is_initialized() |
|
|
and dist.get_world_size(expert_parallel_group) > 1 |
|
|
) |
|
|
forward_fn = parallel_forward_once if has_parallel else forward_once |
|
|
|
|
|
sort_end_bit = max( |
|
|
int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1 |
|
|
) |
|
|
mlp_impl = getattr(self, "mlp_impl", "grouped") |
|
|
|
|
|
output, expert_weights_out, *_ = moe_forward_with_shared_expert( |
|
|
x=x, |
|
|
router_weight=self.router.weight, |
|
|
router_bias=self.router.bias, |
|
|
moe_top_k=moe_top_k, |
|
|
moe_num_experts=moe_num_experts, |
|
|
moe_jitter_eps=moe_jitter_eps, |
|
|
moe_normalize_expert_weights=moe_normalize_expert_weights, |
|
|
uniform_expert_assignment=uniform_expert_assignment, |
|
|
training=self.training, |
|
|
w1=self.experts.gate_up_proj, |
|
|
w2=self.experts.down_proj, |
|
|
w1_bias=self.experts.gate_up_proj_bias, |
|
|
w2_bias=self.experts.down_proj_bias, |
|
|
gradient_scale=gradient_scale, |
|
|
alpha=alpha, |
|
|
sort_end_bit=sort_end_bit, |
|
|
expert_parallel_group=expert_parallel_group, |
|
|
moe_capacity_factor=moe_capacity_factor, |
|
|
moe_expert_model_parallelism=has_parallel, |
|
|
forward_fn=forward_fn, |
|
|
hidden_size=self.experts.hidden_size, |
|
|
mlp_impl=mlp_impl, |
|
|
|
|
|
shared_up_proj_weight=self.shared_up_proj_weight, |
|
|
shared_down_proj_weight=self.shared_down_proj_weight, |
|
|
shared_up_proj_bias=self.shared_up_proj_bias, |
|
|
shared_down_proj_bias=self.shared_down_proj_bias, |
|
|
shared_expert_weighted_sum=self.shared_expert_weighted_sum, |
|
|
shared_activation_fn=self.shared_activation_fn, |
|
|
) |
|
|
return output, expert_weights_out |
|
|
|