| """ | |
| Activation functions for the LLM model. | |
| """ | |
| import jax | |
| import jax.numpy as jnp | |
| from typing import Callable | |
| def gelu(x: jnp.ndarray) -> jnp.ndarray: | |
| """ | |
| Gaussian Error Linear Unit (GELU) activation function. | |
| Args: | |
| x: Input tensor | |
| Returns: | |
| GELU activation applied to input | |
| """ | |
| return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3))) | |
| def swiglu(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: | |
| """ | |
| SwiGLU activation function (Swish-Gated Linear Unit). | |
| Used in modern LLMs like PaLM and Gemini. | |
| Args: | |
| x: First input tensor | |
| y: Second input tensor | |
| Returns: | |
| SwiGLU activation applied to inputs | |
| """ | |
| return x * jax.nn.sigmoid(y) | |
| def relu(x: jnp.ndarray) -> jnp.ndarray: | |
| """ | |
| Rectified Linear Unit (ReLU) activation function. | |
| Args: | |
| x: Input tensor | |
| Returns: | |
| ReLU activation applied to input | |
| """ | |
| return jnp.maximum(0, x) | |
| class GELU: | |
| """GELU activation function class.""" | |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
| return gelu(x) | |
| class SwiGLU: | |
| """SwiGLU activation function class.""" | |
| def __call__(self, x: jnp.ndarray, gate: jnp.ndarray) -> jnp.ndarray: | |
| return swiglu(x, gate) | |
| class ReLU: | |
| """ReLU activation function class.""" | |
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | |
| return relu(x) | |
| def get_activation_fn(name: str) -> Callable: | |
| """ | |
| Get activation function by name. | |
| Args: | |
| name: Name of activation function | |
| Returns: | |
| Activation function | |
| Raises: | |
| ValueError: If activation function is not supported | |
| """ | |
| if name.lower() == 'gelu': | |
| return gelu | |
| elif name.lower() == 'swiglu': | |
| return swiglu | |
| elif name.lower() == 'relu': | |
| return relu | |
| else: | |
| raise ValueError(f"Activation function {name} not supported") | |