""" Muon optimizer from Keller et al. Also a lot of borrowing of ideas from modded-nanogpt. """ import torch from torch import Tensor import torch.distributed as dist @torch.compile def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) # Perform the NS iterations for _ in range(steps): A = X @ X.mT B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X if G.size(-2) > G.size(-1): X = X.mT return X class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz https://kellerjordan.github.io/posts/muon/ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. Some warnings: - This optimizer should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. Arguments: lr: The learning rate used by the internal SGD. momentum: The momentum used by the internal SGD. nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iteration steps to use. """ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) params: list[Tensor] = [*params] param_groups = [] for size in {p.numel() for p in params}: group = dict(params=[p for p in params if p.numel() == size]) param_groups.append(group) super().__init__(param_groups, defaults) @torch.no_grad() def step(self): for group in self.param_groups: params: list[Tensor] = group["params"] for p in params: g = p.grad assert g is not None state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf: Tensor = state["momentum_buffer"] buf.lerp_(g, 1 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) class DistMuon(torch.optim.Optimizer): """ Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz, finally apply aspect-ratio scaled step. Performs its own distributed synchronization: - reduce_scatter(AVG) for gradient averaging - all_gather to replicate updated weights Notes: * Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D params like embeddings or scalars. * Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen by block-cyclic assignment below). If you checkpoint optimizer state on a single rank, consolidate states beforehand. Args: params: iterable of Tensors lr: learning rate momentum: momentum coefficient in [0,1) nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf ns_steps: number of Newton–Schulz iterations for the orthogonalization """ def __init__(self, params, lr: float = 0.02, momentum: float = 0.95, nesterov: bool = True, ns_steps: int = 5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) params = list(params) assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only" rank = dist.get_rank() # Group all parameters by their shape shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering param_groups = [] for shape in shapes: group_params = [p for p in params if p.shape == shape] device, dtype = group_params[0].device, group_params[0].dtype assert all(p.device == device for p in group_params) assert all(p.dtype == dtype for p in group_params) if rank == 0: print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}") param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0]))) super().__init__(param_groups, defaults) @torch.no_grad() def step(self): rank = dist.get_rank() world_size = dist.get_world_size() # Ensure all grads exist assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads" # Kick off all the reduce scatter operations to average up the gradients across all ranks all_reduce_futures = [] for group in self.param_groups: params = group["params"] zero_buffer = group["zero_buffer"] # Go through params in groups of world_size. for base_i in range(0, len(params), world_size): # The compute owner of each param is rank i % world_size owner_idx = base_i + rank # each rank stacks up its chunk of world_size params into a list rs_input = [p.grad for p in params[base_i:base_i + world_size]] # pad rs_input with the zero buffer to complete the group rs_input.extend([zero_buffer] * (world_size - len(rs_input))) # the output buffer gets strided across the group based on the rank rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer) # reduce scatter the gradients within this group of world_size params work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future() all_reduce_futures.append(work) # Now each rank computes the update and gathers future_idx = 0 all_gather_futures = [] for group in self.param_groups: params = group["params"] zero_buffer = group["zero_buffer"] # Go through params in groups of world_size. for base_i in range(0, len(params), world_size): # The compute owner of each param is rank i % world_size owner_idx = base_i + rank # calculate the index of the param that this rank owns # Wait for the reduce scatter to complete all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead future_idx += 1 # Owner computes the Muon update, result is in its param if owner_idx < len(params): p = params[owner_idx] g = p.grad # now averaged across ranks state = self.state[p] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(g) buf: Tensor = state["momentum_buffer"] buf.lerp_(g, 1.0 - group["momentum"]) g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) p.add_(g, alpha=-group["lr"] * scale) # Replicate updated parameters to all ranks ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer ag_output = params[base_i:base_i + world_size] ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad work = dist.all_gather(ag_output, ag_input, async_op=True).get_future() all_gather_futures.append(work) # Wait for all work to finish torch.futures.collect_all(all_gather_futures).wait()