Spaces:
Running
on
Zero
Running
on
Zero
| import numbers | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pdb import set_trace as st | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps: float, elementwise_affine: bool = True): | |
| super().__init__() | |
| self.eps = eps | |
| if isinstance(dim, numbers.Integral): | |
| dim = (dim,) | |
| self.dim = torch.Size(dim) | |
| if elementwise_affine: | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| else: | |
| self.weight = None | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
| if self.weight is not None: | |
| # convert into half-precision if necessary | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| hidden_states = hidden_states.to(self.weight.dtype) | |
| hidden_states = hidden_states * self.weight | |
| else: | |
| hidden_states = hidden_states.to(input_dtype) | |
| return hidden_states.to(input_dtype) | |