Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| class _SwiGLU(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| ): | |
| super().__init__() | |
| self.w12 = nn.Linear(dim, hidden_dim*2, bias=False) | |
| self.w3 = nn.Linear(hidden_dim, dim, bias=False) | |
| def forward(self, x): | |
| x1, x2 = self.w12(x).chunk(2, dim=-1) | |
| return self.w3(torch.nn.functional.silu(x1)*x2) | |
| # try: | |
| # from xformers.ops import SwiGLU as aa | |
| # SwiGLU = SwiGLU | |
| # print("use xformers swiglu") | |
| # except: | |
| # print("use slow swiglu") | |
| SwiGLU = _SwiGLU |