Puffin / src /models /connector /modeling_qwen2.py
KangLiao's picture
init
ace9173
import torch
import torch.nn as nn
from transformers import Qwen2PreTrainedModel, Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
class Qwen2Connector(Qwen2PreTrainedModel):
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
for layer in self.layers:
layer.self_attn.is_causal = False
self._attn_implementation = config._attn_implementation
assert self._attn_implementation == 'flash_attention_2'
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(self, inputs_embeds):
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = position_ids.expand(inputs_embeds.shape[0], -1)
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
position_ids,
use_reentrant=False
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=None,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return hidden_states