Spaces:
Build error
Build error
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from einops import repeat | |
| from diffusers.models.embeddings import get_1d_rotary_pos_embed | |
| class OmniGen2RotaryPosEmbed(nn.Module): | |
| def __init__(self, theta: int, | |
| axes_dim: Tuple[int, int, int], | |
| axes_lens: Tuple[int, int, int] = (300, 512, 512), | |
| patch_size: int = 2): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| self.axes_lens = axes_lens | |
| self.patch_size = patch_size | |
| def get_freqs_cis(axes_dim: Tuple[int, int, int], | |
| axes_lens: Tuple[int, int, int], | |
| theta: int) -> List[torch.Tensor]: | |
| freqs_cis = [] | |
| freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 | |
| for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): | |
| emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) | |
| freqs_cis.append(emb) | |
| return freqs_cis | |
| def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: | |
| device = ids.device | |
| if ids.device.type == "mps": | |
| ids = ids.to("cpu") | |
| result = [] | |
| for i in range(len(self.axes_dim)): | |
| freqs = freqs_cis[i].to(ids.device) | |
| index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) | |
| result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) | |
| return torch.cat(result, dim=-1).to(device) | |
| def forward( | |
| self, | |
| freqs_cis, | |
| attention_mask, | |
| l_effective_ref_img_len, | |
| l_effective_img_len, | |
| ref_img_sizes, | |
| img_sizes, | |
| device | |
| ): | |
| batch_size = len(attention_mask) | |
| p = self.patch_size | |
| encoder_seq_len = attention_mask.shape[1] | |
| l_effective_cap_len = attention_mask.sum(dim=1).tolist() | |
| seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] | |
| max_seq_len = max(seq_lengths) | |
| max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) | |
| max_img_len = max(l_effective_img_len) | |
| # Create position IDs | |
| position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) | |
| for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): | |
| # add text position ids | |
| position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") | |
| pe_shift = cap_seq_len | |
| pe_shift_len = cap_seq_len | |
| if ref_img_sizes[i] is not None: | |
| for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): | |
| H, W = ref_img_size | |
| ref_H_tokens, ref_W_tokens = H // p, W // p | |
| assert ref_H_tokens * ref_W_tokens == ref_img_len | |
| # add image position ids | |
| row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() | |
| col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids | |
| pe_shift += max(ref_H_tokens, ref_W_tokens) | |
| pe_shift_len += ref_img_len | |
| H, W = img_sizes[i] | |
| H_tokens, W_tokens = H // p, W // p | |
| assert H_tokens * W_tokens == l_effective_img_len[i] | |
| row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() | |
| col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() | |
| assert pe_shift_len + l_effective_img_len[i] == seq_len | |
| position_ids[i, pe_shift_len: seq_len, 0] = pe_shift | |
| position_ids[i, pe_shift_len: seq_len, 1] = row_ids | |
| position_ids[i, pe_shift_len: seq_len, 2] = col_ids | |
| # Get combined rotary embeddings | |
| freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) | |
| # create separate rotary embeddings for captions and images | |
| cap_freqs_cis = torch.zeros( | |
| batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype | |
| ) | |
| ref_img_freqs_cis = torch.zeros( | |
| batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype | |
| ) | |
| img_freqs_cis = torch.zeros( | |
| batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype | |
| ) | |
| for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): | |
| cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] | |
| ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] | |
| img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] | |
| return ( | |
| cap_freqs_cis, | |
| ref_img_freqs_cis, | |
| img_freqs_cis, | |
| freqs_cis, | |
| l_effective_cap_len, | |
| seq_lengths, | |
| ) |