JTP-3-Demo / glu.py
RedHotTensors's picture
JTP-3 Hydra Release
d62ba4b
raw
history blame
968 Bytes
from abc import abstractmethod
from typing import Literal
from torch import Tensor
from torch.nn import Module
from torch.nn.functional import silu, gelu
class GatedUnit(Module):
def __init__(self, dim: int = -1) -> None:
super().__init__()
self.dim = dim
@abstractmethod
def _activation(self, x: Tensor) -> Tensor:
...
def forward(self, x: Tensor) -> Tensor:
f, g = x.chunk(2, dim=self.dim)
return self._activation(f) * g
class SwiGLU(GatedUnit):
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
def _activation(self, x: Tensor) -> Tensor:
return silu(x)
class GeGLU(GatedUnit):
def __init__(
self,
dim: int = -1,
approximate: Literal["tanh", "none"] = "tanh"
) -> None:
super().__init__(dim)
self.approximate = approximate
def _activation(self, x: Tensor) -> Tensor:
return gelu(x, self.approximate)