Spaces:
Paused
Paused
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import copy | |
| from typing import List, Tuple, Optional | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn.attention.flex_attention import create_block_mask | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.modeling_utils import PreTrainedModel | |
| from data.data_utils import ( | |
| create_sparse_mask, | |
| get_flattened_position_ids_extrapolate, | |
| get_flattened_position_ids_interpolate, | |
| patchify, | |
| ) | |
| from .qwen2_navit import NaiveCache | |
| from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding | |
| class BagelConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| visual_gen=True, | |
| visual_und=True, | |
| llm_config=None, | |
| vit_config=None, | |
| vae_config=None, | |
| latent_patch_size=2, | |
| max_latent_size=32, | |
| vit_max_num_patch_per_side=70, | |
| connector_act="gelu_pytorch_tanh", | |
| interpolate_pos=False, | |
| timestep_shift=1.0, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.visual_gen = visual_gen | |
| self.visual_und = visual_und | |
| self.llm_config = llm_config | |
| self.vit_config = vit_config | |
| self.vae_config = vae_config | |
| self.latent_patch_size = latent_patch_size | |
| self.max_latent_size = max_latent_size | |
| self.vit_max_num_patch_per_side = vit_max_num_patch_per_side | |
| self.connector_act = connector_act | |
| self.interpolate_pos = interpolate_pos | |
| self.timestep_shift = timestep_shift | |
| class Bagel(PreTrainedModel): | |
| config_class = BagelConfig | |
| base_model_prefix = 'bagel' | |
| def __init__(self, language_model, vit_model, config: BagelConfig): | |
| super().__init__(config) | |
| self.language_model = language_model | |
| self.hidden_size = config.llm_config.hidden_size | |
| self.use_moe = "Mo" in config.llm_config.layer_module | |
| self.num_heads = config.llm_config.num_attention_heads | |
| if config.visual_gen: | |
| self.latent_patch_size = config.latent_patch_size | |
| self.timestep_shift = config.timestep_shift | |
| self.latent_downsample = config.vae_config.downsample * config.latent_patch_size | |
| self.max_latent_size = config.max_latent_size | |
| self.latent_channel = config.vae_config.z_channels | |
| self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel | |
| self.time_embedder = TimestepEmbedder(self.hidden_size) | |
| self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) | |
| self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) | |
| self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) | |
| if config.visual_und: | |
| self.vit_model = vit_model | |
| self.vit_patch_size = config.vit_config.patch_size | |
| self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side | |
| self.vit_hidden_size = config.vit_config.hidden_size | |
| self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) | |
| self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) | |
| if config.interpolate_pos: | |
| self.get_flattened_position_ids = get_flattened_position_ids_interpolate | |
| else: | |
| self.get_flattened_position_ids = get_flattened_position_ids_extrapolate | |
| self.config = config | |
| self._init_weights() | |
| def _init_weights(self): | |
| if self.config.visual_gen: | |
| nn.init.constant_(self.llm2vae.weight, 0) | |
| nn.init.constant_(self.llm2vae.bias, 0) | |
| def forward( | |
| self, | |
| sequence_length: int, | |
| packed_text_ids: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| sample_lens: List[int], | |
| packed_position_ids: torch.LongTensor, | |
| nested_attention_masks: List[torch.Tensor] = None, | |
| split_lens: List[int] = None, | |
| attn_modes: List[str] = None, | |
| # for visual understanding | |
| ce_loss_indexes: Optional[torch.BoolTensor] = None, | |
| packed_label_ids: Optional[torch.LongTensor] = None, | |
| packed_vit_tokens: Optional[torch.Tensor] = None, | |
| packed_vit_token_indexes: Optional[torch.LongTensor] = None, | |
| packed_vit_position_ids: Optional[torch.LongTensor] = None, | |
| vit_token_seqlens: Optional[torch.IntTensor] = None, | |
| # for visual generation | |
| padded_latent: Optional[torch.Tensor] = None, | |
| patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None, | |
| packed_latent_position_ids: Optional[torch.LongTensor] = None, | |
| packed_vae_token_indexes: Optional[torch.LongTensor] = None, | |
| packed_timesteps: Optional[torch.LongTensor] = None, | |
| mse_loss_indexes: Optional[torch.BoolTensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| sequence_length: length of sequence. | |
| packed_text_ids: 1-D int tensor, packed text token ids. | |
| packed_text_indexes: 1-D int tensor, packed text token indexes in sequence. | |
| sample_lens: A list of N ints, length of each sample in packed_sequence. | |
| nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and | |
| -inf means ignore. | |
| packed_position_ids: packed 1-D positions, an image has only one global position shared | |
| by all latent tokens. | |
| packed_vit_tokens: packed patchified image tokens for vit model. | |
| packed_vit_position_ids: 1-D int tensor, the position of each token for vit model. | |
| packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence. | |
| vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model. | |
| packed_label_ids: 1-D int tensor, packed label token ids. | |
| ce_loss_indexes: 1-D bool tensor, where to compute ce loss. | |
| padded_latent: padded latent from VAE encoder. | |
| patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image. | |
| packed_latent_position_ids: 1-D int tensor, the position of each token for latent. | |
| packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence. | |
| packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image. | |
| mse_loss_indexes: 1-D bool tensor, where to compute mse loss. | |
| """ | |
| packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) | |
| packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size)) | |
| packed_sequence[packed_text_indexes] = packed_text_embedding | |
| if nested_attention_masks is None: | |
| sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device) | |
| seqlen = sum(sample_lens) | |
| block_mask = create_block_mask( | |
| sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, | |
| device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True | |
| ) | |
| attention_mask = block_mask | |
| else: | |
| attention_mask = nested_attention_masks | |
| if self.config.visual_und: | |
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) | |
| cu_seqlens = cu_seqlens.to(torch.int32) | |
| max_seqlen = torch.max(vit_token_seqlens).item() | |
| packed_vit_token_embed = self.vit_model( | |
| packed_pixel_values=packed_vit_tokens, | |
| packed_flattened_position_ids=packed_vit_position_ids, | |
| cu_seqlens=cu_seqlens, | |
| max_seqlen=max_seqlen, | |
| ) | |
| packed_vit_token_embed = self.connector(packed_vit_token_embed) | |
| vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids) | |
| packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb | |
| packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed | |
| if self.config.visual_gen: | |
| p = self.latent_patch_size | |
| packed_latent = [] | |
| for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): | |
| latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) | |
| latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) | |
| packed_latent.append(latent) | |
| packed_latent_clean = torch.cat(packed_latent, dim=0) | |
| noise = torch.randn_like(packed_latent_clean) | |
| packed_timesteps = torch.sigmoid(packed_timesteps) | |
| packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps) | |
| packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise | |
| packed_timestep_embeds = self.time_embedder(packed_timesteps) | |
| latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids) | |
| packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb | |
| packed_sequence[packed_vae_token_indexes] = packed_latent | |
| extra_inputs = {} | |
| if self.use_moe: | |
| packed_und_token_indexes = packed_text_indexes | |
| if packed_vit_token_indexes is not None: | |
| packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0) | |
| extra_inputs.update( | |
| packed_und_token_indexes=packed_und_token_indexes, | |
| packed_gen_token_indexes=packed_vae_token_indexes, | |
| ) | |
| last_hidden_state = self.language_model( | |
| packed_sequence=packed_sequence, | |
| sample_lens=sample_lens, | |
| attention_mask=attention_mask, | |
| packed_position_ids=packed_position_ids, | |
| **extra_inputs, | |
| ) | |
| mse = None | |
| if self.config.visual_gen: | |
| packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes]) | |
| target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise | |
| has_mse = packed_timesteps > 0 | |
| mse = (packed_mse_preds - target[has_mse]) ** 2 | |
| ce = None | |
| if ce_loss_indexes is not None: | |
| packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes]) | |
| ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none") | |
| return dict(mse=mse, ce=ce) | |
| def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): | |
| packed_text_ids = list() | |
| packed_text_position_ids = list() | |
| text_token_lens = list() | |
| packed_text_indexes = list() | |
| packed_key_value_indexes = list() | |
| curr = 0 | |
| newlens, new_rope = list(), list() | |
| for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| curr += curr_kvlen | |
| text_ids = tokenizer.encode(prompt) | |
| text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']] | |
| text_token_lens.append(len(text_ids)) | |
| packed_text_ids.extend(text_ids) | |
| packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) | |
| packed_text_indexes.extend(range(curr, curr + len(text_ids))) | |
| newlens.append(curr_kvlen + len(text_ids)) | |
| new_rope.append(curr_position_id + len(text_ids)) | |
| curr += len(text_ids) | |
| generation_input = { | |
| "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), | |
| "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), | |
| "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), | |
| "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), | |
| "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| } | |
| return generation_input, newlens, new_rope | |
| def forward_cache_update_text( | |
| self, | |
| past_key_values: NaiveCache, | |
| packed_text_ids: torch.IntTensor, | |
| packed_text_position_ids: torch.LongTensor, | |
| text_token_lens: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| packed_key_value_indexes: torch.LongTensor, | |
| key_values_lens: torch.IntTensor, | |
| ): | |
| packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs = {"mode": "und"} | |
| output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_text_embedding, | |
| query_lens=text_token_lens, | |
| packed_query_position_ids=packed_text_position_ids, | |
| packed_query_indexes=packed_text_indexes, | |
| past_key_values=past_key_values, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| key_values_lens=key_values_lens, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| **extra_inputs, | |
| ) | |
| past_key_values = output.past_key_values | |
| return past_key_values | |
| def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): | |
| packed_vit_token_indexes = list() | |
| vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() | |
| packed_text_ids, packed_text_indexes = list(), list() | |
| packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() | |
| packed_key_value_indexes = list() | |
| _curr = curr = 0 | |
| newlens, new_rope = list(), list() | |
| for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| curr += curr_kvlen | |
| packed_text_ids.append(new_token_ids['start_of_image']) | |
| packed_text_indexes.append(_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| _curr += 1 | |
| image_tensor = transforms(image) | |
| vit_position_ids = self.get_flattened_position_ids( | |
| image_tensor.size(1), image_tensor.size(2), | |
| self.vit_patch_size, | |
| max_num_patches_per_side=self.vit_max_num_patch_per_side | |
| ) | |
| vit_tokens = patchify(image_tensor, self.vit_patch_size) | |
| packed_vit_tokens.append(vit_tokens) | |
| num_img_tokens = vit_tokens.shape[0] | |
| packed_vit_position_ids.append(vit_position_ids) | |
| vit_token_seqlens.append(num_img_tokens) | |
| packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) | |
| packed_indexes.extend(range(curr, curr + num_img_tokens)) | |
| curr += num_img_tokens | |
| _curr += num_img_tokens | |
| packed_text_ids.append(new_token_ids['end_of_image']) | |
| packed_text_indexes.append(_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| _curr += 1 | |
| packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) | |
| packed_seqlens.append(num_img_tokens + 2) | |
| newlens.append(curr_kvlen + num_img_tokens + 2) | |
| new_rope.append(curr_position_id + 1) | |
| generation_input = { | |
| "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), | |
| "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), | |
| "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), | |
| "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), | |
| "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), | |
| "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), | |
| "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), | |
| "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), | |
| "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), | |
| "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| } | |
| return generation_input, newlens, new_rope | |
| def forward_cache_update_vit( | |
| self, | |
| past_key_values: NaiveCache, | |
| packed_text_ids: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| packed_vit_tokens: torch.Tensor, | |
| packed_vit_token_indexes: torch.LongTensor, | |
| packed_vit_position_ids: torch.LongTensor, | |
| vit_token_seqlens: torch.IntTensor, | |
| packed_position_ids: torch.LongTensor, | |
| packed_seqlens: torch.IntTensor, | |
| packed_indexes: torch.LongTensor, | |
| packed_key_value_indexes: torch.LongTensor, | |
| key_values_lens: torch.IntTensor, | |
| ): | |
| packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) | |
| packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) | |
| packed_sequence[packed_text_indexes] = packed_text_embedding | |
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) | |
| cu_seqlens = cu_seqlens.to(torch.int32) | |
| max_seqlen = torch.max(vit_token_seqlens).item() | |
| packed_vit_token_embed = self.vit_model( | |
| packed_pixel_values=packed_vit_tokens, | |
| packed_flattened_position_ids=packed_vit_position_ids, | |
| cu_seqlens=cu_seqlens, | |
| max_seqlen=max_seqlen, | |
| ) | |
| packed_vit_token_embed = self.connector(packed_vit_token_embed) | |
| pos_emb = self.vit_pos_embed(packed_vit_position_ids) | |
| packed_vit_token_embed = packed_vit_token_embed + pos_emb | |
| packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs = {"mode": "und"} | |
| output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_sequence, | |
| query_lens=packed_seqlens, | |
| packed_query_position_ids=packed_position_ids, | |
| packed_query_indexes=packed_indexes, | |
| past_key_values=past_key_values, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| key_values_lens=key_values_lens, | |
| update_past_key_values=True, | |
| is_causal=False, | |
| **extra_inputs, | |
| ) | |
| past_key_values = output.past_key_values | |
| return past_key_values | |
| def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): | |
| patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() | |
| packed_vae_token_indexes = list() | |
| packed_text_ids, packed_text_indexes = list(), list() | |
| packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() | |
| packed_key_value_indexes = list() | |
| _curr = curr = 0 | |
| vae_image_tensors = list() | |
| newlens, new_rope = list(), list() | |
| for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| curr += curr_kvlen | |
| packed_text_ids.append(new_token_ids['start_of_image']) | |
| packed_text_indexes.append(_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| _curr += 1 | |
| image_tensor = transforms(image) | |
| vae_image_tensors.append(image_tensor) | |
| vae_posiiton_ids = self.get_flattened_position_ids( | |
| image_tensor.size(1), image_tensor.size(2), | |
| self.latent_downsample, | |
| max_num_patches_per_side=self.max_latent_size | |
| ) | |
| packed_vae_position_ids.append(vae_posiiton_ids) | |
| H, W = image_tensor.shape[1:] | |
| h = H // self.latent_downsample | |
| w = W // self.latent_downsample | |
| patchified_vae_latent_shapes.append((h, w)) | |
| num_img_tokens = w * h | |
| packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) | |
| packed_indexes.extend(range(curr, curr + num_img_tokens)) | |
| curr += num_img_tokens | |
| _curr += num_img_tokens | |
| packed_text_ids.append(new_token_ids['end_of_image']) | |
| packed_text_indexes.append(_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| _curr += 1 | |
| packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) | |
| packed_seqlens.append(num_img_tokens + 2) | |
| newlens.append(curr_kvlen + num_img_tokens + 2) | |
| new_rope.append(curr_position_id + 1) | |
| image_sizes = [item.shape for item in vae_image_tensors] | |
| max_image_size = [max(item) for item in list(zip(*image_sizes))] | |
| padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) | |
| for i, image_tensor in enumerate(vae_image_tensors): | |
| padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor | |
| generation_input = { | |
| "padded_images": padded_images, | |
| "patchified_vae_latent_shapes": patchified_vae_latent_shapes, | |
| "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), | |
| "packed_timesteps": torch.tensor([timestep]), | |
| "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), | |
| "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), | |
| "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), | |
| "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), | |
| "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), | |
| "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), | |
| "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| } | |
| return generation_input, newlens, new_rope | |
| def forward_cache_update_vae( | |
| self, | |
| vae_model, | |
| past_key_values: NaiveCache, | |
| padded_images: torch.Tensor, | |
| patchified_vae_latent_shapes: List, | |
| packed_vae_position_ids: torch.LongTensor, | |
| packed_timesteps: torch.Tensor, | |
| packed_vae_token_indexes: torch.LongTensor, | |
| packed_text_ids: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| packed_position_ids: torch.LongTensor, | |
| packed_seqlens: torch.IntTensor, | |
| packed_indexes: torch.LongTensor, | |
| key_values_lens: torch.IntTensor, | |
| packed_key_value_indexes: torch.Tensor, | |
| ): | |
| packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) | |
| packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) | |
| packed_sequence[packed_text_indexes] = packed_text_embedding | |
| padded_latent = vae_model.encode(padded_images) | |
| p = self.latent_patch_size | |
| packed_latent = list() | |
| for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): | |
| latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p) | |
| latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) | |
| packed_latent.append(latent) | |
| packed_latent = torch.cat(packed_latent, dim=0) | |
| packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) | |
| packed_timestep_embeds = self.time_embedder(packed_timesteps) | |
| packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed | |
| packed_sequence[packed_vae_token_indexes] = packed_latent | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs = { | |
| "mode": "gen", | |
| "packed_vae_token_indexes": packed_vae_token_indexes, | |
| "packed_text_indexes": packed_text_indexes | |
| } | |
| output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_sequence, | |
| query_lens=packed_seqlens, | |
| packed_query_position_ids=packed_position_ids, | |
| packed_query_indexes=packed_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=True, | |
| is_causal=False, | |
| **extra_inputs, | |
| ) | |
| past_key_values = output.past_key_values | |
| return past_key_values | |
| def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): | |
| packed_text_ids, packed_text_indexes = list(), list() | |
| packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() | |
| packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() | |
| packed_key_value_indexes = list() | |
| query_curr = curr = 0 | |
| for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| curr += curr_kvlen | |
| packed_text_ids.append(new_token_ids['start_of_image']) | |
| packed_text_indexes.append(query_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| query_curr += 1 | |
| vae_posiiton_ids = self.get_flattened_position_ids( | |
| H, W, | |
| self.latent_downsample, | |
| max_num_patches_per_side=self.max_latent_size | |
| ) | |
| packed_vae_position_ids.append(vae_posiiton_ids) | |
| h, w = H // self.latent_downsample, W // self.latent_downsample | |
| num_image_tokens = h * w | |
| packed_init_noises.append( | |
| torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2) | |
| ) | |
| packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) | |
| packed_indexes.extend(range(curr, curr + num_image_tokens)) | |
| curr += num_image_tokens | |
| query_curr += num_image_tokens | |
| packed_text_ids.append(new_token_ids['end_of_image']) | |
| packed_text_indexes.append(query_curr) | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| query_curr += 1 | |
| packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) | |
| packed_seqlens.append(num_image_tokens + 2) | |
| generation_input = { | |
| "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), | |
| "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), | |
| "packed_init_noises": torch.cat(packed_init_noises, dim=0), | |
| "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), | |
| "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), | |
| "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), | |
| "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), | |
| "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), | |
| "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| } | |
| return generation_input | |
| def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): | |
| packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list() | |
| query_curr = curr = 0 | |
| for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| curr += curr_kvlen | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| query_curr += 1 | |
| h, w = H // self.latent_downsample, W // self.latent_downsample | |
| num_image_tokens = h * w | |
| packed_indexes.extend(range(curr, curr + num_image_tokens)) | |
| curr += num_image_tokens | |
| query_curr += num_image_tokens | |
| packed_indexes.append(curr) | |
| curr += 1 | |
| query_curr += 1 | |
| packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) | |
| generation_input = { | |
| "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), | |
| "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long), | |
| "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| } | |
| return generation_input | |
| def generate_image( | |
| self, | |
| packed_text_ids: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| packed_init_noises: torch.Tensor, | |
| packed_vae_position_ids: torch.LongTensor, | |
| packed_vae_token_indexes: torch.LongTensor, | |
| packed_seqlens: torch.IntTensor, | |
| packed_position_ids: torch.LongTensor, | |
| packed_indexes: torch.LongTensor, | |
| past_key_values: NaiveCache, | |
| key_values_lens: torch.IntTensor, | |
| packed_key_value_indexes: torch.LongTensor, | |
| num_timesteps: int = 24, | |
| timestep_shift: float = 1.0, | |
| cfg_renorm_min: float = 0.0, | |
| cfg_renorm_type: str = "global", | |
| cfg_interval: Optional[Tuple[float, float]] = [0, 1], | |
| # cfg_text | |
| cfg_text_scale: float = 1.0, | |
| cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, | |
| cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, | |
| cfg_text_past_key_values: Optional[NaiveCache] = None, | |
| cfg_text_key_values_lens: Optional[torch.IntTensor] = None, | |
| cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, | |
| # cfg_img | |
| cfg_img_scale: float = 1.0, | |
| cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, | |
| cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, | |
| cfg_img_past_key_values: Optional[NaiveCache] = None, | |
| cfg_img_key_values_lens: Optional[torch.IntTensor] = None, | |
| cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, | |
| cfg_type: str = "parallel", | |
| ): | |
| x_t = packed_init_noises | |
| timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) | |
| timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) | |
| dts = timesteps[:-1] - timesteps[1:] | |
| timesteps = timesteps[:-1] | |
| for i, t in enumerate(timesteps): | |
| timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) | |
| if t > cfg_interval[0] and t <= cfg_interval[1]: | |
| cfg_text_scale_ = cfg_text_scale | |
| cfg_img_scale_ = cfg_img_scale | |
| else: | |
| cfg_text_scale_ = 1.0 | |
| cfg_img_scale_ = 1.0 | |
| v_t = self._forward_flow( | |
| x_t=x_t, | |
| timestep=timestep, | |
| packed_vae_token_indexes=packed_vae_token_indexes, | |
| packed_vae_position_ids=packed_vae_position_ids, | |
| packed_text_ids=packed_text_ids, | |
| packed_text_indexes=packed_text_indexes, | |
| packed_position_ids=packed_position_ids, | |
| packed_indexes=packed_indexes, | |
| packed_seqlens=packed_seqlens, | |
| key_values_lens=key_values_lens, | |
| past_key_values=past_key_values, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| cfg_renorm_min=cfg_renorm_min, | |
| cfg_renorm_type=cfg_renorm_type, | |
| # cfg_text | |
| cfg_text_scale=cfg_text_scale_, | |
| cfg_text_packed_position_ids=cfg_text_packed_position_ids, | |
| cfg_text_packed_query_indexes=cfg_text_packed_query_indexes, | |
| cfg_text_key_values_lens=cfg_text_key_values_lens, | |
| cfg_text_past_key_values=cfg_text_past_key_values, | |
| cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes, | |
| # cfg_img | |
| cfg_img_scale=cfg_img_scale_, | |
| cfg_img_packed_position_ids=cfg_img_packed_position_ids, | |
| cfg_img_packed_query_indexes=cfg_img_packed_query_indexes, | |
| cfg_img_key_values_lens=cfg_img_key_values_lens, | |
| cfg_img_past_key_values=cfg_img_past_key_values, | |
| cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes, | |
| cfg_type=cfg_type, | |
| ) | |
| x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise | |
| unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) | |
| return unpacked_latent | |
| def _forward_flow( | |
| self, | |
| x_t: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| packed_vae_token_indexes: torch.LongTensor, | |
| packed_vae_position_ids: torch.LongTensor, | |
| packed_text_ids: torch.LongTensor, | |
| packed_text_indexes: torch.LongTensor, | |
| packed_indexes: torch.LongTensor, | |
| packed_position_ids: torch.LongTensor, | |
| packed_seqlens: torch.IntTensor, | |
| key_values_lens: torch.IntTensor, | |
| past_key_values: NaiveCache, | |
| packed_key_value_indexes: torch.LongTensor, | |
| cfg_renorm_min: float = 0.0, | |
| cfg_renorm_type: str = "global", | |
| # cfg_text | |
| cfg_text_scale: float = 1.0, | |
| cfg_text_packed_position_ids: Optional[torch.LongTensor] = None, | |
| cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None, | |
| cfg_text_key_values_lens: Optional[torch.Tensor] = None, | |
| cfg_text_past_key_values: Optional[NaiveCache] = None, | |
| cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None, | |
| # cfg_img | |
| cfg_img_scale: float = 1.0, | |
| cfg_img_packed_position_ids: Optional[torch.LongTensor] = None, | |
| cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None, | |
| cfg_img_key_values_lens: Optional[torch.Tensor] = None, | |
| cfg_img_past_key_values: Optional[NaiveCache] = None, | |
| cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None, | |
| cfg_type: str = "parallel", | |
| ): | |
| packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) | |
| packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) | |
| packed_sequence[packed_text_indexes] = packed_text_embedding | |
| assert timestep.unique().shape[0] == 1 | |
| packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) | |
| packed_timestep_embeds = self.time_embedder(timestep) | |
| x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed | |
| packed_sequence[packed_vae_token_indexes] = x_t | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs = { | |
| "mode": "gen", | |
| "packed_vae_token_indexes": packed_vae_token_indexes, | |
| "packed_text_indexes": packed_text_indexes | |
| } | |
| output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_sequence, | |
| query_lens=packed_seqlens, | |
| packed_query_position_ids=packed_position_ids, | |
| packed_query_indexes=packed_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=False, | |
| is_causal=False, | |
| **extra_inputs, | |
| ) | |
| v_t = self.llm2vae(output.packed_query_sequence) | |
| v_t = v_t[packed_vae_token_indexes] | |
| if cfg_text_scale > 1.0: | |
| cfg_text_output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_sequence, | |
| query_lens=packed_seqlens, | |
| packed_query_position_ids=cfg_text_packed_position_ids, | |
| packed_query_indexes=cfg_text_packed_query_indexes, | |
| past_key_values=cfg_text_past_key_values, | |
| key_values_lens=cfg_text_key_values_lens, | |
| packed_key_value_indexes=cfg_text_packed_key_value_indexes, | |
| update_past_key_values=False, | |
| is_causal=False, | |
| **extra_inputs, | |
| ) | |
| cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence) | |
| cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes] | |
| if cfg_img_scale > 1.0: | |
| cfg_img_output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_sequence, | |
| query_lens=packed_seqlens, | |
| packed_query_position_ids=cfg_img_packed_position_ids, | |
| packed_query_indexes=cfg_img_packed_query_indexes, | |
| past_key_values=cfg_img_past_key_values, | |
| key_values_lens=cfg_img_key_values_lens, | |
| packed_key_value_indexes=cfg_img_packed_key_value_indexes, | |
| update_past_key_values=False, | |
| is_causal=False, | |
| **extra_inputs, | |
| ) | |
| cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence) | |
| cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes] | |
| if cfg_text_scale > 1.0: | |
| if cfg_renorm_type == "text_channel": | |
| v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) | |
| norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) | |
| norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) | |
| scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) | |
| v_t_text = v_t_text_ * scale | |
| if cfg_img_scale > 1.0: | |
| v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) | |
| else: | |
| v_t = v_t_text | |
| else: | |
| v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) | |
| if cfg_img_scale > 1.0: | |
| v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) | |
| else: | |
| v_t_ = v_t_text_ | |
| # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit | |
| if cfg_renorm_type == "global": | |
| norm_v_t = torch.norm(v_t) | |
| norm_v_t_ = torch.norm(v_t_) | |
| elif cfg_renorm_type == "channel": | |
| norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) | |
| norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) | |
| else: | |
| raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted") | |
| scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) | |
| v_t = v_t_ * scale | |
| else: | |
| # No CFG | |
| pass | |
| return v_t | |
| def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): | |
| packed_start_tokens, packed_key_value_indexes = list(), list() | |
| packed_query_position_ids = list() | |
| curr = 0 | |
| for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): | |
| packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) | |
| packed_start_tokens.append(new_token_ids['bos_token_id']) | |
| packed_query_position_ids.append(curr_position_id) | |
| curr += curr_kvlen | |
| generation_input = { | |
| "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), | |
| "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), | |
| "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), | |
| "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), | |
| } | |
| return generation_input | |
| def generate_text( | |
| self, | |
| past_key_values: NaiveCache, | |
| packed_key_value_indexes: torch.LongTensor, | |
| key_values_lens: torch.IntTensor, | |
| packed_start_tokens: torch.LongTensor, | |
| packed_query_position_ids: torch.LongTensor, | |
| max_length: int, | |
| do_sample: bool = False, | |
| temperature: float = 1.0, | |
| end_token_id: int = None, | |
| ): | |
| """ | |
| Generates text token by token in a streaming fashion. | |
| This function is a generator that yields one token at a time. It replicates | |
| the behavior of the original batch generation function, including the handling | |
| of start tokens and the end-of-sequence token. | |
| """ | |
| step = 0 | |
| curr_tokens = packed_start_tokens | |
| while step < max_length: | |
| packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens) | |
| query_lens = torch.ones_like(curr_tokens) | |
| packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( | |
| 0, len(key_values_lens), | |
| device=key_values_lens.device, | |
| dtype=key_values_lens.dtype | |
| ) | |
| uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) | |
| for i in range(len(uppacked)): | |
| uppacked[i] += i | |
| packed_key_value_indexes = torch.cat(uppacked, dim=0) | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs = {"mode": "und"} | |
| output = self.language_model.forward_inference( | |
| packed_query_sequence=packed_text_embedding, | |
| query_lens=query_lens, | |
| packed_query_position_ids=packed_query_position_ids, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| **extra_inputs, | |
| ) | |
| past_key_values = output.past_key_values | |
| packed_query_sequence = output.packed_query_sequence | |
| pred_logits = self.language_model.lm_head(packed_query_sequence) | |
| if do_sample: | |
| probs = nn.functional.softmax(pred_logits / temperature, dim=-1) | |
| curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| curr_tokens = torch.argmax(pred_logits, dim=-1) | |
| uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) | |
| for i in range(len(uppacked)): | |
| uppacked[i] = torch.cat( | |
| [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 | |
| ) | |
| packed_key_value_indexes = torch.cat(uppacked, dim=0) | |
| key_values_lens = key_values_lens + 1 | |
| packed_query_position_ids = packed_query_position_ids + 1 | |
| step += 1 | |
| yield curr_tokens # Yield each token as it's generated | |
| if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1 | |
| break | |
| # for evaluation | |
| def chat( | |
| self, | |
| tokenizer, | |
| new_token_ids, | |
| image_transform, | |
| images, | |
| prompt, | |
| max_length: int, | |
| do_sample: bool = False, | |
| temperature: float = 1.0, | |
| ): | |
| device = next(self.parameters()).device | |
| if isinstance(new_token_ids, dict): | |
| for k, v in new_token_ids.items(): | |
| if torch.is_tensor(v): | |
| new_token_ids[k] = v.to(device) | |
| elif torch.is_tensor(new_token_ids): | |
| new_token_ids = new_token_ids.to(device) | |
| # prefill | |
| past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers) | |
| newlens = [0] | |
| new_rope = [0] | |
| # add images | |
| for image in images: | |
| generation_input, newlens, new_rope = self.prepare_vit_images( | |
| curr_kvlens=newlens, | |
| curr_rope=new_rope, | |
| images=[image], | |
| transforms=image_transform, | |
| new_token_ids=new_token_ids, | |
| ) | |
| for k, v in generation_input.items(): | |
| if torch.is_tensor(v): | |
| generation_input[k] = v.to(device) | |
| with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input) | |
| # add text | |
| generation_input, newlens, new_rope = self.prepare_prompts( | |
| curr_kvlens=newlens, | |
| curr_rope=new_rope, | |
| prompts=[prompt], | |
| tokenizer=tokenizer, | |
| new_token_ids=new_token_ids, | |
| ) | |
| for k, v in generation_input.items(): | |
| if torch.is_tensor(v): | |
| generation_input[k] = v.to(device) | |
| with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| past_key_values = self.forward_cache_update_text(past_key_values, **generation_input) | |
| # decode | |
| generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids) | |
| for k, v in generation_input.items(): | |
| if torch.is_tensor(v): | |
| generation_input[k] = v.to(device) | |
| with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| for unpacked_latent in self.generate_text( | |
| past_key_values=past_key_values, | |
| max_length=max_length, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| end_token_id=new_token_ids['eos_token_id'], | |
| **generation_input, | |
| ): | |
| output = tokenizer.decode(unpacked_latent[:,0]) | |
| yield output |