Spaces:
Running
Running
| from abc import abstractmethod | |
| from typing import Literal | |
| from torch import Tensor | |
| from torch.nn import Module | |
| from torch.nn.functional import silu, gelu | |
| class GatedUnit(Module): | |
| def __init__(self, dim: int = -1) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| def _activation(self, x: Tensor) -> Tensor: | |
| ... | |
| def forward(self, x: Tensor) -> Tensor: | |
| f, g = x.chunk(2, dim=self.dim) | |
| return self._activation(f) * g | |
| class SwiGLU(GatedUnit): | |
| def __init__(self, dim: int = -1) -> None: | |
| super().__init__(dim) | |
| def _activation(self, x: Tensor) -> Tensor: | |
| return silu(x) | |
| class GeGLU(GatedUnit): | |
| def __init__( | |
| self, | |
| dim: int = -1, | |
| approximate: Literal["tanh", "none"] = "tanh" | |
| ) -> None: | |
| super().__init__(dim) | |
| self.approximate = approximate | |
| def _activation(self, x: Tensor) -> Tensor: | |
| return gelu(x, self.approximate) | |