Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import inspect | |
| from loguru import logger | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import CrossEntropyLoss | |
| from safetensors.torch import safe_open | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.cache_utils import Cache, DynamicCache, StaticCache | |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from models.config import NextStepConfig | |
| from models.llama_model import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding | |
| from models.heads import FlowMatchingHead | |
| from utils.misc import LargeInt | |
| from utils.compile_utils import smart_compile | |
| class NextStepOutputWithPast(CausalLMOutputWithPast): | |
| lm_loss: torch.FloatTensor | None = None | |
| im_loss: torch.FloatTensor | None = None | |
| class NextStepPreTrainedModel(PreTrainedModel): | |
| config_class = NextStepConfig | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["LlamaDecoderLayer"] | |
| _skip_keys_device_placement = ["past_key_values"] | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_cache_class = True | |
| _supports_quantized_cache = True | |
| _supports_static_cache = True | |
| def _init_weights(self, module): | |
| std = self.config.initializer_range | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| def trainable_params(self) -> float: | |
| n_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return LargeInt(n_params) | |
| class NextStep(NextStepPreTrainedModel): | |
| def __init__(self, config: NextStepConfig): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) | |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = LlamaRotaryEmbedding(config=config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.gradient_checkpointing = False | |
| token_dim = self.config.latent_channels * self.config.latent_patch_size**2 | |
| self.image_in_projector = nn.Linear(token_dim, config.hidden_size) | |
| self.image_in_projector.weight.data.normal_(mean=0.0, std=config.initializer_range) | |
| self.image_in_projector.bias.data.zero_() | |
| self.image_out_projector = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.image_out_projector.weight.data.normal_(mean=0.0, std=config.initializer_range) | |
| self.image_out_projector.bias.data.zero_() | |
| self.image_head = FlowMatchingHead( | |
| input_dim=token_dim, | |
| cond_dim=config.hidden_size, | |
| dim=config.fm_head_dim, | |
| layers=config.fm_head_layers, | |
| ) | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def load_lm_head(self, lm_head_dir: str | None = None): | |
| index_json_file = os.path.join(lm_head_dir, "model.safetensors.index.json") | |
| head_weight_name = "lm_head.weight" if not self.config.tie_word_embeddings else "model.embed_tokens.weight" | |
| if os.path.exists(index_json_file): | |
| with open(index_json_file, "r") as f: | |
| index = json.load(f) | |
| model_name = index["weight_map"][head_weight_name] | |
| else: | |
| model_name = "model.safetensors" | |
| with safe_open(os.path.join(lm_head_dir, model_name), framework="pt") as f: | |
| loaded_weight = f.get_tensor(head_weight_name) | |
| loaded_weight = loaded_weight.to(dtype=self.lm_head.weight.dtype, device=self.lm_head.weight.device) | |
| self.lm_head.weight.data.copy_(loaded_weight) | |
| def patchify(self, img: torch.Tensor): | |
| """ | |
| img: (bsz, C, H, W) | |
| x: (bsz, H * W / patch_size**2, patch_size**2 * C) | |
| """ | |
| bsz, c, h, w = img.shape | |
| p = self.config.latent_patch_size | |
| h_, w_ = h // p, w // p | |
| img = img.reshape(bsz, c, h_, p, w_, p) | |
| img = torch.einsum("nchpwq->nhwcpq", img) | |
| x = img.reshape(bsz, h_ * w_, c * p**2) | |
| return x | |
| def unpatchify(self, x: torch.Tensor, h: int = None, w: int = None): | |
| """ | |
| x: (bsz, H * W / patch_size**2, patch_size**2 * C) | |
| img: (bsz, C, H, W) | |
| """ | |
| bsz = x.shape[0] | |
| p = self.config.latent_patch_size | |
| c = self.config.latent_channels | |
| if h is None and w is None: | |
| h_ = w_ = int(x.shape[1] ** 0.5) | |
| else: | |
| h_, w_ = h, w | |
| assert h_ * w_ == x.shape[1], f"Invalid sequence length {x.shape[1]}." | |
| x = x.reshape(bsz, h_, w_, c, p, p) | |
| x = torch.einsum("nhwcpq->nchpwq", x) | |
| img = x.reshape(bsz, c, h_ * p, w_ * p) | |
| return img | |
| def prepare_inputs_embeds(self, input_ids: torch.LongTensor | None = None, latents: torch.FloatTensor | None = None): | |
| if latents is None: | |
| if not self.training: | |
| return self.embed_tokens(input_ids) | |
| else: # dummy forward for image pass, for the consistent shape of gradient. | |
| raise NotImplementedError("Dummy forward for image pass is not implemented.") | |
| else: | |
| bs, seq_length = input_ids.shape | |
| inputs_embeds = torch.zeros( | |
| (bs, seq_length, self.config.hidden_size), | |
| device=self.embed_tokens.weight.device, | |
| dtype=self.embed_tokens.weight.dtype, | |
| ) | |
| im_indices = input_ids == self.config.image_placeholder_id | |
| lm_indices = ~im_indices | |
| if isinstance(latents, list): | |
| tokens = torch.cat([self.patchify(latent) for latent in latents], dim=1) | |
| else: | |
| tokens = self.patchify(latents) | |
| # tokens = tokens.reshape(1, -1, tokens.shape[-1]) | |
| image_embeds = self.image_in_projector(tokens) | |
| image_embeds = image_embeds.view(-1, self.config.hidden_size) | |
| token_embeds = self.embed_tokens(input_ids[lm_indices]) | |
| inputs_embeds[im_indices] = image_embeds.to(inputs_embeds.dtype) | |
| inputs_embeds[lm_indices] = token_embeds | |
| return inputs_embeds | |
| def _update_causal_mask( | |
| self, | |
| attention_mask: torch.Tensor, | |
| input_tensor: torch.Tensor, | |
| cache_position: torch.Tensor, | |
| past_key_values: Cache, | |
| output_attentions: bool, | |
| ): | |
| if self.config._attn_implementation == "flash_attention_2": | |
| if attention_mask is not None and (attention_mask == 0.0).any(): | |
| return attention_mask | |
| return None | |
| # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | |
| # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | |
| # to infer the attention mask. | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| using_static_cache = isinstance(past_key_values, StaticCache) | |
| # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward | |
| if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: | |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( | |
| attention_mask, | |
| inputs_embeds=input_tensor, | |
| past_key_values_length=past_seen_tokens, | |
| is_training=self.training, | |
| ): | |
| return None | |
| dtype, device = input_tensor.dtype, input_tensor.device | |
| sequence_length = input_tensor.shape[1] | |
| if using_static_cache: | |
| target_length = past_key_values.get_max_cache_shape() | |
| else: | |
| target_length = ( | |
| attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 | |
| ) | |
| # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). | |
| causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( | |
| attention_mask, | |
| sequence_length=sequence_length, | |
| target_length=target_length, | |
| dtype=dtype, | |
| device=device, | |
| cache_position=cache_position, | |
| batch_size=input_tensor.shape[0], | |
| ) | |
| if ( | |
| self.config._attn_implementation == "sdpa" | |
| and attention_mask is not None | |
| and attention_mask.device.type == "cuda" | |
| and not output_attentions | |
| ): | |
| # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when | |
| # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | |
| # Details: https://github.com/pytorch/pytorch/issues/110213 | |
| min_dtype = torch.finfo(dtype).min | |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) | |
| return causal_mask | |
| def _prepare_4d_causal_attention_mask_with_cache_position( | |
| attention_mask: torch.Tensor, | |
| sequence_length: int, | |
| target_length: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| cache_position: torch.Tensor, | |
| batch_size: int, | |
| **kwargs, | |
| ): | |
| """ | |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. | |
| Args: | |
| attention_mask (`torch.Tensor`): | |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape | |
| `(batch_size, 1, query_length, key_value_length)`. | |
| sequence_length (`int`): | |
| The sequence length being processed. | |
| target_length (`int`): | |
| The target length: when generating with static cache, the mask should be as long as the static cache, | |
| to account for the 0 padding, the part of the cache that is not filled yet. | |
| dtype (`torch.dtype`): | |
| The dtype to use for the 4D attention mask. | |
| device (`torch.device`): | |
| The device to plcae the 4D attention mask on. | |
| cache_position (`torch.Tensor`): | |
| Indices depicting the position of the input sequence tokens in the sequence. | |
| batch_size (`torch.Tensor`): | |
| Batch size. | |
| """ | |
| if attention_mask is not None and attention_mask.dim() == 4: | |
| # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. | |
| causal_mask = attention_mask | |
| else: | |
| min_dtype = torch.finfo(dtype).min | |
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) | |
| if sequence_length != 1: | |
| causal_mask = torch.triu(causal_mask, diagonal=1) | |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) | |
| if attention_mask is not None: | |
| causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |
| mask_length = attention_mask.shape[-1] | |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) | |
| padding_mask = padding_mask == 0 | |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) | |
| return causal_mask | |
| def forward_model( | |
| self, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| past_key_values: Cache | list[torch.FloatTensor] | None = None, | |
| use_cache: bool | None = None, | |
| output_attentions: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| ) -> tuple | BaseModelOutputWithPast: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| if self.gradient_checkpointing and self.training and use_cache: | |
| use_cache = False | |
| if use_cache and past_key_values is None: | |
| past_key_values = DynamicCache() | |
| if cache_position is None: | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| ) | |
| position_ids = cache_position.unsqueeze(0) | |
| causal_mask = self._update_causal_mask( | |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
| ) | |
| hidden_states = inputs_embeds | |
| # create position embeddings to be shared across the decoder layers | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| for decoder_layer in self.layers: | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| decoder_layer.__call__, | |
| hidden_states, | |
| causal_mask, | |
| past_key_values, | |
| output_attentions, | |
| use_cache, | |
| cache_position, | |
| position_embeddings, | |
| ) | |
| else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| attention_mask=causal_mask, | |
| past_key_value=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| hidden_states = self.norm(hidden_states) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values if use_cache else None, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor, | |
| past_key_values: Cache | None = None, | |
| attention_mask: torch.LongTensor | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or | |
| slicing inputs given the existing cache. | |
| See the forward pass in the model documentation for expected arguments (different models might have different | |
| requirements for e.g. `past_key_values`). This function should work as is for most LLMs. | |
| """ | |
| # 1. Handle BC: | |
| model_inputs = {} | |
| # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) | |
| if self._supports_cache_class: | |
| model_inputs["cache_position"] = cache_position | |
| # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this | |
| # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly | |
| # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) | |
| elif cache_position is None: | |
| past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 | |
| cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) | |
| # 2. Generic cache-dependent input preparation | |
| # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
| # Exception 1: when passing input_embeds, input_ids may be missing entries | |
| # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
| # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case | |
| if past_key_values is not None: | |
| model_inputs["past_key_values"] = past_key_values | |
| if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 | |
| input_ids = input_ids[:, -cache_position.shape[0] :] | |
| elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
| input_ids = input_ids[:, cache_position] | |
| # 3. Prepare base model inputs | |
| input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if not self.config.is_encoder_decoder: | |
| if inputs_embeds is not None and cache_position[0] == 0: | |
| model_inputs[input_ids_key] = None | |
| model_inputs["inputs_embeds"] = inputs_embeds | |
| else: | |
| # `clone` calls in this function ensure a consistent stride. See #32227 | |
| model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) | |
| model_inputs["inputs_embeds"] = None | |
| else: | |
| model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) | |
| # 4. Create missing `position_ids` on the fly | |
| if ( | |
| attention_mask is not None | |
| and kwargs.get("position_ids") is None | |
| and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) | |
| ): | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) | |
| # 5. Slice model inputs if it's an input that should have the same length as `input_ids` | |
| for model_input_name in ["position_ids", "token_type_ids"]: | |
| model_input = kwargs.get(model_input_name) | |
| if model_input is not None: | |
| if past_key_values: | |
| model_input = model_input[:, -input_ids.shape[1] :] | |
| model_input = model_input.clone(memory_format=torch.contiguous_format) | |
| model_inputs[model_input_name] = model_input | |
| # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass) | |
| if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: | |
| if model_inputs["inputs_embeds"] is not None: | |
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape | |
| device = model_inputs["inputs_embeds"].device | |
| else: | |
| batch_size, sequence_length = model_inputs[input_ids_key].shape | |
| device = model_inputs[input_ids_key].device | |
| # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create | |
| # the 4D causal mask exists, it should be present in the base model (XXXModel class). | |
| base_model = getattr(self, self.base_model_prefix, None) | |
| if base_model is None: | |
| causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None) | |
| else: | |
| causal_mask_creation_function = getattr( | |
| base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None | |
| ) | |
| if causal_mask_creation_function is None: | |
| logger.warning_once( | |
| f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " | |
| "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " | |
| "writing code, see Llama for an example implementation. If you're a user, please report this " | |
| "issue on GitHub." | |
| ) | |
| else: | |
| attention_mask = causal_mask_creation_function( | |
| attention_mask, | |
| sequence_length=sequence_length, | |
| target_length=past_key_values.get_max_cache_shape(), | |
| dtype=self.dtype, | |
| device=device, | |
| cache_position=cache_position, | |
| batch_size=batch_size, | |
| config=self.config, | |
| past_key_values=past_key_values, | |
| ) | |
| if attention_mask is not None: | |
| model_inputs["attention_mask"] = attention_mask | |
| # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). | |
| for key, value in kwargs.items(): | |
| if key not in model_inputs: | |
| model_inputs[key] = value | |
| # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) | |
| model_inputs.pop("labels", None) | |
| return model_inputs | |
| def generate(self, inputs: torch.LongTensor = None, **kwargs): | |
| input_ids = kwargs.pop("input_ids") | |
| latents = kwargs.pop("latents", None) | |
| inputs_embeds = self.prepare_inputs_embeds(input_ids, latents) | |
| return super().generate(inputs=inputs, input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) | |
| def gradient_checkpointing_enable(self, **kwargs): | |
| super().gradient_checkpointing_enable(**kwargs) | |
| self.image_head.net.grad_checkpointing = True |