NEO1_0-2B-SFT / modeling_neo_vit.py
Paranioar's picture
Upload folder using huggingface_hub
d585119 verified
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from .configuration_neo_vit import NEOVisionConfig
def precompute_rope_freqs_sincos(
dim: int, max_position: int, base: float = 10000.0, device=None
):
"""预计算 RoPE 的 cos 和 sin 值 (1D)。"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(max_position, device=device).type_as(inv_freq)
freqs = torch.outer(t, inv_freq)
return torch.cos(freqs), torch.sin(freqs)
def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
"""
Compute patch coordinates (x, y)
Args:
grid_hw: (B, 2) tensor representing (H, W) per image
"""
device = grid_hw.device
B = grid_hw.shape[0]
# Get the number of patches per image
H = grid_hw[:, 0]
W = grid_hw[:, 1]
N = H * W
N_total = N.sum()
# Create the batch index for each patch (B x patch count)
patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
# Generate intra-image patch index (row-major order)
patch_id_within_image = torch.arange(N_total, device=device)
patch_id_within_image = patch_id_within_image - torch.cumsum(
torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
)[patch_to_sample]
# Get H/W for each patch according to its image
W_per_patch = W[patch_to_sample]
abs_x = patch_id_within_image % W_per_patch
abs_y = patch_id_within_image // W_per_patch
return abs_x, abs_y
def apply_rotary_emb_1d(
x: torch.Tensor,
cos_cached: torch.Tensor,
sin_cached: torch.Tensor,
positions: torch.Tensor,
):
"""对输入张量的一部分应用1D RoPE。"""
# x: (..., seq_len, dim_part)
# positions: (..., seq_len)
# cos_cached: (max_pos, dim_part / 2)
cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
x_rotated = torch.empty_like(x)
x_rotated[..., 0::2] = rotated_x1
x_rotated[..., 1::2] = rotated_x2
return x_rotated
def apply_2d_rotary_pos_emb(
x: torch.Tensor,
cos_cached_x: torch.Tensor,
sin_cached_x: torch.Tensor,
cos_cached_y: torch.Tensor,
sin_cached_y: torch.Tensor,
abs_positions_x: torch.Tensor,
abs_positions_y: torch.Tensor
):
"""应用2D RoPE到输入张量x。"""
dim = x.shape[-1]
dim_half = dim // 2
# 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
# 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
x_part_1 = x[..., :dim_half]
x_part_2 = x[..., dim_half:]
# 将与 abs_positions_x 相关的旋转应用于 x_part_1
rotated_part_1 = apply_rotary_emb_1d(
x_part_1, cos_cached_x, sin_cached_x, abs_positions_x
)
# 将与 abs_positions_y 相关的旋转应用于 x_part_2
rotated_part_2 = apply_rotary_emb_1d(
x_part_2, cos_cached_y, sin_cached_y, abs_positions_y
)
# 将它们重新拼接起来。确保顺序与你分割时一致。
return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
class NEOVisionEmbeddings(nn.Module):
"""
Embedding Module for Vision.
"""
def __init__(self, config: NEOVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.llm_embed_dim = config.llm_hidden_size[0]
self.downsample_factor = int(1 / config.downsample_ratio[0])
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.dense_embedding = nn.Conv2d(
in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor
)
self.gelu = nn.GELU()
self.rope_dim_part = self.embed_dim // 2
cos_x, sin_x = precompute_rope_freqs_sincos(
self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
)
cos_y, sin_y = precompute_rope_freqs_sincos(
self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
)
self.register_buffer("cos_cached_x", cos_x, persistent=False)
self.register_buffer("sin_cached_x", sin_x, persistent=False)
self.register_buffer("cos_cached_y", cos_y, persistent=False)
self.register_buffer("sin_cached_y", sin_y, persistent=False)
def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
"""
Apply 2D Rotary Position Embedding to the patch embeddings.
"""
abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
embeddings = apply_2d_rotary_pos_emb(
patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
self.cos_cached_x, self.sin_cached_x,
self.cos_cached_y, self.sin_cached_y,
abs_pos_x,
abs_pos_y
).to(self.patch_embedding.weight.dtype)
return embeddings
def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor:
pixel_values = pixel_values.view( #
-1,
3,
self.patch_size,
self.patch_size,
) # [28072, 768] -> [28072, 3, 16, 16]
patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device)
self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device)
self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device)
self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device)
patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024]
assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0]
patches_list = []
cur_position = 0
for i in range(grid_hw.shape[0]):
h, w = grid_hw[i]
patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
patches_per_img = patches_per_img.permute(0, 2, 3, 1)
patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
cur_position += h * w
embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
assert cur_position == patch_embeds.shape[0]
assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2)
return embeddings
class NEOVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
config_class = NEOVisionConfig
# support transformers 4.51.+
_tp_plan = ''
def __init__(self, config: NEOVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = NEOVisionEmbeddings(config)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
grid_hw: Optional[torch.Tensor] = None
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
assert pixel_values.dim() == 2, f"pixel_values must be 2D for native resolution, got: {pixel_values.dim()}"
hidden_states = self.embeddings(pixel_values, grid_hw=grid_hw)
return BaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=None,
hidden_states=None,
attentions=None,
)