Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| # AnyPrecisionAdamW: a flexible precision AdamW optimizer | |
| # with optional Kahan summation for high precision weight updates. | |
| # Allows direct control over momentum, variance and auxiliary compensation | |
| # buffer dtypes. | |
| # Optional Kahan summation is used to offset precision reduction for | |
| # the weight updates. This allows full training in BFloat16 (equal or | |
| # better than FP32 results in many cases) due to high precision weight upates. | |
| import torch | |
| from torch.optim.optimizer import Optimizer | |
| class AnyPrecisionAdamW(Optimizer): | |
| def __init__( | |
| self, | |
| params, | |
| lr=1e-3, | |
| betas=(0.9, 0.999), | |
| eps=1e-8, | |
| weight_decay=0.0, | |
| use_kahan_summation=False, | |
| momentum_dtype=torch.bfloat16, | |
| variance_dtype=torch.bfloat16, | |
| compensation_buffer_dtype=torch.bfloat16, | |
| ): | |
| """ | |
| Args: | |
| params (iterable): iterable of parameters to optimize or dicts defining | |
| parameter groups | |
| lr (float, optional): learning rate (default: 1e-3) | |
| betas (Tuple[float, float], optional): coefficients used for computing | |
| running averages of gradient and its square (default: (0.9, 0.999)) | |
| eps (float, optional): term added to the denominator to improve | |
| numerical stability (default: 1e-8) | |
| weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |
| # Any Precision specific | |
| use_kahan_summation = creates auxiliary buffer to ensure high precision | |
| model param updates (default: False) | |
| momentum_dtype = dtype for momentum (default: BFloat32) | |
| variance_dtype = dtype for uncentered variance (default: BFloat16) | |
| compensation_buffer_dtype = dtype for Kahan summation | |
| buffer (default: BFloat16) | |
| # Usage | |
| This optimizer implements optimizer states, and Kahan summation | |
| for high precision updates, all in user controlled dtypes. | |
| Defaults are variance in BF16, Momentum in FP32. | |
| This can be run in FSDP mixed precision, amp, or full precision, | |
| depending on what training pipeline you wish to work with. | |
| Setting to use_kahan_summation = False, and changing momentum and | |
| variance dtypes to FP32, reverts this to a standard AdamW optimizer. | |
| """ | |
| defaults = dict( | |
| lr=lr, | |
| betas=betas, | |
| eps=eps, | |
| weight_decay=weight_decay, | |
| use_kahan_summation=use_kahan_summation, | |
| momentum_dtype=momentum_dtype, | |
| variance_dtype=variance_dtype, | |
| compensation_buffer_dtype=compensation_buffer_dtype, | |
| ) | |
| super().__init__(params, defaults) | |
| def step(self, closure=None): | |
| """Performs a single optimization step. | |
| Args: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| # to fix linter, we do not keep the returned loss for use atm. | |
| closure() | |
| for group in self.param_groups: | |
| beta1, beta2 = group["betas"] | |
| lr = group["lr"] | |
| weight_decay = group["weight_decay"] | |
| eps = group["eps"] | |
| use_kahan_summation = group["use_kahan_summation"] | |
| momentum_dtype = group["momentum_dtype"] | |
| variance_dtype = group["variance_dtype"] | |
| compensation_buffer_dtype = group["compensation_buffer_dtype"] | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| if p.grad.is_sparse: | |
| raise RuntimeError( | |
| "AnyPrecisionAdamW does not support sparse gradients" | |
| ) | |
| state = self.state[p] | |
| # State initialization | |
| if len(state) == 0: | |
| state["step"] = torch.tensor(0.0) | |
| # momentum - EMA of gradient values | |
| state["exp_avg"] = torch.zeros_like( | |
| p, | |
| dtype=momentum_dtype, | |
| ) | |
| # variance uncentered - EMA of squared gradient values | |
| state["exp_avg_sq"] = torch.zeros_like( | |
| p, | |
| dtype=variance_dtype, | |
| ) | |
| # optional Kahan summation - accumulated error tracker | |
| if use_kahan_summation: | |
| state["compensation"] = torch.zeros_like( | |
| p, | |
| dtype=compensation_buffer_dtype, | |
| ) | |
| # main processing ------------------------- | |
| # update the steps for each param group update | |
| state["step"] += 1 | |
| step = state["step"] | |
| exp_avg = state["exp_avg"] | |
| exp_avg_sq = state["exp_avg_sq"] | |
| grad = p.grad | |
| # weight decay, AdamW style | |
| if weight_decay: | |
| p.data.mul_(1 - lr * weight_decay) | |
| # update momentum | |
| exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
| # update uncentered variance | |
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
| # adjust using bias1 | |
| bias_correction1 = 1 - beta1**step | |
| step_size = lr / bias_correction1 | |
| # adjust using bias2 | |
| denom_correction = (1 - beta2**step) ** 0.5 # avoids math import | |
| centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( | |
| eps, alpha=1 | |
| ) | |
| # lr update to compensation | |
| if use_kahan_summation: | |
| compensation = state["compensation"] | |
| compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) | |
| # update weights with compensation (Kahan summation) | |
| # save error back to compensation for next iteration | |
| temp_buffer = p.detach().clone() | |
| p.data.add_(compensation) | |
| compensation.add_(temp_buffer.sub_(p.data)) | |
| else: | |
| # usual AdamW updates | |
| p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) |