drbh
fix: include torch compile flag
8176cbe
import torch
import torch.distributed as dist
from typing import Optional, Any, TYPE_CHECKING
from . import _layers
from . import ops
# Conditional import for meta kernel registration
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
try:
from torch.library import impl_abstract as register_fake
except ImportError:
# Fallback for older PyTorch versions
def register_fake(op_name):
def decorator(fn):
return fn
return decorator
# Meta kernel implementations for torch.compile compatibility
def _install_meta_kernels():
"""Install meta kernels for existing MegaBlocks operations"""
# Create wrapper functions that check for compilation and return meta tensors
# Patch ops.sort
if hasattr(ops, "sort"):
original_sort = ops.sort
def sort_with_meta(x, end_bit=None):
if torch.compiler.is_compiling():
print("Using meta kernel for sort")
# Meta implementation - return tensors with correct shape/dtype/device
return torch.empty_like(x), torch.empty_like(x)
# print("Using original sort kernel")
return original_sort(x, end_bit)
ops.sort = sort_with_meta
# Patch ops.histogram
if hasattr(ops, "histogram"):
original_histogram = ops.histogram
def histogram_with_meta(x, max_val):
if torch.compiler.is_compiling():
# Meta implementation
return torch.empty((max_val,), dtype=torch.int32, device=x.device)
return original_histogram(x, max_val)
ops.histogram = histogram_with_meta
# Patch ops.inclusive_cumsum
if hasattr(ops, "inclusive_cumsum"):
original_inclusive_cumsum = ops.inclusive_cumsum
def inclusive_cumsum_with_meta(x, dim):
if torch.compiler.is_compiling():
# Meta implementation
return torch.empty_like(x)
return original_inclusive_cumsum(x, dim)
ops.inclusive_cumsum = inclusive_cumsum_with_meta
# Patch ops.binned_gather
if hasattr(ops, "binned_gather"):
original_binned_gather = ops.binned_gather
def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
if torch.compiler.is_compiling():
# Meta implementation - output shape based on bin_size
if x.dim() >= 2:
hidden_size = x.size(-1)
return torch.empty(
(bin_size, x.size(1), hidden_size),
dtype=x.dtype,
device=x.device,
)
else:
return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
return original_binned_gather(x, indices, bins, bin_size, top_k)
ops.binned_gather = binned_gather_with_meta
# Patch ops.binned_scatter
if hasattr(ops, "binned_scatter"):
original_binned_scatter = ops.binned_scatter
def binned_scatter_with_meta(x, indices, weights, bins, top_k):
if torch.compiler.is_compiling():
# Meta implementation - typically reduces to 2D
if x.dim() >= 3:
return torch.empty(
(x.size(1), x.size(2)), dtype=x.dtype, device=x.device
)
else:
return torch.empty_like(x)
return original_binned_scatter(x, indices, weights, bins, top_k)
ops.binned_scatter = binned_scatter_with_meta
# Patch ops.gather
if hasattr(ops, "gather"):
original_gather = ops.gather
def gather_with_meta(x, indices, bin_ids, bins, top_k):
if torch.compiler.is_compiling():
# Meta implementation
if x.dim() >= 2:
hidden_size = x.size(-1)
return torch.empty(
(indices.numel(), hidden_size), dtype=x.dtype, device=x.device
)
else:
return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
return original_gather(x, indices, bin_ids, bins, top_k)
ops.gather = gather_with_meta
# Patch ops.scatter
if hasattr(ops, "scatter"):
original_scatter = ops.scatter
def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
if torch.compiler.is_compiling():
# Meta implementation - restore sequence shape
seq_len = (
indices.size(0) // top_k
if indices.numel() > 0 and top_k > 0
else x.size(0)
)
if x.dim() >= 2:
return torch.empty(
(seq_len, x.size(-1)), dtype=x.dtype, device=x.device
)
else:
return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
return original_scatter(x, indices, bin_ids, weights, bins, top_k)
ops.scatter = scatter_with_meta
# Patch ops.replicate
if hasattr(ops, "replicate"):
original_replicate = ops.replicate
def replicate_with_meta(x, bins, num_outputs):
if torch.compiler.is_compiling():
# Meta implementation
return torch.empty(
(x.shape[0], num_outputs), dtype=x.dtype, device=x.device
)
return original_replicate(x, bins, num_outputs)
ops.replicate = replicate_with_meta
# Patch ops.repeat (if it's a regular function)
if hasattr(ops, "repeat"):
original_repeat = ops.repeat
def repeat_with_meta(x, repeats):
if torch.compiler.is_compiling():
# Meta implementation
if isinstance(repeats, (tuple, list)):
new_shape = list(x.shape)
for i, rep in enumerate(repeats):
if i < len(new_shape):
new_shape[i] *= rep
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
else:
new_shape = [x.size(0) * repeats] + list(x.shape[1:])
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
return original_repeat(x, repeats)
ops.repeat = repeat_with_meta
# Install meta kernels on import
try:
_install_meta_kernels()
except Exception as e:
# If meta kernel installation fails, continue without them
# torch.compile may not work but the library will still function
import warnings
warnings.warn(
f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
)
# Set the expert model parallel attributes on a tensor
def set_expert_model_parallel_attributes(
tensor: torch.Tensor,
is_parallel: bool,
):
assert not hasattr(tensor, "expert_model_parallel")
setattr(tensor, "expert_model_parallel", is_parallel)
# Get the expert model parallel attributes from a tensor
def expert_sharding_degree(
world_size: int,
moe_num_experts: int,
) -> int:
esd = min(world_size, moe_num_experts)
if (moe_num_experts % esd) != 0:
raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
return esd
# Calculate the hidden sharding degree based on world size and expert sharding degree
def hidden_sharding_degree(
world_size: int,
moe_num_experts: int,
ffn_hidden_size: int,
) -> int:
esd = expert_sharding_degree(world_size, moe_num_experts)
hsd = world_size // esd
if (ffn_hidden_size % hsd) != 0:
raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
if (esd * hsd) != world_size:
raise ValueError(
f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
)
return hsd
# Calculate the number of experts per rank based on world size and expert sharding degree
def experts_per_rank(
moe_num_experts: int,
world_size: int,
) -> int:
return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
def features_per_rank(
ffn_hidden_size: int, world_size: int, moe_num_experts: int
) -> int:
return ffn_hidden_size // hidden_sharding_degree(
world_size, moe_num_experts, ffn_hidden_size
)
# Apply jitter to the input tensor
def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
low = 1.0 - moe_jitter_eps
high = 1.0 + moe_jitter_eps
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
return x * (low + noise * (high - low))
# Compute the top-k scores from the logits
def compute_top_k(scores: torch.Tensor, moe_top_k: int):
if moe_top_k == 1:
return scores.max(dim=-1, keepdim=True)
return torch.topk(scores, moe_top_k, dim=-1)
# Route tokens to experts and compute expert weights and indices
def route_tokens(
x: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
moe_top_k: int,
moe_num_experts: int,
moe_jitter_eps: float = None,
moe_normalize_expert_weights: int = None,
uniform_expert_assignment: bool = False,
training: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if training and moe_jitter_eps is not None:
x = apply_jitter(x, moe_jitter_eps)
x_flat = x.view(-1, x.shape[-1])
logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
expert_weights = expert_weights.softmax(dim=-1)
if moe_normalize_expert_weights is not None:
expert_weights = expert_weights / torch.norm(
expert_weights,
p=moe_normalize_expert_weights,
dim=-1,
keepdim=True,
)
if uniform_expert_assignment:
expert_indices = _layers.router._uniform_expert_assignment(
expert_indices,
moe_num_experts,
)
return logits, expert_weights, expert_indices
# Scale the gradient of the weights
def scale_grad(
w: torch.Tensor,
gradient_scale: Optional[float] = None,
) -> torch.Tensor:
if gradient_scale is None:
return w
return _layers.mlp.scale_gradient(w, gradient_scale)
# Forward pass for the MLP layer
def mlp_forward(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_bias: torch.Tensor,
w2_bias: torch.Tensor,
gradient_scale: Optional[float] = None,
alpha: float = 1.702,
limit: float = 7.0,
):
# Scale weights
w1 = scale_grad(w1, gradient_scale)
w2 = scale_grad(w2, gradient_scale)
w1_bias = scale_grad(w1_bias, gradient_scale)
w2_bias = scale_grad(w2_bias, gradient_scale)
# Resolve dtensors
w1 = _layers.mlp.resolve_dtensor(w1)
w2 = _layers.mlp.resolve_dtensor(w2)
w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
# Forward pass
gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
next_states = torch.bmm(((up + 1) * glu), w2)
next_states += w2_bias[..., None, :]
return next_states
# Shared expert MLP forward pass
def shared_mlp_forward(
x: torch.Tensor,
up_proj_weight: torch.Tensor,
down_proj_weight: torch.Tensor,
up_proj_bias: Optional[torch.Tensor] = None,
down_proj_bias: Optional[torch.Tensor] = None,
activation_fn: Optional[Any] = None,
gradient_scale: Optional[float] = None,
) -> torch.Tensor:
# Default activation function
if activation_fn is None:
activation_fn = torch.nn.functional.gelu
# Scale weights
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
if up_proj_bias is not None:
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
if down_proj_bias is not None:
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
# Resolve dtensors
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
if up_proj_bias is not None:
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
if down_proj_bias is not None:
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
# Up projection
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
# Activation
x = activation_fn(x)
# Down projection
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
return x
# Combine outputs from shared expert and regular experts
def combine_expert_shared_outputs(
shared_expert_out: torch.Tensor,
expert_out: torch.Tensor,
shared_expert_weighted_sum: bool = False,
moe_top_k: int = 1,
) -> torch.Tensor:
if shared_expert_weighted_sum:
# Weighted sum based on number of experts used
total_experts = moe_top_k + 1
shared_weight = 1.0 / total_experts
expert_weight = moe_top_k / total_experts
return shared_expert_out * shared_weight + expert_out * expert_weight
else:
# Simple addition
return shared_expert_out + expert_out
# Global variable to store load balancing loss
_LOAD_BALANCING_LOSS = []
def save_load_balancing_loss(loss):
global _LOAD_BALANCING_LOSS
_LOAD_BALANCING_LOSS.append(loss)
def get_load_balancing_loss():
global _LOAD_BALANCING_LOSS
return _LOAD_BALANCING_LOSS
def clear_load_balancing_loss():
global _LOAD_BALANCING_LOSS
_LOAD_BALANCING_LOSS.clear()
def batched_load_balancing_loss(args):
if args.moe_loss_weight == 0:
return 0.0
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
if args.num_layers_per_virtual_pipeline_stage is not None:
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
if len(tokens_per_expert) != num_layers_per_pipeline_stage:
raise ValueError(
f"Expected {num_layers_per_pipeline_stage} token_per_experts "
f"but found {len(tokens_per_expert)}.\nnum_layers = "
f"{args.num_layers}\npipeline_model_parallel_size = "
f"{args.pipeline_model_parallel_size}\n"
"num_layers_per_virtual_pipeline_stage"
f" = {args.num_layers_per_virtual_pipeline_stage}",
)
if len(expert_scores) != num_layers_per_pipeline_stage:
raise ValueError(
f"Expected {num_layers_per_pipeline_stage} expert_scores "
f"but found {len(tokens_per_expert)}.\nnum_layers = "
f"{args.num_layers}\npipeline_model_parallel_size = "
f"{args.pipeline_model_parallel_size}\n"
"num_layers_per_virtual_pipeline_stage"
f" = {args.num_layers_per_virtual_pipeline_stage}",
)
# Verify the shape of the tokens_per_expert and expert_scores tensors.
assert all(
(x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
)
tokens = expert_scores[0].shape[0]
assert all(
(
(
x.ndim == 2
and x.shape[1] == args.moe_num_experts
and x.shape[0] == tokens
)
for x in expert_scores
)
)
# Concatenate the contributions of each layer and convert to
# the correct types and formats for the dot product.
expert_scores = torch.cat(expert_scores, dim=1)
if args.moe_lbl_in_fp32:
expert_scores = expert_scores.float()
if tokens != 0:
expert_scores = expert_scores.mean(dim=0)
else:
expert_scores = expert_scores.sum(dim=0)
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
assert tokens_per_expert.numel() == expected_values
assert expert_scores.numel() == expected_values
# Calculate the total scale across all factors.
#
# loss_weight * num_experts / (num_layers * tokens * top_k)
scale_numerator = args.moe_num_experts * args.moe_loss_weight
scale_denominator = args.num_layers * tokens * args.moe_top_k
scale = scale_numerator / scale_denominator
return scale * torch.dot(tokens_per_expert, expert_scores)
# Calculate the expert capacity based on tokens, top_k, number of experts,
# expert parallel group, capacity factor, and whether expert model parallelism is used.
def expert_capacity(
tokens: int,
top_k: int,
num_experts: int,
expert_parallel_group: int,
moe_capacity_factor: float,
moe_expert_model_parallelism: bool,
) -> int:
world_size = (
dist.get_world_size(expert_parallel_group)
if moe_expert_model_parallelism
else 1
)
tokens_per_expert = top_k * tokens * world_size / num_experts
return int(moe_capacity_factor * tokens_per_expert)
def load_balancing_loss(
tokens_per_expert: torch.Tensor,
expert_scores: torch.Tensor,
top_k: int,
num_experts: int,
):
assert len(expert_scores.size()) == 2
tokens, num_experts = expert_scores.size()
assert num_experts == num_experts
assert len(tokens_per_expert.size()) == 1
(num_experts,) = tokens_per_expert.size()
assert num_experts == num_experts
scale = num_experts / (tokens * top_k)
return scale * torch.dot(
tokens_per_expert.to(expert_scores.dtype),
expert_scores.mean(dim=0),
)
def indices_and_bins(
top_expert: torch.Tensor,
sort_end_bit: int,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
top_expert = top_expert.int()
# Ensure contiguous memory layout
top_expert = top_expert.contiguous()
# Ensure CUB knows which device to use
with torch.cuda.device(top_expert.device):
output = ops.sort(top_expert, sort_end_bit)
bin_ids, indices = output
tokens_per_expert = ops.histogram(top_expert, num_experts)
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = bins.view(1) if not len(bins.size()) else bins
return indices, bin_ids, bins, tokens_per_expert
def expert_capacity_fn(
tokens: int,
top_k: int,
num_experts: int,
expert_parallel_group: torch.distributed.ProcessGroup,
moe_capacity_factor: float = 1.0,
moe_expert_model_parallelism: bool = False,
) -> int:
world_size = (
dist.get_world_size(expert_parallel_group)
if moe_expert_model_parallelism
else 1
)
tokens_per_expert = top_k * tokens * world_size / num_experts
return int(moe_capacity_factor * tokens_per_expert)
def permute_and_compute(
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capacity,
top_k,
w1,
w2,
w1_bias,
w2_bias,
gradient_scale,
alpha,
):
# Route tokens to experts
x = x.view(-1, x.shape[-1])
# Ensure CUB knows which device to use
with torch.cuda.device(x.device):
x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
# Expert computation
x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
# Ensure CUB knows which device to use
with torch.cuda.device(x.device):
# Route tokens back
out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
return out
def forward_once(
x: torch.Tensor,
expert_weights: torch.Tensor,
top_experts: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_bias: torch.Tensor,
w2_bias: torch.Tensor,
gradient_scale: Optional[float] = None,
alpha: float = 1.702,
sort_end_bit: int = 0,
top_k: int = 4,
num_experts: int = 128,
expert_parallel_group: int = None,
moe_capacity_factor: float = 1.0,
moe_expert_model_parallelism: bool = False,
mlp_impl: Optional[str] = None,
):
# x: [sl, bs, hs]
# expert_weights: [sl * bs, top-k]
# top_experts: [sl * bs, top-k]
expert_weights = expert_weights.flatten()
top_experts = top_experts.flatten()
with torch.no_grad():
indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
top_experts, sort_end_bit, num_experts
)
# Calculate expert capacity
sl, bs, _ = x.size()
expert_capacity = expert_capacity_fn(
sl * bs,
top_k,
num_experts,
expert_parallel_group,
moe_capacity_factor,
moe_expert_model_parallelism,
)
if expert_capacity == 0:
expert_capacity = torch.max(tokens_per_expert).item()
x = permute_and_compute(
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capacity,
top_k,
w1,
w2,
w1_bias,
w2_bias,
gradient_scale,
alpha,
)
return x, tokens_per_expert
def parallel_forward_once(
x: torch.Tensor,
expert_weights: torch.Tensor,
top_experts: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_bias: torch.Tensor,
w2_bias: torch.Tensor,
gradient_scale: Optional[float] = None,
alpha: float = 1.702,
sort_end_bit: int = 0,
top_k: int = 4,
num_experts: int = 128,
expert_parallel_group: torch.distributed.ProcessGroup = None,
moe_capacity_factor: float = 1.0,
moe_expert_model_parallelism: bool = True,
hidden_size: int = 1152,
mlp_impl: Optional[str] = "grouped",
):
# Flatten inputs
expert_weights = expert_weights.flatten()
top_experts = top_experts.flatten()
# TODO: remove debugging var
# my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
with torch.no_grad():
# Step 1: Local permutation setup
indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
top_experts, sort_end_bit, num_experts
)
# Calculate sharding parameters
world_size = dist.get_world_size(expert_parallel_group)
hidden_sharding_deg = hidden_sharding_degree(
world_size, num_experts, hidden_size
)
experts_per_rank_val = experts_per_rank(num_experts, world_size)
# Replicate token counts for hidden sharding
repeated_tokens_per_expert = ops.repeat(
tokens_per_expert, (hidden_sharding_deg,)
)
# Exchange token counts across devices
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
# Ensure CUB knows which device to use
tpe_handle = dist.all_to_all_single(
parallel_tokens_per_expert,
repeated_tokens_per_expert,
group=expert_parallel_group,
async_op=True,
)
# Step 2: Local permutation - group tokens by target device
x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
x = ops.gather(x, indices, bin_ids, bins, top_k)
# Step 3: Compute communication counts and exchange tokens
with torch.no_grad():
tpe_handle.wait()
# Reshape for per-device calculations
repeated_tokens_per_expert = repeated_tokens_per_expert.view(
world_size, experts_per_rank_val
)
parallel_tokens_per_expert = parallel_tokens_per_expert.view(
world_size, experts_per_rank_val
)
# Calculate send/recv counts
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
# recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
tokens_received = sum(recv_counts)
# Replicate for hidden sharding
x = ops.repeat(x, (hidden_sharding_deg, 1))
# Cross-device token exchange
parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
x, recv_counts, send_counts, expert_parallel_group, async_op=True
)
with torch.no_grad():
# Step 4: Setup for local expert computation
replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
replicate_bins = (
replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
)
# Create expert indices for received tokens
parallel_top_expert = torch.remainder(
torch.arange(
num_experts * hidden_sharding_deg,
dtype=torch.int32,
device=indices.device,
),
experts_per_rank_val,
)
parallel_top_expert = ops.replicate(
parallel_top_expert.unsqueeze(dim=0),
replicate_bins,
tokens_received,
).flatten()
# Sort tokens by expert assignment
parallel_bin_ids, parallel_indices = ops.sort(
parallel_top_expert,
sort_end_bit,
)
# Calculate bins for local experts
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
dim=0, dtype=torch.int
)
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
parallel_bins = (
parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
)
# Calculate expert capacity
expert_capacity = expert_capacity_fn(
tokens_received,
top_k,
experts_per_rank_val,
expert_parallel_group,
moe_capacity_factor,
moe_expert_model_parallelism,
)
if expert_capacity == 0:
expert_capacity = torch.max(parallel_tokens_per_expert).item()
# Locally permute the tokens and perform the expert computation.
# Block to make sure that the cross-device permutation is complete.
if mlp_impl == "grouped":
# GroupedMLP requires counts on CPU. We can use the tensor already
# moved to CPU for the prior all_to_all, which avoids an extra
# device synchronization.
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
dim=0,
dtype=torch.int,
)
# Step 5: Expert computation
parallel_x_handle.wait()
parallel_x = permute_and_compute(
parallel_x,
parallel_tokens_per_expert,
parallel_indices,
parallel_bin_ids,
None, # expert_weights
parallel_bins,
expert_capacity,
top_k=1,
w1=w1,
w2=w2,
w1_bias=w1_bias,
w2_bias=w2_bias,
gradient_scale=gradient_scale,
alpha=alpha,
)
# Step 6: Reverse communication - send results back
x, _ = _layers.all_to_all.all_to_all(
parallel_x, send_counts, recv_counts, expert_parallel_group
)
# Step 7: Reduce across hidden sharding dimension
shape = (hidden_sharding_deg, -1, hidden_size)
x = x.view(shape).sum(dim=0)
# Step 8: Final local unpermutation
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
return x, tokens_per_expert.flatten()
def moe_forward(
x: torch.Tensor,
router_weight: torch.Tensor,
router_bias: Optional[torch.Tensor],
moe_top_k: int,
moe_num_experts: int,
moe_jitter_eps: float = None,
moe_normalize_expert_weights: int = None,
uniform_expert_assignment: bool = False,
training: bool = False,
w1: torch.Tensor = None,
w2: torch.Tensor = None,
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
gradient_scale: Optional[float] = None,
alpha: float = 1.702,
sort_end_bit: int = 0,
expert_parallel_group: torch.distributed.ProcessGroup = None,
moe_capacity_factor: float = 1.0,
moe_expert_model_parallelism: bool = False,
forward_fn: Any = None,
hidden_size: int = None,
mlp_impl: str = "grouped",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Route tokens to experts
logits, expert_weights, expert_indices = route_tokens(
x,
router_weight,
router_bias,
moe_top_k,
moe_num_experts,
moe_jitter_eps,
moe_normalize_expert_weights,
uniform_expert_assignment,
training,
)
# Create router scores for output
router_scores = (
torch.zeros_like(logits)
.scatter_(1, expert_indices, expert_weights)
.transpose(0, 1)
)
in_shape = x.size()
# Prepare forward function arguments
forward_args = {
"x": x,
"expert_weights": expert_weights,
"top_experts": expert_indices,
"w1": w1,
"w2": w2,
"w1_bias": w1_bias,
"w2_bias": w2_bias,
"gradient_scale": gradient_scale,
"alpha": alpha,
"sort_end_bit": sort_end_bit,
"top_k": moe_top_k,
"num_experts": moe_num_experts,
"expert_parallel_group": expert_parallel_group,
"moe_capacity_factor": moe_capacity_factor,
"moe_expert_model_parallelism": moe_expert_model_parallelism,
"mlp_impl": mlp_impl,
}
# Add hidden_size for parallel forward
if moe_expert_model_parallelism and hidden_size is not None:
forward_args["hidden_size"] = hidden_size
elif moe_expert_model_parallelism and hidden_size is None:
# Infer hidden_size from input shape
forward_args["hidden_size"] = x.shape[-1]
# Compute expert outputs
x, tokens_per_expert = forward_fn(**forward_args)
# Save load balancing loss if needed
moe_loss_weight = 0.0 # Can be made configurable
if training and moe_loss_weight > 0:
save_load_balancing_loss((tokens_per_expert, logits))
# Restore original shape
x = x.view(in_shape)
return x, expert_weights, router_scores
def moe_forward_with_shared_expert(
x: torch.Tensor,
router_weight: torch.Tensor,
router_bias: Optional[torch.Tensor],
moe_top_k: int,
moe_num_experts: int,
moe_jitter_eps: float = None,
moe_normalize_expert_weights: int = None,
uniform_expert_assignment: bool = False,
training: bool = False,
w1: torch.Tensor = None,
w2: torch.Tensor = None,
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
gradient_scale: Optional[float] = None,
alpha: float = 1.702,
sort_end_bit: int = 0,
expert_parallel_group: torch.distributed.ProcessGroup = None,
moe_capacity_factor: float = 1.0,
moe_expert_model_parallelism: bool = False,
forward_fn: Any = None,
hidden_size: int = None,
mlp_impl: str = "grouped",
# Shared expert parameters
shared_up_proj_weight: Optional[torch.Tensor] = None,
shared_down_proj_weight: Optional[torch.Tensor] = None,
shared_up_proj_bias: Optional[torch.Tensor] = None,
shared_down_proj_bias: Optional[torch.Tensor] = None,
shared_expert_weighted_sum: bool = False,
shared_activation_fn: Optional[Any] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# First, compute regular MoE forward pass
expert_out, expert_weights, router_scores = moe_forward(
x=x,
router_weight=router_weight,
router_bias=router_bias,
moe_top_k=moe_top_k,
moe_num_experts=moe_num_experts,
moe_jitter_eps=moe_jitter_eps,
moe_normalize_expert_weights=moe_normalize_expert_weights,
uniform_expert_assignment=uniform_expert_assignment,
training=training,
w1=w1,
w2=w2,
w1_bias=w1_bias,
w2_bias=w2_bias,
gradient_scale=gradient_scale,
alpha=alpha,
sort_end_bit=sort_end_bit,
expert_parallel_group=expert_parallel_group,
moe_capacity_factor=moe_capacity_factor,
moe_expert_model_parallelism=moe_expert_model_parallelism,
forward_fn=forward_fn,
hidden_size=hidden_size,
mlp_impl=mlp_impl,
)
# If shared expert weights provided, compute shared expert output
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
shared_expert_out = shared_mlp_forward(
x=x,
up_proj_weight=shared_up_proj_weight,
down_proj_weight=shared_down_proj_weight,
up_proj_bias=shared_up_proj_bias,
down_proj_bias=shared_down_proj_bias,
activation_fn=shared_activation_fn,
gradient_scale=gradient_scale,
)
# Combine expert outputs
combined_out = combine_expert_shared_outputs(
shared_expert_out=shared_expert_out,
expert_out=expert_out,
shared_expert_weighted_sum=shared_expert_weighted_sum,
moe_top_k=moe_top_k,
)
return combined_out, expert_weights, router_scores
# Return regular MoE output if no shared expert
return expert_out, expert_weights, router_scores
def create_shared_expert_weights(
hidden_size: int,
shared_expert_hidden_size: int,
device: torch.device,
dtype: torch.dtype,
init_method: Any,
output_layer_init_method: Any = None,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if output_layer_init_method is None:
output_layer_init_method = init_method
# Create weight tensors
up_proj_weight = torch.empty(
shared_expert_hidden_size,
hidden_size,
device=device,
dtype=dtype,
)
down_proj_weight = torch.empty(
hidden_size,
shared_expert_hidden_size,
device=device,
dtype=dtype,
)
# Initialize weights
init_method(up_proj_weight)
output_layer_init_method(down_proj_weight)
# No bias by default
return up_proj_weight, down_proj_weight, None, None
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
# This exists because device_mesh is trapped in hook closures with no model attribute
# Fragile - breaks if hook structure changes or Python internals change
# TODO: Replace with a more robust solution when available
def get_device_mesh(model):
# Extract device_mesh from child's unused pre_hook closure
try:
# Find the pre-hook that contains 'device_mesh' in its closure
hook = next(
h
for h in model.experts._forward_pre_hooks.values()
if "device_mesh" in h.__code__.co_freevars
)
# Extract the device_mesh from the closure
return hook.__closure__[
hook.__code__.co_freevars.index("device_mesh")
].cell_contents
except Exception:
return None
class MegaBlocksMoeMLP(torch.nn.Module):
can_torch_compile: bool = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
moe_top_k = getattr(self.router, "top_k", 4)
moe_num_experts = getattr(self.experts, "num_experts", 128)
gradient_scale = getattr(self.experts, "gradient_scale", None)
alpha = getattr(self.experts, "alpha", 1.0)
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
moe_normalize_expert_weights = getattr(
self.experts, "normalize_expert_weights", None
)
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
expert_parallel_group = getattr(self, "expert_parallel_group", None)
if expert_parallel_group is None:
device_mesh = get_device_mesh(self)
expert_parallel_group = device_mesh.get_group() if device_mesh else None
has_parallel = (
expert_parallel_group is not None
and dist.is_initialized()
and dist.get_world_size(expert_parallel_group) > 1
)
forward_fn = parallel_forward_once if has_parallel else forward_once
sort_end_bit = max(
int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
)
mlp_impl = getattr(self, "mlp_impl", "grouped")
output, expert_weights_out, *_ = moe_forward(
x=x,
router_weight=self.router.weight,
router_bias=self.router.bias,
moe_top_k=moe_top_k,
moe_num_experts=moe_num_experts,
moe_jitter_eps=moe_jitter_eps,
moe_normalize_expert_weights=moe_normalize_expert_weights,
uniform_expert_assignment=uniform_expert_assignment,
training=self.training,
w1=self.experts.gate_up_proj,
w2=self.experts.down_proj,
w1_bias=self.experts.gate_up_proj_bias,
w2_bias=self.experts.down_proj_bias,
gradient_scale=gradient_scale,
alpha=alpha,
sort_end_bit=sort_end_bit,
expert_parallel_group=expert_parallel_group,
moe_capacity_factor=moe_capacity_factor,
moe_expert_model_parallelism=has_parallel,
forward_fn=forward_fn,
hidden_size=self.experts.hidden_size,
mlp_impl=mlp_impl,
)
return output, expert_weights_out
# Export main classes
__all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
def __init__(self):
super().__init__()
# Shared expert weights will be set by the user
self.shared_up_proj_weight = None
self.shared_down_proj_weight = None
self.shared_up_proj_bias = None
self.shared_down_proj_bias = None
self.shared_expert_weighted_sum = False
self.shared_activation_fn = None
def set_shared_expert_weights(
self,
up_proj_weight: torch.Tensor,
down_proj_weight: torch.Tensor,
up_proj_bias: Optional[torch.Tensor] = None,
down_proj_bias: Optional[torch.Tensor] = None,
weighted_sum: bool = False,
activation_fn: Optional[Any] = None,
):
self.shared_up_proj_weight = up_proj_weight
self.shared_down_proj_weight = down_proj_weight
self.shared_up_proj_bias = up_proj_bias
self.shared_down_proj_bias = down_proj_bias
self.shared_expert_weighted_sum = weighted_sum
self.shared_activation_fn = activation_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
moe_top_k = getattr(self.router, "top_k", 4)
moe_num_experts = getattr(self.experts, "num_experts", 128)
gradient_scale = getattr(self.experts, "gradient_scale", None)
alpha = getattr(self.experts, "alpha", 1.0)
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
moe_normalize_expert_weights = getattr(
self.experts, "normalize_expert_weights", None
)
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
expert_parallel_group = getattr(self, "expert_parallel_group", None)
if expert_parallel_group is None:
device_mesh = get_device_mesh(self)
expert_parallel_group = device_mesh.get_group() if device_mesh else None
has_parallel = (
expert_parallel_group is not None
and dist.is_initialized()
and dist.get_world_size(expert_parallel_group) > 1
)
forward_fn = parallel_forward_once if has_parallel else forward_once
sort_end_bit = max(
int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
)
mlp_impl = getattr(self, "mlp_impl", "grouped")
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
x=x,
router_weight=self.router.weight,
router_bias=self.router.bias,
moe_top_k=moe_top_k,
moe_num_experts=moe_num_experts,
moe_jitter_eps=moe_jitter_eps,
moe_normalize_expert_weights=moe_normalize_expert_weights,
uniform_expert_assignment=uniform_expert_assignment,
training=self.training,
w1=self.experts.gate_up_proj,
w2=self.experts.down_proj,
w1_bias=self.experts.gate_up_proj_bias,
w2_bias=self.experts.down_proj_bias,
gradient_scale=gradient_scale,
alpha=alpha,
sort_end_bit=sort_end_bit,
expert_parallel_group=expert_parallel_group,
moe_capacity_factor=moe_capacity_factor,
moe_expert_model_parallelism=has_parallel,
forward_fn=forward_fn,
hidden_size=self.experts.hidden_size,
mlp_impl=mlp_impl,
# Shared expert parameters
shared_up_proj_weight=self.shared_up_proj_weight,
shared_down_proj_weight=self.shared_down_proj_weight,
shared_up_proj_bias=self.shared_up_proj_bias,
shared_down_proj_bias=self.shared_down_proj_bias,
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
shared_activation_fn=self.shared_activation_fn,
)
return output, expert_weights_out