|  | """ | 
					
						
						|  | Orginally Taken verbatim from xformers library | 
					
						
						|  | https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py | 
					
						
						|  |  | 
					
						
						|  | The difference is that xformers seems to assume the inputs to be | 
					
						
						|  | (bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim) | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | from typing import List, Optional, Tuple, Dict, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import dataclasses | 
					
						
						|  | from transformers.utils import logging | 
					
						
						|  |  | 
					
						
						|  | from transformers import PretrainedConfig | 
					
						
						|  |  | 
					
						
						|  | is_dacite_available = False | 
					
						
						|  | try: | 
					
						
						|  | import dacite | 
					
						
						|  | is_dacite_available = True | 
					
						
						|  | except ImportError: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | logger = logging.get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  | @dataclasses.dataclass | 
					
						
						|  | class LongRopeConfig(object): | 
					
						
						|  | short_factor: List[float] | 
					
						
						|  | long_factor: List[float] | 
					
						
						|  | original_max_position_embeddings: int | 
					
						
						|  | type: str = "longrope" | 
					
						
						|  | short_mscale: float = -1 | 
					
						
						|  | long_mscale: float = -1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig": | 
					
						
						|  | if is_dacite_available: | 
					
						
						|  |  | 
					
						
						|  | return dacite.from_dict(data_class=cls, data=config_dict) | 
					
						
						|  | kwargs = {} | 
					
						
						|  | for field in dataclasses.fields(cls): | 
					
						
						|  | if field.name in config_dict: | 
					
						
						|  | if field.init: | 
					
						
						|  | kwargs[field.name] = config_dict[field.name] | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Field {field.name} is not initiable") | 
					
						
						|  | else: | 
					
						
						|  | if field.default is dataclasses.MISSING: | 
					
						
						|  | raise ValueError(f"Field {field.name} is required") | 
					
						
						|  | extra_keys = set(config_dict.keys()) - set(kwargs.keys()) | 
					
						
						|  | if len(extra_keys) > 0: | 
					
						
						|  | for key in extra_keys: | 
					
						
						|  | logger.error(f"Unrecognized key {key} in config_dict") | 
					
						
						|  | raise ValueError(f"Unrecognized keys in config_dict") | 
					
						
						|  | return cls(**kwargs) | 
					
						
						|  |  | 
					
						
						|  | def rotate_half(x): | 
					
						
						|  | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | 
					
						
						|  | return torch.cat((-x2, x1), dim=x1.ndim - 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.script | 
					
						
						|  | def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if seq_dimension == 0: | 
					
						
						|  | cos = cos[: x.shape[0], None, None, :] | 
					
						
						|  | sin = sin[: x.shape[0], None, None, :] | 
					
						
						|  | elif seq_dimension == 1: | 
					
						
						|  |  | 
					
						
						|  | cos = cos[None, : x.shape[1], None, :] | 
					
						
						|  | sin = sin[None, : x.shape[1], None, :] | 
					
						
						|  | elif seq_dimension == 2: | 
					
						
						|  | cos = cos[None, None, : x.shape[2], :] | 
					
						
						|  | sin = sin[None, None, : x.shape[2], :] | 
					
						
						|  |  | 
					
						
						|  | return (x * cos) + (rotate_half(x) * sin) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RotaryEmbedding(torch.nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Adapted from the xformers library | 
					
						
						|  |  | 
					
						
						|  | The rotary position embeddings from RoFormer_ (Su et. al). | 
					
						
						|  | A crucial insight from the method is that the query and keys are | 
					
						
						|  | transformed by rotation matrices which depend on the relative positions. | 
					
						
						|  | Other implementations are available in the Rotary Transformer repo_ and in | 
					
						
						|  | GPT-NeoX_, GPT-NeoX was an inspiration | 
					
						
						|  | .. _RoFormer: https://arxiv.org/abs/2104.09864 | 
					
						
						|  | .. _repo: https://github.com/ZhuiyiTechnology/roformer | 
					
						
						|  | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox | 
					
						
						|  | .. warning: Please note that this embedding is not registered on purpose, as it is transformative | 
					
						
						|  | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis | 
					
						
						|  |  | 
					
						
						|  | # Arguments | 
					
						
						|  | :param dim_mode: head dimention | 
					
						
						|  | :param max_seq_len: | 
					
						
						|  | :param default_seq_dimension: which dim is the sequence length | 
					
						
						|  | :param dtype: cos/sin dtype | 
					
						
						|  | :param use_fused_kernel: if to use customized fused kernel. | 
					
						
						|  | Note: if used, q, k will be modified inplace. Ok for both forward & backward. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | dim_model: int, | 
					
						
						|  | *, | 
					
						
						|  | max_seq_len: Optional[int] = None, | 
					
						
						|  | dtype: Optional[torch.dtype] = None, | 
					
						
						|  | base=10000, | 
					
						
						|  | position_scale=1, | 
					
						
						|  | device: Optional[torch.device] = None, | 
					
						
						|  | longrope_config: Optional[LongRopeConfig] = None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.base = base | 
					
						
						|  | self.dim_model = dim_model | 
					
						
						|  | self.max_seq_len = max_seq_len | 
					
						
						|  | self.longrope_config = longrope_config | 
					
						
						|  |  | 
					
						
						|  | if self.is_longrope: | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "range_vector", | 
					
						
						|  | torch.arange(max_seq_len, device=device, dtype=torch.float32), | 
					
						
						|  | persistent=False | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "short_factors", | 
					
						
						|  | torch.tensor(self.longrope_config.short_factor, dtype=torch.float32), | 
					
						
						|  | persistent=False | 
					
						
						|  | ) | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "long_factors", | 
					
						
						|  | torch.tensor(self.longrope_config.long_factor, dtype=torch.float32), | 
					
						
						|  | persistent=False | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model)) | 
					
						
						|  | self.register_buffer("inv_freq", inv_freq) | 
					
						
						|  |  | 
					
						
						|  | self.position_scale = position_scale | 
					
						
						|  |  | 
					
						
						|  | if not self.is_longrope: | 
					
						
						|  | dtype = dtype or torch.get_default_dtype() | 
					
						
						|  | self._set_cos_sin_cache( | 
					
						
						|  | seq_len=max_seq_len, | 
					
						
						|  | device=self.inv_freq.device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | ) | 
					
						
						|  | @property | 
					
						
						|  | def is_longrope(self): | 
					
						
						|  | return self.longrope_config is not None | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def original_max_seq_len(self): | 
					
						
						|  | if self.longrope_config is not None: | 
					
						
						|  | return self.longrope_config.original_max_position_embeddings | 
					
						
						|  | logger.warning_once( | 
					
						
						|  | ( | 
					
						
						|  | "``original_max_seq_len'' is being accessed, but longrope_config has not been set. " | 
					
						
						|  | "Please only do this if you are sure about the context." | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | return self.max_seq_len | 
					
						
						|  |  | 
					
						
						|  | def get_range_vector(self, seq_len: int, device: torch.device): | 
					
						
						|  | if self.is_longrope: | 
					
						
						|  | assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}" | 
					
						
						|  | if self.range_vector.device != device: | 
					
						
						|  | self.range_vector = self.range_vector.to(device) | 
					
						
						|  | return self.range_vector[:seq_len] | 
					
						
						|  | return torch.arange(seq_len, device=device, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | if scale <= 1.0: | 
					
						
						|  | return 1.0 | 
					
						
						|  | return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len)) | 
					
						
						|  |  | 
					
						
						|  | def _set_cos_sin_cache( | 
					
						
						|  | self, | 
					
						
						|  | seq_len: int, | 
					
						
						|  | device: Optional[torch.device] = None, | 
					
						
						|  | dtype: Optional[torch.dtype] = None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | dtype = dtype or torch.get_default_dtype() | 
					
						
						|  | self.max_seq_len_cached = seq_len | 
					
						
						|  | t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq) | 
					
						
						|  | device_type = device.type if device is not None else "cpu" | 
					
						
						|  | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | 
					
						
						|  | with torch.autocast(device_type=device_type, enabled=False): | 
					
						
						|  |  | 
					
						
						|  | freqs = torch.outer(t, self.inv_freq) | 
					
						
						|  |  | 
					
						
						|  | emb = torch.cat((freqs, freqs), dim=-1) | 
					
						
						|  | cos = emb.cos() | 
					
						
						|  | sin = emb.sin() | 
					
						
						|  | self.register_buffer("cos_cached", cos.to(dtype), persistent=False) | 
					
						
						|  | self.register_buffer("sin_cached", sin.to(dtype), persistent=False) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, q: torch.Tensor, | 
					
						
						|  | k: torch.Tensor, | 
					
						
						|  | seq_dimension: int = 1, | 
					
						
						|  | seqlen_offset: int = 0, | 
					
						
						|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
					
						
						|  | """q, k does not include `seqlen_offset` | 
					
						
						|  | q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) | 
					
						
						|  | k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) | 
					
						
						|  | """ | 
					
						
						|  | if seq_dimension < 0: | 
					
						
						|  | seq_dimension = k.ndim + seq_dimension | 
					
						
						|  | assert seq_dimension in (0, 1, 2) | 
					
						
						|  | seq_len = k.shape[seq_dimension] + seqlen_offset | 
					
						
						|  |  | 
					
						
						|  | if self.is_longrope: | 
					
						
						|  | if seq_len > self.original_max_seq_len: | 
					
						
						|  | t = self.get_range_vector(seq_len, device=q.device) | 
					
						
						|  | rescale_factors = self.long_factors.to(q.device) | 
					
						
						|  | long_mscale = self.longrope_config.long_mscale | 
					
						
						|  | mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len) | 
					
						
						|  | else: | 
					
						
						|  | t = self.get_range_vector(self.original_max_seq_len, device=q.device) | 
					
						
						|  | rescale_factors = self.short_factors.to(q.device) | 
					
						
						|  | short_mscale = self.longrope_config.short_mscale | 
					
						
						|  | mscale = short_mscale if short_mscale > 0 else 1.0 | 
					
						
						|  | assert rescale_factors.shape == (self.dim_model // 2, ), ( | 
					
						
						|  | f"misaligned shape for LongRoPE rescale factors:\n" | 
					
						
						|  | f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}." | 
					
						
						|  | ) | 
					
						
						|  | inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model))) | 
					
						
						|  | device_type = q.device.type if q.device is not None else "cpu" | 
					
						
						|  | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | 
					
						
						|  | with torch.autocast(device_type=device_type, enabled=False): | 
					
						
						|  | freqs = torch.outer(t, inv_freq) | 
					
						
						|  | emb = torch.cat((freqs, freqs), dim=-1) | 
					
						
						|  | cos = emb.cos() * mscale | 
					
						
						|  | sin = emb.sin() * mscale | 
					
						
						|  | cos_cached = cos.to(q.dtype) | 
					
						
						|  | sin_cached = sin.to(q.dtype) | 
					
						
						|  | else: | 
					
						
						|  | if seq_len > self.max_seq_len_cached: | 
					
						
						|  | self._set_cos_sin_cache( | 
					
						
						|  | seq_len=seq_len, | 
					
						
						|  | device=k.device, | 
					
						
						|  | dtype=k.dtype, | 
					
						
						|  | ) | 
					
						
						|  | cos_cached = self.cos_cached | 
					
						
						|  | sin_cached = self.sin_cached | 
					
						
						|  | return ( | 
					
						
						|  | apply_rotary_pos_emb( | 
					
						
						|  | q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension | 
					
						
						|  | ).to(q.dtype), | 
					
						
						|  | apply_rotary_pos_emb( | 
					
						
						|  | k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension | 
					
						
						|  | ).to(k.dtype), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding": | 
					
						
						|  | kwargs = dict( | 
					
						
						|  | dim_model=config.hidden_size // config.num_attention_heads, | 
					
						
						|  | max_seq_len=config.max_position_embeddings, | 
					
						
						|  | base=config.rope_embedding_base, | 
					
						
						|  | position_scale=config.rope_position_scale, | 
					
						
						|  | ) | 
					
						
						|  | if config.rope_scaling is not None: | 
					
						
						|  | kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling) | 
					
						
						|  | return cls(**kwargs) | 
					
						
						|  |  |