Spaces:
Build error
Build error
| # Copyright (c) 2024, Tri Dao. | |
| # Implement dropout + residual + layer_norm / rms_norm. | |
| # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html | |
| # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. | |
| # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. | |
| # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from typing import Callable | |
| def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): | |
| def decorator(*args, **kwargs): | |
| if cuda_amp_deprecated: | |
| kwargs["device_type"] = "cuda" | |
| return dec(*args, **kwargs) | |
| return decorator | |
| if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] | |
| deprecated = True | |
| from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] | |
| else: | |
| deprecated = False | |
| from torch.cuda.amp import custom_fwd, custom_bwd | |
| custom_fwd = custom_amp_decorator(custom_fwd, deprecated) | |
| custom_bwd = custom_amp_decorator(custom_bwd, deprecated) | |
| def triton_autotune_configs(): | |
| # Return configs with a valid warp count for the current device | |
| configs=[] | |
| # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 | |
| max_threads_per_block=1024 | |
| # Default to warp size 32 if not defined by device | |
| warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) | |
| # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit | |
| warp_count=1 | |
| while warp_count*warp_size <= max_threads_per_block: | |
| configs.append(triton.Config({}, num_warps=warp_count)) | |
| warp_count*=2 | |
| return configs | |
| def layer_norm_ref( | |
| x, | |
| weight, | |
| bias, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| eps=1e-6, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| prenorm=False, | |
| zero_centered_weight=False, | |
| dropout_mask=None, | |
| dropout_mask1=None, | |
| upcast=False, | |
| ): | |
| dtype = x.dtype | |
| if upcast: | |
| x = x.float() | |
| weight = weight.float() | |
| bias = bias.float() if bias is not None else None | |
| residual = residual.float() if residual is not None else residual | |
| x1 = x1.float() if x1 is not None else None | |
| weight1 = weight1.float() if weight1 is not None else None | |
| bias1 = bias1.float() if bias1 is not None else None | |
| if zero_centered_weight: | |
| weight = weight + 1.0 | |
| if weight1 is not None: | |
| weight1 = weight1 + 1.0 | |
| if x1 is not None: | |
| assert rowscale is None, "rowscale is not supported with parallel LayerNorm" | |
| if rowscale is not None: | |
| x = x * rowscale[..., None] | |
| if dropout_p > 0.0: | |
| if dropout_mask is not None: | |
| x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) | |
| else: | |
| x = F.dropout(x, p=dropout_p) | |
| if x1 is not None: | |
| if dropout_mask1 is not None: | |
| x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) | |
| else: | |
| x1 = F.dropout(x1, p=dropout_p) | |
| if x1 is not None: | |
| x = x + x1 | |
| if residual is not None: | |
| x = (x + residual).to(x.dtype) | |
| out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( | |
| dtype | |
| ) | |
| if weight1 is None: | |
| return out if not prenorm else (out, x) | |
| else: | |
| out1 = F.layer_norm( | |
| x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps | |
| ).to(dtype) | |
| return (out, out1) if not prenorm else (out, out1, x) | |
| def rms_norm_ref( | |
| x, | |
| weight, | |
| bias, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| eps=1e-6, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| prenorm=False, | |
| zero_centered_weight=False, | |
| dropout_mask=None, | |
| dropout_mask1=None, | |
| upcast=False, | |
| ): | |
| dtype = x.dtype | |
| if upcast: | |
| x = x.float() | |
| weight = weight.float() | |
| bias = bias.float() if bias is not None else None | |
| residual = residual.float() if residual is not None else residual | |
| x1 = x1.float() if x1 is not None else None | |
| weight1 = weight1.float() if weight1 is not None else None | |
| bias1 = bias1.float() if bias1 is not None else None | |
| if zero_centered_weight: | |
| weight = weight + 1.0 | |
| if weight1 is not None: | |
| weight1 = weight1 + 1.0 | |
| if x1 is not None: | |
| assert rowscale is None, "rowscale is not supported with parallel LayerNorm" | |
| if rowscale is not None: | |
| x = x * rowscale[..., None] | |
| if dropout_p > 0.0: | |
| if dropout_mask is not None: | |
| x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) | |
| else: | |
| x = F.dropout(x, p=dropout_p) | |
| if x1 is not None: | |
| if dropout_mask1 is not None: | |
| x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) | |
| else: | |
| x1 = F.dropout(x1, p=dropout_p) | |
| if x1 is not None: | |
| x = x + x1 | |
| if residual is not None: | |
| x = (x + residual).to(x.dtype) | |
| rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) | |
| out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) | |
| if weight1 is None: | |
| return out if not prenorm else (out, x) | |
| else: | |
| out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( | |
| dtype | |
| ) | |
| return (out, out1) if not prenorm else (out, out1, x) | |
| # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) | |
| # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) | |
| def _layer_norm_fwd_1pass_kernel( | |
| X, # pointer to the input | |
| Y, # pointer to the output | |
| W, # pointer to the weights | |
| B, # pointer to the biases | |
| RESIDUAL, # pointer to the residual | |
| X1, | |
| W1, | |
| B1, | |
| Y1, | |
| RESIDUAL_OUT, # pointer to the residual | |
| ROWSCALE, | |
| SEEDS, # Dropout seeds for each row | |
| DROPOUT_MASK, | |
| Mean, # pointer to the mean | |
| Rstd, # pointer to the 1/std | |
| stride_x_row, # how much to increase the pointer when moving by 1 row | |
| stride_y_row, | |
| stride_res_row, | |
| stride_res_out_row, | |
| stride_x1_row, | |
| stride_y1_row, | |
| M, # number of rows in X | |
| N, # number of columns in X | |
| eps, # epsilon to avoid division by zero | |
| dropout_p, # Dropout probability | |
| zero_centered_weight, # If true, add 1.0 to the weight | |
| IS_RMS_NORM: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| HAS_RESIDUAL: tl.constexpr, | |
| STORE_RESIDUAL_OUT: tl.constexpr, | |
| HAS_BIAS: tl.constexpr, | |
| HAS_DROPOUT: tl.constexpr, | |
| STORE_DROPOUT_MASK: tl.constexpr, | |
| HAS_ROWSCALE: tl.constexpr, | |
| HAS_X1: tl.constexpr, | |
| HAS_W1: tl.constexpr, | |
| HAS_B1: tl.constexpr, | |
| ): | |
| # Map the program id to the row of X and Y it should compute. | |
| row = tl.program_id(0) | |
| X += row * stride_x_row | |
| Y += row * stride_y_row | |
| if HAS_RESIDUAL: | |
| RESIDUAL += row * stride_res_row | |
| if STORE_RESIDUAL_OUT: | |
| RESIDUAL_OUT += row * stride_res_out_row | |
| if HAS_X1: | |
| X1 += row * stride_x1_row | |
| if HAS_W1: | |
| Y1 += row * stride_y1_row | |
| # Compute mean and variance | |
| cols = tl.arange(0, BLOCK_N) | |
| x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) | |
| if HAS_ROWSCALE: | |
| rowscale = tl.load(ROWSCALE + row).to(tl.float32) | |
| x *= rowscale | |
| if HAS_DROPOUT: | |
| # Compute dropout mask | |
| # 7 rounds is good enough, and reduces register pressure | |
| keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p | |
| x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) | |
| if STORE_DROPOUT_MASK: | |
| tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) | |
| if HAS_X1: | |
| x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) | |
| if HAS_ROWSCALE: | |
| rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) | |
| x1 *= rowscale | |
| if HAS_DROPOUT: | |
| # Compute dropout mask | |
| # 7 rounds is good enough, and reduces register pressure | |
| keep_mask = ( | |
| tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p | |
| ) | |
| x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) | |
| if STORE_DROPOUT_MASK: | |
| tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) | |
| x += x1 | |
| if HAS_RESIDUAL: | |
| residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) | |
| x += residual | |
| if STORE_RESIDUAL_OUT: | |
| tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) | |
| if not IS_RMS_NORM: | |
| mean = tl.sum(x, axis=0) / N | |
| tl.store(Mean + row, mean) | |
| xbar = tl.where(cols < N, x - mean, 0.0) | |
| var = tl.sum(xbar * xbar, axis=0) / N | |
| else: | |
| xbar = tl.where(cols < N, x, 0.0) | |
| var = tl.sum(xbar * xbar, axis=0) / N | |
| rstd = 1 / tl.sqrt(var + eps) | |
| tl.store(Rstd + row, rstd) | |
| # Normalize and apply linear transformation | |
| mask = cols < N | |
| w = tl.load(W + cols, mask=mask).to(tl.float32) | |
| if zero_centered_weight: | |
| w += 1.0 | |
| if HAS_BIAS: | |
| b = tl.load(B + cols, mask=mask).to(tl.float32) | |
| x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd | |
| y = x_hat * w + b if HAS_BIAS else x_hat * w | |
| # Write output | |
| tl.store(Y + cols, y, mask=mask) | |
| if HAS_W1: | |
| w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) | |
| if zero_centered_weight: | |
| w1 += 1.0 | |
| if HAS_B1: | |
| b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) | |
| y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 | |
| tl.store(Y1 + cols, y1, mask=mask) | |
| def _layer_norm_fwd( | |
| x, | |
| weight, | |
| bias, | |
| eps, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| out_dtype=None, | |
| residual_dtype=None, | |
| zero_centered_weight=False, | |
| is_rms_norm=False, | |
| return_dropout_mask=False, | |
| out=None, | |
| residual_out=None | |
| ): | |
| if residual is not None: | |
| residual_dtype = residual.dtype | |
| M, N = x.shape | |
| assert x.stride(-1) == 1 | |
| if residual is not None: | |
| assert residual.stride(-1) == 1 | |
| assert residual.shape == (M, N) | |
| assert weight.shape == (N,) | |
| assert weight.stride(-1) == 1 | |
| if bias is not None: | |
| assert bias.stride(-1) == 1 | |
| assert bias.shape == (N,) | |
| if x1 is not None: | |
| assert x1.shape == x.shape | |
| assert rowscale is None | |
| assert x1.stride(-1) == 1 | |
| if weight1 is not None: | |
| assert weight1.shape == (N,) | |
| assert weight1.stride(-1) == 1 | |
| if bias1 is not None: | |
| assert bias1.shape == (N,) | |
| assert bias1.stride(-1) == 1 | |
| if rowscale is not None: | |
| assert rowscale.is_contiguous() | |
| assert rowscale.shape == (M,) | |
| # allocate output | |
| if out is None: | |
| out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) | |
| else: | |
| assert out.shape == x.shape | |
| assert out.stride(-1) == 1 | |
| if weight1 is not None: | |
| y1 = torch.empty_like(out) | |
| assert y1.stride(-1) == 1 | |
| else: | |
| y1 = None | |
| if ( | |
| residual is not None | |
| or (residual_dtype is not None and residual_dtype != x.dtype) | |
| or dropout_p > 0.0 | |
| or rowscale is not None | |
| or x1 is not None | |
| ): | |
| if residual_out is None: | |
| residual_out = torch.empty( | |
| M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype | |
| ) | |
| else: | |
| assert residual_out.shape == x.shape | |
| assert residual_out.stride(-1) == 1 | |
| else: | |
| residual_out = None | |
| mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None | |
| rstd = torch.empty((M,), dtype=torch.float32, device=x.device) | |
| if dropout_p > 0.0: | |
| seeds = torch.randint( | |
| 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 | |
| ) | |
| else: | |
| seeds = None | |
| if return_dropout_mask and dropout_p > 0.0: | |
| dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) | |
| else: | |
| dropout_mask = None | |
| # Less than 64KB per feature: enqueue fused kernel | |
| MAX_FUSED_SIZE = 65536 // x.element_size() | |
| BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | |
| if N > BLOCK_N: | |
| raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | |
| with torch.cuda.device(x.device.index): | |
| _layer_norm_fwd_1pass_kernel[(M,)]( | |
| x, | |
| out, | |
| weight, | |
| bias, | |
| residual, | |
| x1, | |
| weight1, | |
| bias1, | |
| y1, | |
| residual_out, | |
| rowscale, | |
| seeds, | |
| dropout_mask, | |
| mean, | |
| rstd, | |
| x.stride(0), | |
| out.stride(0), | |
| residual.stride(0) if residual is not None else 0, | |
| residual_out.stride(0) if residual_out is not None else 0, | |
| x1.stride(0) if x1 is not None else 0, | |
| y1.stride(0) if y1 is not None else 0, | |
| M, | |
| N, | |
| eps, | |
| dropout_p, | |
| zero_centered_weight, | |
| is_rms_norm, | |
| BLOCK_N, | |
| residual is not None, | |
| residual_out is not None, | |
| bias is not None, | |
| dropout_p > 0.0, | |
| dropout_mask is not None, | |
| rowscale is not None, | |
| ) | |
| # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 | |
| if dropout_mask is not None and x1 is not None: | |
| dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) | |
| else: | |
| dropout_mask1 = None | |
| return ( | |
| out, | |
| y1, | |
| mean, | |
| rstd, | |
| residual_out if residual_out is not None else x, | |
| seeds, | |
| dropout_mask, | |
| dropout_mask1, | |
| ) | |
| # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) | |
| # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) | |
| # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) | |
| def _layer_norm_bwd_kernel( | |
| X, # pointer to the input | |
| W, # pointer to the weights | |
| B, # pointer to the biases | |
| Y, # pointer to the output to be recomputed | |
| DY, # pointer to the output gradient | |
| DX, # pointer to the input gradient | |
| DW, # pointer to the partial sum of weights gradient | |
| DB, # pointer to the partial sum of biases gradient | |
| DRESIDUAL, | |
| W1, | |
| DY1, | |
| DX1, | |
| DW1, | |
| DB1, | |
| DRESIDUAL_IN, | |
| ROWSCALE, | |
| SEEDS, | |
| Mean, # pointer to the mean | |
| Rstd, # pointer to the 1/std | |
| stride_x_row, # how much to increase the pointer when moving by 1 row | |
| stride_y_row, | |
| stride_dy_row, | |
| stride_dx_row, | |
| stride_dres_row, | |
| stride_dy1_row, | |
| stride_dx1_row, | |
| stride_dres_in_row, | |
| M, # number of rows in X | |
| N, # number of columns in X | |
| eps, # epsilon to avoid division by zero | |
| dropout_p, | |
| zero_centered_weight, | |
| rows_per_program, | |
| IS_RMS_NORM: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| HAS_DRESIDUAL: tl.constexpr, | |
| STORE_DRESIDUAL: tl.constexpr, | |
| HAS_BIAS: tl.constexpr, | |
| HAS_DROPOUT: tl.constexpr, | |
| HAS_ROWSCALE: tl.constexpr, | |
| HAS_DY1: tl.constexpr, | |
| HAS_DX1: tl.constexpr, | |
| HAS_B1: tl.constexpr, | |
| RECOMPUTE_OUTPUT: tl.constexpr, | |
| ): | |
| # Map the program id to the elements of X, DX, and DY it should compute. | |
| row_block_id = tl.program_id(0) | |
| row_start = row_block_id * rows_per_program | |
| # Do not early exit if row_start >= M, because we need to write DW and DB | |
| cols = tl.arange(0, BLOCK_N) | |
| mask = cols < N | |
| X += row_start * stride_x_row | |
| if HAS_DRESIDUAL: | |
| DRESIDUAL += row_start * stride_dres_row | |
| if STORE_DRESIDUAL: | |
| DRESIDUAL_IN += row_start * stride_dres_in_row | |
| DY += row_start * stride_dy_row | |
| DX += row_start * stride_dx_row | |
| if HAS_DY1: | |
| DY1 += row_start * stride_dy1_row | |
| if HAS_DX1: | |
| DX1 += row_start * stride_dx1_row | |
| if RECOMPUTE_OUTPUT: | |
| Y += row_start * stride_y_row | |
| w = tl.load(W + cols, mask=mask).to(tl.float32) | |
| if zero_centered_weight: | |
| w += 1.0 | |
| if RECOMPUTE_OUTPUT and HAS_BIAS: | |
| b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) | |
| if HAS_DY1: | |
| w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) | |
| if zero_centered_weight: | |
| w1 += 1.0 | |
| dw = tl.zeros((BLOCK_N,), dtype=tl.float32) | |
| if HAS_BIAS: | |
| db = tl.zeros((BLOCK_N,), dtype=tl.float32) | |
| if HAS_DY1: | |
| dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) | |
| if HAS_B1: | |
| db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) | |
| row_end = min((row_block_id + 1) * rows_per_program, M) | |
| for row in range(row_start, row_end): | |
| # Load data to SRAM | |
| x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) | |
| dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) | |
| if HAS_DY1: | |
| dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) | |
| if not IS_RMS_NORM: | |
| mean = tl.load(Mean + row) | |
| rstd = tl.load(Rstd + row) | |
| # Compute dx | |
| xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd | |
| xhat = tl.where(mask, xhat, 0.0) | |
| if RECOMPUTE_OUTPUT: | |
| y = xhat * w + b if HAS_BIAS else xhat * w | |
| tl.store(Y + cols, y, mask=mask) | |
| wdy = w * dy | |
| dw += dy * xhat | |
| if HAS_BIAS: | |
| db += dy | |
| if HAS_DY1: | |
| wdy += w1 * dy1 | |
| dw1 += dy1 * xhat | |
| if HAS_B1: | |
| db1 += dy1 | |
| if not IS_RMS_NORM: | |
| c1 = tl.sum(xhat * wdy, axis=0) / N | |
| c2 = tl.sum(wdy, axis=0) / N | |
| dx = (wdy - (xhat * c1 + c2)) * rstd | |
| else: | |
| c1 = tl.sum(xhat * wdy, axis=0) / N | |
| dx = (wdy - xhat * c1) * rstd | |
| if HAS_DRESIDUAL: | |
| dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) | |
| dx += dres | |
| # Write dx | |
| if STORE_DRESIDUAL: | |
| tl.store(DRESIDUAL_IN + cols, dx, mask=mask) | |
| if HAS_DX1: | |
| if HAS_DROPOUT: | |
| keep_mask = ( | |
| tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p | |
| ) | |
| dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) | |
| else: | |
| dx1 = dx | |
| tl.store(DX1 + cols, dx1, mask=mask) | |
| if HAS_DROPOUT: | |
| keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p | |
| dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) | |
| if HAS_ROWSCALE: | |
| rowscale = tl.load(ROWSCALE + row).to(tl.float32) | |
| dx *= rowscale | |
| tl.store(DX + cols, dx, mask=mask) | |
| X += stride_x_row | |
| if HAS_DRESIDUAL: | |
| DRESIDUAL += stride_dres_row | |
| if STORE_DRESIDUAL: | |
| DRESIDUAL_IN += stride_dres_in_row | |
| if RECOMPUTE_OUTPUT: | |
| Y += stride_y_row | |
| DY += stride_dy_row | |
| DX += stride_dx_row | |
| if HAS_DY1: | |
| DY1 += stride_dy1_row | |
| if HAS_DX1: | |
| DX1 += stride_dx1_row | |
| tl.store(DW + row_block_id * N + cols, dw, mask=mask) | |
| if HAS_BIAS: | |
| tl.store(DB + row_block_id * N + cols, db, mask=mask) | |
| if HAS_DY1: | |
| tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) | |
| if HAS_B1: | |
| tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) | |
| def _layer_norm_bwd( | |
| dy, | |
| x, | |
| weight, | |
| bias, | |
| eps, | |
| mean, | |
| rstd, | |
| dresidual=None, | |
| dy1=None, | |
| weight1=None, | |
| bias1=None, | |
| seeds=None, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| has_residual=False, | |
| has_x1=False, | |
| zero_centered_weight=False, | |
| is_rms_norm=False, | |
| x_dtype=None, | |
| recompute_output=False, | |
| ): | |
| M, N = x.shape | |
| assert x.stride(-1) == 1 | |
| assert dy.stride(-1) == 1 | |
| assert dy.shape == (M, N) | |
| if dresidual is not None: | |
| assert dresidual.stride(-1) == 1 | |
| assert dresidual.shape == (M, N) | |
| assert weight.shape == (N,) | |
| assert weight.stride(-1) == 1 | |
| if bias is not None: | |
| assert bias.stride(-1) == 1 | |
| assert bias.shape == (N,) | |
| if dy1 is not None: | |
| assert weight1 is not None | |
| assert dy1.shape == dy.shape | |
| assert dy1.stride(-1) == 1 | |
| if weight1 is not None: | |
| assert weight1.shape == (N,) | |
| assert weight1.stride(-1) == 1 | |
| if bias1 is not None: | |
| assert bias1.shape == (N,) | |
| assert bias1.stride(-1) == 1 | |
| if seeds is not None: | |
| assert seeds.is_contiguous() | |
| assert seeds.shape == (M if not has_x1 else M * 2,) | |
| if rowscale is not None: | |
| assert rowscale.is_contiguous() | |
| assert rowscale.shape == (M,) | |
| # allocate output | |
| dx = ( | |
| torch.empty_like(x) | |
| if x_dtype is None | |
| else torch.empty(M, N, dtype=x_dtype, device=x.device) | |
| ) | |
| dresidual_in = ( | |
| torch.empty_like(x) | |
| if has_residual | |
| and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) | |
| else None | |
| ) | |
| dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None | |
| y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None | |
| if recompute_output: | |
| assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" | |
| # Less than 64KB per feature: enqueue fused kernel | |
| MAX_FUSED_SIZE = 65536 // x.element_size() | |
| BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | |
| if N > BLOCK_N: | |
| raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | |
| # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the | |
| # latency of the gmem reads/writes, but will increase the time of summing up dw / db. | |
| sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 | |
| _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) | |
| _db = ( | |
| torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) | |
| if bias is not None | |
| else None | |
| ) | |
| _dw1 = torch.empty_like(_dw) if weight1 is not None else None | |
| _db1 = torch.empty_like(_db) if bias1 is not None else None | |
| rows_per_program = math.ceil(M / sm_count) | |
| grid = (sm_count,) | |
| with torch.cuda.device(x.device.index): | |
| _layer_norm_bwd_kernel[grid]( | |
| x, | |
| weight, | |
| bias, | |
| y, | |
| dy, | |
| dx, | |
| _dw, | |
| _db, | |
| dresidual, | |
| weight1, | |
| dy1, | |
| dx1, | |
| _dw1, | |
| _db1, | |
| dresidual_in, | |
| rowscale, | |
| seeds, | |
| mean, | |
| rstd, | |
| x.stride(0), | |
| 0 if not recompute_output else y.stride(0), | |
| dy.stride(0), | |
| dx.stride(0), | |
| dresidual.stride(0) if dresidual is not None else 0, | |
| dy1.stride(0) if dy1 is not None else 0, | |
| dx1.stride(0) if dx1 is not None else 0, | |
| dresidual_in.stride(0) if dresidual_in is not None else 0, | |
| M, | |
| N, | |
| eps, | |
| dropout_p, | |
| zero_centered_weight, | |
| rows_per_program, | |
| is_rms_norm, | |
| BLOCK_N, | |
| dresidual is not None, | |
| dresidual_in is not None, | |
| bias is not None, | |
| dropout_p > 0.0, | |
| ) | |
| dw = _dw.sum(0).to(weight.dtype) | |
| db = _db.sum(0).to(bias.dtype) if bias is not None else None | |
| dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None | |
| db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None | |
| # Don't need to compute dresidual_in separately in this case | |
| if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: | |
| dresidual_in = dx | |
| if has_x1 and dropout_p == 0.0: | |
| dx1 = dx | |
| return ( | |
| (dx, dw, db, dresidual_in, dx1, dw1, db1) | |
| if not recompute_output | |
| else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) | |
| ) | |
| class LayerNormFn(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x, | |
| weight, | |
| bias, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| eps=1e-6, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| zero_centered_weight=False, | |
| is_rms_norm=False, | |
| return_dropout_mask=False, | |
| out=None, | |
| residual_out=None | |
| ): | |
| x_shape_og = x.shape | |
| # Check for zero sequence length | |
| if x.numel() == 0: | |
| ctx.zero_seq_length = True | |
| # Only save minimal required tensors for backward | |
| # ctx.save_for_backward(weight, bias, weight1, bias1) | |
| ctx.x_shape_og = x_shape_og | |
| ctx.weight_shape = weight.shape | |
| ctx.weight_dtype = weight.dtype | |
| ctx.weight_device = weight.device | |
| ctx.has_bias = bias is not None | |
| ctx.bias_shape = bias.shape if bias is not None else None | |
| ctx.bias_dtype = bias.dtype if bias is not None else None | |
| ctx.bias_device = bias.device if bias is not None else None | |
| ctx.has_weight1 = weight1 is not None | |
| ctx.weight1_shape = weight1.shape if weight1 is not None else None | |
| ctx.weight1_dtype = weight1.dtype if weight1 is not None else None | |
| ctx.weight1_device = weight1.device if weight1 is not None else None | |
| ctx.has_bias1 = bias1 is not None | |
| ctx.bias1_shape = bias1.shape if bias1 is not None else None | |
| ctx.bias1_dtype = bias1.dtype if bias1 is not None else None | |
| ctx.bias1_device = bias1.device if bias1 is not None else None | |
| ctx.has_residual = residual is not None | |
| ctx.has_x1 = x1 is not None | |
| ctx.dropout_p = dropout_p | |
| # Handle output tensors with correct dtype | |
| y = x # Preserve input tensor properties | |
| y1 = torch.empty_like(x) if x1 is not None else None | |
| # Only create residual_out if prenorm is True | |
| residual_out = torch.empty(x.shape, | |
| dtype=torch.float32 if residual_in_fp32 else x.dtype, | |
| device=x.device) if prenorm else None | |
| # Handle dropout masks | |
| dropout_mask = None | |
| dropout_mask1 = None | |
| if return_dropout_mask: | |
| dropout_mask = torch.empty_like(x, dtype=torch.uint8) | |
| if x1 is not None: | |
| dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) | |
| # Return based on configuration | |
| if not return_dropout_mask: | |
| if weight1 is None: | |
| return y if not prenorm else (y, residual_out) | |
| else: | |
| return (y, y1) if not prenorm else (y, y1, residual_out) | |
| else: | |
| if weight1 is None: | |
| return ((y, dropout_mask, dropout_mask1) if not prenorm | |
| else (y, residual_out, dropout_mask, dropout_mask1)) | |
| else: | |
| return ((y, y1, dropout_mask, dropout_mask1) if not prenorm | |
| else (y, y1, residual_out, dropout_mask, dropout_mask1)) | |
| ctx.zero_seq_length = False | |
| # reshape input data into 2D tensor | |
| x = x.reshape(-1, x.shape[-1]) | |
| if x.stride(-1) != 1: | |
| x = x.contiguous() | |
| if residual is not None: | |
| assert residual.shape == x_shape_og | |
| residual = residual.reshape(-1, residual.shape[-1]) | |
| if residual.stride(-1) != 1: | |
| residual = residual.contiguous() | |
| if x1 is not None: | |
| assert x1.shape == x_shape_og | |
| assert rowscale is None, "rowscale is not supported with parallel LayerNorm" | |
| x1 = x1.reshape(-1, x1.shape[-1]) | |
| if x1.stride(-1) != 1: | |
| x1 = x1.contiguous() | |
| weight = weight.contiguous() | |
| if bias is not None: | |
| bias = bias.contiguous() | |
| if weight1 is not None: | |
| weight1 = weight1.contiguous() | |
| if bias1 is not None: | |
| bias1 = bias1.contiguous() | |
| if rowscale is not None: | |
| rowscale = rowscale.reshape(-1).contiguous() | |
| residual_dtype = ( | |
| residual.dtype | |
| if residual is not None | |
| else (torch.float32 if residual_in_fp32 else None) | |
| ) | |
| if out is not None: | |
| out = out.reshape(-1, out.shape[-1]) | |
| if residual_out is not None: | |
| residual_out = residual_out.reshape(-1, residual_out.shape[-1]) | |
| y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( | |
| x, | |
| weight, | |
| bias, | |
| eps, | |
| residual, | |
| x1, | |
| weight1, | |
| bias1, | |
| dropout_p=dropout_p, | |
| rowscale=rowscale, | |
| residual_dtype=residual_dtype, | |
| zero_centered_weight=zero_centered_weight, | |
| is_rms_norm=is_rms_norm, | |
| return_dropout_mask=return_dropout_mask, | |
| out=out, | |
| residual_out=residual_out | |
| ) | |
| ctx.save_for_backward( | |
| residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd | |
| ) | |
| ctx.x_shape_og = x_shape_og | |
| ctx.eps = eps | |
| ctx.dropout_p = dropout_p | |
| ctx.is_rms_norm = is_rms_norm | |
| ctx.has_residual = residual is not None | |
| ctx.has_x1 = x1 is not None | |
| ctx.prenorm = prenorm | |
| ctx.x_dtype = x.dtype | |
| ctx.zero_centered_weight = zero_centered_weight | |
| y = y.reshape(x_shape_og) | |
| y1 = y1.reshape(x_shape_og) if y1 is not None else None | |
| residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None | |
| dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None | |
| dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None | |
| if not return_dropout_mask: | |
| if weight1 is None: | |
| return y if not prenorm else (y, residual_out) | |
| else: | |
| return (y, y1) if not prenorm else (y, y1, residual_out) | |
| else: | |
| if weight1 is None: | |
| return ( | |
| (y, dropout_mask, dropout_mask1) | |
| if not prenorm | |
| else (y, residual_out, dropout_mask, dropout_mask1) | |
| ) | |
| else: | |
| return ( | |
| (y, y1, dropout_mask, dropout_mask1) | |
| if not prenorm | |
| else (y, y1, residual_out, dropout_mask, dropout_mask1) | |
| ) | |
| def backward(ctx, dy, *args): | |
| if ctx.zero_seq_length: | |
| return ( | |
| torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), | |
| torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), | |
| torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, | |
| torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, | |
| torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, | |
| torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, | |
| torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors | |
| dy = dy.reshape(-1, dy.shape[-1]) | |
| if dy.stride(-1) != 1: | |
| dy = dy.contiguous() | |
| assert dy.shape == x.shape | |
| if weight1 is not None: | |
| dy1, args = args[0], args[1:] | |
| dy1 = dy1.reshape(-1, dy1.shape[-1]) | |
| if dy1.stride(-1) != 1: | |
| dy1 = dy1.contiguous() | |
| assert dy1.shape == x.shape | |
| else: | |
| dy1 = None | |
| if ctx.prenorm: | |
| dresidual = args[0] | |
| dresidual = dresidual.reshape(-1, dresidual.shape[-1]) | |
| if dresidual.stride(-1) != 1: | |
| dresidual = dresidual.contiguous() | |
| assert dresidual.shape == x.shape | |
| else: | |
| dresidual = None | |
| dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( | |
| dy, | |
| x, | |
| weight, | |
| bias, | |
| ctx.eps, | |
| mean, | |
| rstd, | |
| dresidual, | |
| dy1, | |
| weight1, | |
| bias1, | |
| seeds, | |
| ctx.dropout_p, | |
| rowscale, | |
| ctx.has_residual, | |
| ctx.has_x1, | |
| ctx.zero_centered_weight, | |
| ctx.is_rms_norm, | |
| x_dtype=ctx.x_dtype, | |
| ) | |
| return ( | |
| dx.reshape(ctx.x_shape_og), | |
| dw, | |
| db, | |
| dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, | |
| dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, | |
| dw1, | |
| db1, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def layer_norm_fn( | |
| x, | |
| weight, | |
| bias, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| eps=1e-6, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| zero_centered_weight=False, | |
| is_rms_norm=False, | |
| return_dropout_mask=False, | |
| out=None, | |
| residual_out=None | |
| ): | |
| return LayerNormFn.apply( | |
| x, | |
| weight, | |
| bias, | |
| residual, | |
| x1, | |
| weight1, | |
| bias1, | |
| eps, | |
| dropout_p, | |
| rowscale, | |
| prenorm, | |
| residual_in_fp32, | |
| zero_centered_weight, | |
| is_rms_norm, | |
| return_dropout_mask, | |
| out, | |
| residual_out | |
| ) | |
| def rms_norm_fn( | |
| x, | |
| weight, | |
| bias, | |
| residual=None, | |
| x1=None, | |
| weight1=None, | |
| bias1=None, | |
| eps=1e-6, | |
| dropout_p=0.0, | |
| rowscale=None, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| zero_centered_weight=False, | |
| return_dropout_mask=False, | |
| out=None, | |
| residual_out=None | |
| ): | |
| return LayerNormFn.apply( | |
| x, | |
| weight, | |
| bias, | |
| residual, | |
| x1, | |
| weight1, | |
| bias1, | |
| eps, | |
| dropout_p, | |
| rowscale, | |
| prenorm, | |
| residual_in_fp32, | |
| zero_centered_weight, | |
| True, | |
| return_dropout_mask, | |
| out, | |
| residual_out | |
| ) | |
| class RMSNorm(torch.nn.Module): | |
| def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, | |
| device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.eps = eps | |
| if dropout_p > 0.0: | |
| self.drop = torch.nn.Dropout(dropout_p) | |
| else: | |
| self.drop = None | |
| self.zero_centered_weight = zero_centered_weight | |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) | |
| self.register_parameter("bias", None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| if not self.zero_centered_weight: | |
| torch.nn.init.ones_(self.weight) | |
| else: | |
| torch.nn.init.zeros_(self.weight) | |
| def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): | |
| return rms_norm_fn( | |
| x, | |
| self.weight, | |
| self.bias, | |
| residual=residual, | |
| eps=self.eps, | |
| dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, | |
| prenorm=prenorm, | |
| residual_in_fp32=residual_in_fp32, | |
| zero_centered_weight=self.zero_centered_weight, | |
| ) | |
| class LayerNormLinearFn(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| x, | |
| norm_weight, | |
| norm_bias, | |
| linear_weight, | |
| linear_bias, | |
| residual=None, | |
| eps=1e-6, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| is_rms_norm=False, | |
| ): | |
| x_shape_og = x.shape | |
| # reshape input data into 2D tensor | |
| x = x.reshape(-1, x.shape[-1]) | |
| if x.stride(-1) != 1: | |
| x = x.contiguous() | |
| if residual is not None: | |
| assert residual.shape == x_shape_og | |
| residual = residual.reshape(-1, residual.shape[-1]) | |
| if residual.stride(-1) != 1: | |
| residual = residual.contiguous() | |
| norm_weight = norm_weight.contiguous() | |
| if norm_bias is not None: | |
| norm_bias = norm_bias.contiguous() | |
| residual_dtype = ( | |
| residual.dtype | |
| if residual is not None | |
| else (torch.float32 if residual_in_fp32 else None) | |
| ) | |
| y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( | |
| x, | |
| norm_weight, | |
| norm_bias, | |
| eps, | |
| residual, | |
| out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), | |
| residual_dtype=residual_dtype, | |
| is_rms_norm=is_rms_norm, | |
| ) | |
| y = y.reshape(x_shape_og) | |
| dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype | |
| linear_weight = linear_weight.to(dtype) | |
| linear_bias = linear_bias.to(dtype) if linear_bias is not None else None | |
| out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) | |
| # We don't store y, will be recomputed in the backward pass to save memory | |
| ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) | |
| ctx.x_shape_og = x_shape_og | |
| ctx.eps = eps | |
| ctx.is_rms_norm = is_rms_norm | |
| ctx.has_residual = residual is not None | |
| ctx.prenorm = prenorm | |
| ctx.x_dtype = x.dtype | |
| ctx.linear_bias_is_none = linear_bias is None | |
| return out if not prenorm else (out, residual_out.reshape(x_shape_og)) | |
| def backward(ctx, dout, *args): | |
| x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors | |
| dout = dout.reshape(-1, dout.shape[-1]) | |
| dy = F.linear(dout, linear_weight.t()) | |
| dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) | |
| if dy.stride(-1) != 1: | |
| dy = dy.contiguous() | |
| assert dy.shape == x.shape | |
| if ctx.prenorm: | |
| dresidual = args[0] | |
| dresidual = dresidual.reshape(-1, dresidual.shape[-1]) | |
| if dresidual.stride(-1) != 1: | |
| dresidual = dresidual.contiguous() | |
| assert dresidual.shape == x.shape | |
| else: | |
| dresidual = None | |
| dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( | |
| dy, | |
| x, | |
| norm_weight, | |
| norm_bias, | |
| ctx.eps, | |
| mean, | |
| rstd, | |
| dresidual=dresidual, | |
| has_residual=ctx.has_residual, | |
| is_rms_norm=ctx.is_rms_norm, | |
| x_dtype=ctx.x_dtype, | |
| recompute_output=True, | |
| ) | |
| dlinear_weight = torch.einsum("bo,bi->oi", dout, y) | |
| return ( | |
| dx.reshape(ctx.x_shape_og), | |
| dnorm_weight, | |
| dnorm_bias, | |
| dlinear_weight, | |
| dlinear_bias, | |
| dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def layer_norm_linear_fn( | |
| x, | |
| norm_weight, | |
| norm_bias, | |
| linear_weight, | |
| linear_bias, | |
| residual=None, | |
| eps=1e-6, | |
| prenorm=False, | |
| residual_in_fp32=False, | |
| is_rms_norm=False, | |
| ): | |
| return LayerNormLinearFn.apply( | |
| x, | |
| norm_weight, | |
| norm_bias, | |
| linear_weight, | |
| linear_bias, | |
| residual, | |
| eps, | |
| prenorm, | |
| residual_in_fp32, | |
| is_rms_norm, | |
| ) | |