File size: 318 Bytes
cc25da0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
class GlmMLP(Phi3MLP):
    pass

class GlmAttention(LlamaAttention):
    def __init__(self, config, layer_idx=None):
        super().__init__(config, layer_idx)
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)

class GlmForCausalLM(LlamaForCausalLM):
    pass