Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from fam.llm.layers.attn import SelfAttention | |
| from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm | |
| class Block(nn.Module): | |
| """ | |
| Block class represents a single block in the model. | |
| Args: | |
| config (object): Configuration object containing parameters for the block. | |
| Attributes: | |
| ln_1 (object): Layer normalization for the attention layer. | |
| ln_2 (object): Layer normalization for the feed-forward layer. | |
| attn (object): Self-attention layer. | |
| mlp (object): Multi-layer perceptron layer. | |
| Methods: | |
| forward(x): Performs forward pass through the block. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| if config.norm_type == "rmsnorm": | |
| if config.rmsnorm_eps is None: | |
| raise Exception("RMSNorm requires rmsnorm_eps to be set") | |
| self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm | |
| self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm | |
| elif config.norm_type == "layernorm": | |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm | |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm | |
| else: | |
| raise Exception(f"Unknown norm type: {config.norm_type}") | |
| self.attn = SelfAttention(config) | |
| self.mlp = MLP(config) | |
| def forward(self, x): | |
| """ | |
| Performs forward pass through the block. | |
| Args: | |
| x (tensor): Input tensor. | |
| Returns: | |
| tensor: Output tensor after passing through the block. | |
| """ | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |