|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import Qwen2PreTrainedModel, Qwen2Config |
|
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer |
|
|
|
|
|
|
|
|
class Qwen2Connector(Qwen2PreTrainedModel): |
|
|
def __init__(self, config: Qwen2Config): |
|
|
super().__init__(config) |
|
|
self.layers = nn.ModuleList( |
|
|
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
|
|
|
for layer in self.layers: |
|
|
layer.self_attn.is_causal = False |
|
|
|
|
|
self._attn_implementation = config._attn_implementation |
|
|
assert self._attn_implementation == 'flash_attention_2' |
|
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, inputs_embeds): |
|
|
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) |
|
|
position_ids = position_ids.expand(inputs_embeds.shape[0], -1) |
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
for encoder_layer in self.layers: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
encoder_layer.__call__, |
|
|
hidden_states, |
|
|
None, |
|
|
position_ids, |
|
|
use_reentrant=False |
|
|
) |
|
|
else: |
|
|
layer_outputs = encoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|