|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import lru_cache |
|
|
from typing import Tuple |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb |
|
|
from torch import nn |
|
|
|
|
|
from common.cache import Cache |
|
|
|
|
|
|
|
|
class RotaryEmbeddingBase(nn.Module): |
|
|
def __init__(self, dim: int, rope_dim: int): |
|
|
super().__init__() |
|
|
self.rope = RotaryEmbedding( |
|
|
dim=dim // rope_dim, |
|
|
freqs_for="pixel", |
|
|
max_freq=256, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
freqs = self.rope.freqs |
|
|
del self.rope.freqs |
|
|
self.rope.register_buffer("freqs", freqs.data) |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
|
def get_axial_freqs(self, *dims): |
|
|
return self.rope.get_axial_freqs(*dims) |
|
|
|
|
|
|
|
|
class RotaryEmbedding3d(RotaryEmbeddingBase): |
|
|
def __init__(self, dim: int): |
|
|
super().__init__(dim, rope_dim=3) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.FloatTensor, |
|
|
k: torch.FloatTensor, |
|
|
size: Tuple[int, int, int], |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.FloatTensor, |
|
|
]: |
|
|
T, H, W = size |
|
|
freqs = self.get_axial_freqs(T, H, W) |
|
|
q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) |
|
|
k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) |
|
|
q = apply_rotary_emb(freqs, q) |
|
|
k = apply_rotary_emb(freqs, k) |
|
|
q = rearrange(q, "b h T H W d -> b h (T H W) d") |
|
|
k = rearrange(k, "b h T H W d -> b h (T H W) d") |
|
|
return q, k |
|
|
|
|
|
|
|
|
class NaRotaryEmbedding3d(RotaryEmbedding3d): |
|
|
def forward( |
|
|
self, |
|
|
q: torch.FloatTensor, |
|
|
k: torch.FloatTensor, |
|
|
shape: torch.LongTensor, |
|
|
cache: Cache, |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.FloatTensor, |
|
|
]: |
|
|
freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) |
|
|
q = rearrange(q, "L h d -> h L d") |
|
|
k = rearrange(k, "L h d -> h L d") |
|
|
q = apply_rotary_emb(freqs, q.float()).to(q.dtype) |
|
|
k = apply_rotary_emb(freqs, k.float()).to(k.dtype) |
|
|
q = rearrange(q, "h L d -> L h d") |
|
|
k = rearrange(k, "h L d -> L h d") |
|
|
return q, k |
|
|
|
|
|
def get_freqs( |
|
|
self, |
|
|
shape: torch.LongTensor, |
|
|
) -> torch.Tensor: |
|
|
freq_list = [] |
|
|
for f, h, w in shape.tolist(): |
|
|
freqs = self.get_axial_freqs(f, h, w) |
|
|
freq_list.append(freqs.view(-1, freqs.size(-1))) |
|
|
return torch.cat(freq_list, dim=0) |
|
|
|