Spaces:
Running
on
Zero
Running
on
Zero
| # The code is revised from DiT | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import math | |
| from safetensors.torch import load_file | |
| from typing import List, Optional, Tuple, Union | |
| import torch.utils.checkpoint | |
| from huggingface_hub import snapshot_download | |
| from transformers.modeling_outputs import BaseModelOutputWithPast | |
| from transformers import Phi3Config, Phi3Model | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class Phi3Transformer(Phi3Model): | |
| """ | |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] | |
| We only modified the attention mask | |
| Args: | |
| config: Phi3Config | |
| """ | |
| def prefetch_layer(self, layer_idx: int, device: torch.device): | |
| "Starts prefetching the next layer cache" | |
| with torch.cuda.stream(self.prefetch_stream): | |
| # Prefetch next layer tensors to GPU | |
| for name, param in self.layers[layer_idx].named_parameters(): | |
| param.data = param.data.to(device, non_blocking=True) | |
| def evict_previous_layer(self, layer_idx: int): | |
| "Moves the previous layer cache to the CPU" | |
| prev_layer_idx = layer_idx - 1 | |
| for name, param in self.layers[prev_layer_idx].named_parameters(): | |
| param.data = param.data.to("cpu", non_blocking=True) | |
| def get_offlaod_layer(self, layer_idx: int, device: torch.device): | |
| # init stream | |
| if not hasattr(self, "prefetch_stream"): | |
| self.prefetch_stream = torch.cuda.Stream() | |
| # delete previous layer | |
| torch.cuda.current_stream().synchronize() | |
| self.evict_previous_layer(layer_idx) | |
| # make sure the current layer is ready | |
| torch.cuda.synchronize(self.prefetch_stream) | |
| # load next layer | |
| self.prefetch_layer((layer_idx + 1) % len(self.layers), device) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| offload_model: Optional[bool] = False, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if self.gradient_checkpointing and self.training: | |
| if use_cache: | |
| logger.warning_once( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
| ) | |
| use_cache = False | |
| # kept for BC (non `Cache` `past_key_values` inputs) | |
| return_legacy_cache = False | |
| if use_cache and not isinstance(past_key_values, Cache): | |
| return_legacy_cache = True | |
| if past_key_values is None: | |
| past_key_values = DynamicCache() | |
| else: | |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
| logger.warning_once( | |
| "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | |
| "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " | |
| "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" | |
| ) | |
| # if inputs_embeds is None: | |
| # inputs_embeds = self.embed_tokens(input_ids) | |
| # if cache_position is None: | |
| # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| # cache_position = torch.arange( | |
| # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| # ) | |
| # if position_ids is None: | |
| # position_ids = cache_position.unsqueeze(0) | |
| if attention_mask is not None and attention_mask.dim() == 3: | |
| dtype = inputs_embeds.dtype | |
| min_dtype = torch.finfo(dtype).min | |
| attention_mask = (1 - attention_mask) * min_dtype | |
| attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) | |
| else: | |
| raise Exception("attention_mask parameter was unavailable or invalid") | |
| # causal_mask = self._update_causal_mask( | |
| # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
| # ) | |
| hidden_states = inputs_embeds | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| next_decoder_cache = None | |
| layer_idx = -1 | |
| for decoder_layer in self.layers: | |
| layer_idx += 1 | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| decoder_layer.__call__, | |
| hidden_states, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| output_attentions, | |
| use_cache, | |
| cache_position, | |
| ) | |
| else: | |
| if offload_model and not self.training: | |
| self.get_offlaod_layer(layer_idx, device=inputs_embeds.device) | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if use_cache: | |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| hidden_states = self.norm(hidden_states) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| print('************') | |
| all_hidden_states += (hidden_states,) | |
| next_cache = next_decoder_cache if use_cache else None | |
| if return_legacy_cache: | |
| next_cache = next_cache.to_legacy_cache() | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=next_cache, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| ) | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| """ | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param t: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an (N, D) Tensor of positional embeddings. | |
| """ | |
| # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t, dtype=torch.float32): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of DiT. | |
| """ | |
| def __init__(self, hidden_size, patch_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1): | |
| """ | |
| grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
| [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| if isinstance(grid_size, int): | |
| grid_size = (grid_size, grid_size) | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2. | |
| omega = 1. / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class PatchEmbedMR(nn.Module): | |
| """ 2D Image to Patch Embedding | |
| """ | |
| def __init__( | |
| self, | |
| patch_size: int = 2, | |
| in_chans: int = 4, | |
| embed_dim: int = 768, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| x = x.flatten(2).transpose(1, 2) # NCHW -> NLC | |
| return x | |
| class OmniGenOriginalModel(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| transformer_config: Phi3Config, | |
| patch_size=2, | |
| in_channels=4, | |
| pe_interpolation: float = 1.0, | |
| pos_embed_max_size: int = 192, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels | |
| self.patch_size = patch_size | |
| self.pos_embed_max_size = pos_embed_max_size | |
| hidden_size = transformer_config.hidden_size | |
| self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True) | |
| self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True) | |
| self.time_token = TimestepEmbedder(hidden_size) | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.pe_interpolation = pe_interpolation | |
| pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64) | |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) | |
| self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) | |
| self.initialize_weights() | |
| self.llm = Phi3Transformer(config=transformer_config) | |
| self.llm.config.use_cache = False | |
| def from_pretrained(cls, model_name): | |
| if not os.path.exists(model_name): | |
| cache_folder = os.getenv('HF_HUB_CACHE') | |
| model_name = snapshot_download(repo_id=model_name, | |
| cache_dir=cache_folder, | |
| ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) | |
| config = Phi3Config.from_pretrained(model_name) | |
| model = cls(config) | |
| if os.path.exists(os.path.join(model_name, 'model.safetensors')): | |
| print("Loading safetensors") | |
| ckpt = load_file(os.path.join(model_name, 'model.safetensors')) | |
| else: | |
| ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') | |
| model.load_state_dict(ckpt) | |
| return model | |
| def initialize_weights(self): | |
| assert not hasattr(self, "llama") | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): | |
| w = self.x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.x_embedder.proj.bias, 0) | |
| w = self.input_x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.x_embedder.proj.bias, 0) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| nn.init.normal_(self.time_token.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.time_token.mlp[2].weight, std=0.02) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def unpatchify(self, x, h, w): | |
| """ | |
| x: (N, T, patch_size**2 * C) | |
| imgs: (N, H, W, C) | |
| """ | |
| c = self.out_channels | |
| x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c)) | |
| x = torch.einsum('nhwpqc->nchpwq', x) | |
| imgs = x.reshape(shape=(x.shape[0], c, h, w)) | |
| return imgs | |
| def cropped_pos_embed(self, height, width): | |
| """Crops positional embeddings for SD3 compatibility.""" | |
| if self.pos_embed_max_size is None: | |
| raise ValueError("`pos_embed_max_size` must be set for cropping.") | |
| height = height // self.patch_size | |
| width = width // self.patch_size | |
| if height > self.pos_embed_max_size: | |
| raise ValueError( | |
| f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
| ) | |
| if width > self.pos_embed_max_size: | |
| raise ValueError( | |
| f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
| ) | |
| top = (self.pos_embed_max_size - height) // 2 | |
| left = (self.pos_embed_max_size - width) // 2 | |
| spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) | |
| spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] | |
| # print(top, top + height, left, left + width, spatial_pos_embed.size()) | |
| spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) | |
| return spatial_pos_embed | |
| def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False): | |
| if isinstance(latents, list): | |
| return_list = False | |
| if padding_latent is None: | |
| padding_latent = [None] * len(latents) | |
| return_list = True | |
| patched_latents, num_tokens, shapes = [], [], [] | |
| for latent, padding in zip(latents, padding_latent): | |
| height, width = latent.shape[-2:] | |
| if is_input_images: | |
| latent = self.input_x_embedder(latent) | |
| else: | |
| latent = self.x_embedder(latent) | |
| pos_embed = self.cropped_pos_embed(height, width) | |
| latent = latent + pos_embed | |
| if padding is not None: | |
| latent = torch.cat([latent, padding], dim=-2) | |
| patched_latents.append(latent) | |
| num_tokens.append(pos_embed.size(1)) | |
| shapes.append([height, width]) | |
| if not return_list: | |
| latents = torch.cat(patched_latents, dim=0) | |
| else: | |
| latents = patched_latents | |
| else: | |
| height, width = latents.shape[-2:] | |
| if is_input_images: | |
| latents = self.input_x_embedder(latents) | |
| else: | |
| latents = self.x_embedder(latents) | |
| pos_embed = self.cropped_pos_embed(height, width) | |
| latents = latents + pos_embed | |
| num_tokens = latents.size(1) | |
| shapes = [height, width] | |
| return latents, num_tokens, shapes | |
| def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False): | |
| """ | |
| """ | |
| input_is_list = isinstance(x, list) | |
| x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent) | |
| time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) | |
| if input_img_latents is not None: | |
| input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True) | |
| if input_ids is not None: | |
| condition_embeds = self.llm.embed_tokens(input_ids).clone() | |
| input_img_inx = 0 | |
| for b_inx in input_image_sizes.keys(): | |
| for start_inx, end_inx in input_image_sizes[b_inx]: | |
| condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] | |
| input_img_inx += 1 | |
| if input_img_latents is not None: | |
| assert input_img_inx == len(input_latents) | |
| input_emb = torch.cat([condition_embeds, time_token, x], dim=1) | |
| else: | |
| input_emb = torch.cat([time_token, x], dim=1) | |
| output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model) | |
| output, past_key_values = output.last_hidden_state, output.past_key_values | |
| if input_is_list: | |
| image_embedding = output[:, -max(num_tokens):] | |
| time_emb = self.t_embedder(timestep, dtype=x.dtype) | |
| x = self.final_layer(image_embedding, time_emb) | |
| latents = [] | |
| for i in range(x.size(0)): | |
| latent = x[i:i+1, :num_tokens[i]] | |
| latent = self.unpatchify(latent, shapes[i][0], shapes[i][1]) | |
| latents.append(latent) | |
| else: | |
| image_embedding = output[:, -num_tokens:] | |
| time_emb = self.t_embedder(timestep, dtype=x.dtype) | |
| x = self.final_layer(image_embedding, time_emb) | |
| latents = self.unpatchify(x, shapes[0], shapes[1]) | |
| if return_past_key_values: | |
| return latents, past_key_values | |
| return latents | |
| def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): | |
| self.llm.config.use_cache = use_kv_cache | |
| model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model) | |
| if use_img_cfg: | |
| cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0) | |
| cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) | |
| model_out = [cond, cond, cond] | |
| else: | |
| cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0) | |
| cond = uncond + cfg_scale * (cond - uncond) | |
| model_out = [cond, cond] | |
| return torch.cat(model_out, dim=0), past_key_values | |
| def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): | |
| self.llm.config.use_cache = use_kv_cache | |
| if past_key_values is None: | |
| past_key_values = [None] * len(attention_mask) | |
| x = torch.split(x, len(x) // len(attention_mask), dim=0) | |
| timestep = timestep.to(x[0].dtype) | |
| timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0) | |
| model_out, pask_key_values = [], [] | |
| for i in range(len(input_ids)): | |
| temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model) | |
| model_out.append(temp_out) | |
| pask_key_values.append(temp_pask_key_values) | |
| if len(model_out) == 3: | |
| cond, uncond, img_cond = model_out | |
| cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) | |
| model_out = [cond, cond, cond] | |
| elif len(model_out) == 2: | |
| cond, uncond = model_out | |
| cond = uncond + cfg_scale * (cond - uncond) | |
| model_out = [cond, cond] | |
| else: | |
| return model_out[0] | |
| return torch.cat(model_out, dim=0), pask_key_values | |
| class OmniGenTransformer(OmniGenOriginalModel): | |
| def __init__(self): | |
| config = { | |
| "_name_or_path": "Phi-3-vision-128k-instruct", | |
| "architectures": [ | |
| "Phi3ForCausalLM" | |
| ], | |
| "attention_dropout": 0.0, | |
| "bos_token_id": 1, | |
| "eos_token_id": 2, | |
| "hidden_act": "silu", | |
| "hidden_size": 3072, | |
| "initializer_range": 0.02, | |
| "intermediate_size": 8192, | |
| "max_position_embeddings": 131072, | |
| "model_type": "phi3", | |
| "num_attention_heads": 32, | |
| "num_hidden_layers": 32, | |
| "num_key_value_heads": 32, | |
| "original_max_position_embeddings": 4096, | |
| "rms_norm_eps": 1e-05, | |
| "rope_scaling": { | |
| "long_factor": [ | |
| 1.0299999713897705, | |
| 1.0499999523162842, | |
| 1.0499999523162842, | |
| 1.0799999237060547, | |
| 1.2299998998641968, | |
| 1.2299998998641968, | |
| 1.2999999523162842, | |
| 1.4499999284744263, | |
| 1.5999999046325684, | |
| 1.6499998569488525, | |
| 1.8999998569488525, | |
| 2.859999895095825, | |
| 3.68999981880188, | |
| 5.419999599456787, | |
| 5.489999771118164, | |
| 5.489999771118164, | |
| 9.09000015258789, | |
| 11.579999923706055, | |
| 15.65999984741211, | |
| 15.769999504089355, | |
| 15.789999961853027, | |
| 18.360000610351562, | |
| 21.989999771118164, | |
| 23.079999923706055, | |
| 30.009998321533203, | |
| 32.35000228881836, | |
| 32.590003967285156, | |
| 35.56000518798828, | |
| 39.95000457763672, | |
| 53.840003967285156, | |
| 56.20000457763672, | |
| 57.95000457763672, | |
| 59.29000473022461, | |
| 59.77000427246094, | |
| 59.920005798339844, | |
| 61.190006256103516, | |
| 61.96000671386719, | |
| 62.50000762939453, | |
| 63.3700065612793, | |
| 63.48000717163086, | |
| 63.48000717163086, | |
| 63.66000747680664, | |
| 63.850006103515625, | |
| 64.08000946044922, | |
| 64.760009765625, | |
| 64.80001068115234, | |
| 64.81001281738281, | |
| 64.81001281738281 | |
| ], | |
| "short_factor": [ | |
| 1.05, | |
| 1.05, | |
| 1.05, | |
| 1.1, | |
| 1.1, | |
| 1.1, | |
| 1.2500000000000002, | |
| 1.2500000000000002, | |
| 1.4000000000000004, | |
| 1.4500000000000004, | |
| 1.5500000000000005, | |
| 1.8500000000000008, | |
| 1.9000000000000008, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.1000000000000005, | |
| 2.1000000000000005, | |
| 2.2, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3999999999999995, | |
| 2.3999999999999995, | |
| 2.6499999999999986, | |
| 2.6999999999999984, | |
| 2.8999999999999977, | |
| 2.9499999999999975, | |
| 3.049999999999997, | |
| 3.049999999999997, | |
| 3.049999999999997 | |
| ], | |
| "type": "su" | |
| }, | |
| "rope_theta": 10000.0, | |
| "sliding_window": 131072, | |
| "tie_word_embeddings": False, | |
| "torch_dtype": "bfloat16", | |
| "transformers_version": "4.38.1", | |
| "use_cache": True, | |
| "vocab_size": 32064, | |
| "_attn_implementation": "sdpa" | |
| } | |
| config = Phi3Config(**config) | |
| super().__init__(config) | |
| def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False): | |
| input_is_list = isinstance(x, list) | |
| x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent) | |
| time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1) | |
| if input_img_latents is not None: | |
| input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True) | |
| if input_ids is not None: | |
| condition_embeds = self.llm.embed_tokens(input_ids).clone() | |
| input_img_inx = 0 | |
| for b_inx in input_image_sizes.keys(): | |
| for start_inx, end_inx in input_image_sizes[b_inx]: | |
| condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] | |
| input_img_inx += 1 | |
| if input_img_latents is not None: | |
| assert input_img_inx == len(input_latents) | |
| input_emb = torch.cat([condition_embeds, time_token, x], dim=1) | |
| else: | |
| input_emb = torch.cat([time_token, x], dim=1) | |
| output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model) | |
| output, past_key_values = output.last_hidden_state, output.past_key_values | |
| if input_is_list: | |
| image_embedding = output[:, -max(num_tokens):] | |
| time_emb = self.t_embedder(timestep, dtype=x.dtype) | |
| x = self.final_layer(image_embedding, time_emb) | |
| latents = [] | |
| for i in range(x.size(0)): | |
| latent = x[i:i+1, :num_tokens[i]] | |
| latent = self.unpatchify(latent, shapes[i][0], shapes[i][1]) | |
| latents.append(latent) | |
| else: | |
| image_embedding = output[:, -num_tokens:] | |
| time_emb = self.t_embedder(timestep, dtype=x.dtype) | |
| x = self.final_layer(image_embedding, time_emb) | |
| latents = self.unpatchify(x, shapes[0], shapes[1]) | |
| if return_past_key_values: | |
| return latents, past_key_values | |
| return latents | |
| def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model): | |
| self.llm.config.use_cache = use_kv_cache | |
| if past_key_values is None: | |
| past_key_values = [None] * len(attention_mask) | |
| x = torch.split(x, len(x) // len(attention_mask), dim=0) | |
| timestep = timestep.to(x[0].dtype) | |
| timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0) | |
| model_out, pask_key_values = [], [] | |
| for i in range(len(input_ids)): | |
| temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model) | |
| model_out.append(temp_out) | |
| pask_key_values.append(temp_pask_key_values) | |
| if len(model_out) == 3: | |
| cond, uncond, img_cond = model_out | |
| cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond) | |
| model_out = [cond, cond, cond] | |
| elif len(model_out) == 2: | |
| cond, uncond = model_out | |
| cond = uncond + cfg_scale * (cond - uncond) | |
| model_out = [cond, cond] | |
| else: | |
| return model_out[0] | |
| return torch.cat(model_out, dim=0), pask_key_values | |
| def state_dict_converter(): | |
| return OmniGenTransformerStateDictConverter() | |
| class OmniGenTransformerStateDictConverter: | |
| def __init__(self): | |
| pass | |
| def from_diffusers(self, state_dict): | |
| return state_dict | |
| def from_civitai(self, state_dict): | |
| return state_dict | |