import re from collections import defaultdict from math import sqrt from typing import Any, Iterable, Self, cast import torch from torch import Tensor from torch.nn import ( Module, ModuleList, Parameter, Buffer, Linear, LayerNorm, RMSNorm, Dropout, Flatten, init ) from torch.nn.functional import pad, scaled_dot_product_attention from einops import rearrange from glu import SwiGLU class IndexedAdd(Module): def __init__( self, n_indices: int, dim: int, weight_shape: tuple[int, ...] | None = None, *, inplace: bool = False, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() self.dim = dim self.inplace = inplace self.index = Buffer(torch.empty( 2, n_indices, device=device, dtype=torch.int32 )) self.weight = Parameter(torch.ones( *(sz if sz != -1 else n_indices for sz in weight_shape), device=device, dtype=dtype )) if weight_shape is not None else None def _save_to_state_dict( self, destination: dict[str, Any], prefix: str, keep_vars: bool ) -> None: super()._save_to_state_dict(destination, prefix, keep_vars) if keep_vars: return with torch.no_grad(): index_key = f"{prefix}index" index = destination[index_key] min_index = index.amin(None).item() if min_index >= 0: max_index = index.amax(None).item() if max_index < (1 << 8): destination[index_key] = index.to(dtype=torch.uint8) elif max_index < (1 << 16): destination[index_key] = index.to(dtype=torch.uint16) @torch.no_grad() def load_indices(self, indices: Iterable[tuple[int, int]], *, mean: bool = False) -> None: if mean: if self.weight is None: raise ValueError("No weights to initialize with means.") groups: dict[int, list[int]] = defaultdict(list) idx = -1 for idx, (src, dst) in enumerate(indices): self.index[0, idx] = src self.index[1, idx] = dst if mean: groups[dst].append(idx) if (idx + 1) != self.index.size(1): raise IndexError(f"Expected {self.index.size(1)} indices, but got {idx + 1}.") if not mean: return assert self.weight is not None for idxs in groups.values(): if len(idxs) < 2: continue self.weight.index_fill_( self.dim, torch.tensor(idxs, device=self.weight.device, dtype=torch.int64), 1.0 / len(idxs) ) def forward(self, dst: Tensor, src: Tensor) -> Tensor: src = src.index_select(self.dim, self.index[0]) if self.weight is not None: src.mul_(self.weight) return ( dst.index_add_(self.dim, self.index[1], src) if self.inplace else dst.index_add(self.dim, self.index[1], src) ) class BatchLinear(Module): def __init__( self, batch_shape: tuple[int, ...] | int, in_features: int, out_features: int, *, bias: bool = False, flatten: bool = False, bias_inplace: bool = True, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() if isinstance(batch_shape, int): batch_shape = (batch_shape,) elif not batch_shape: raise ValueError("At least one batch dimension is required.") self.flatten = -(len(batch_shape) + 1) if flatten else 0 self.weight = Parameter(torch.empty( *batch_shape, in_features, out_features, device=device, dtype=dtype )) bt = self.weight.flatten(end_dim=-3).mT for idx in range(bt.size(0)): init.kaiming_uniform_(bt[idx], a=sqrt(5)) self.bias = Parameter(torch.zeros( *batch_shape, out_features, device=device, dtype=dtype )) if bias else None self.bias_inplace = bias_inplace def forward(self, x: Tensor) -> Tensor: # ... B... 1 I @ B... I O -> ... B... O x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2) if self.bias is not None: if self.bias_inplace: x.add_(self.bias) else: x = x + self.bias if self.flatten: x = x.flatten(self.flatten) return x class Mean(Module): def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None: super().__init__() self.dim = dim self.keepdim = keepdim def forward(self, x: Tensor) -> Tensor: return x.mean(self.dim, self.keepdim) class _MidBlock(Module): def __init__( self, attn_dim: int, head_dim: int, n_classes: int, *, ff_ratio: float, ff_dropout: float, q_cls_inplace: bool = True, device: torch.device | str | None, dtype: torch.dtype | None, ) -> None: super().__init__() self.head_dim = head_dim self.q_cls_inplace = q_cls_inplace hidden_dim = int(attn_dim * ff_ratio) self.q_proj = Linear( attn_dim, attn_dim, bias=False, device=device, dtype=dtype ) self.q_cls = Parameter(torch.zeros( n_classes, attn_dim, device=device, dtype=dtype )) self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False) self.attn_out = Linear( attn_dim, attn_dim, bias=False, device=device, dtype=dtype ) self.ff_norm = LayerNorm( attn_dim, device=device, dtype=dtype ) self.ff_in = Linear( attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype ) self.ff_act = SwiGLU() self.ff_drop = Dropout(ff_dropout) self.ff_out = Linear( hidden_dim, attn_dim, bias=False, device=device, dtype=dtype ) def _forward_q(self, x: Tensor) -> Tensor: x = self.q_proj(x) if self.q_cls_inplace: x.add_(self.q_cls) else: x = x + self.q_cls x = self.q_norm(x) x = rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim) return x def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor: a = scaled_dot_product_attention( self._forward_q(x), k, v, attn_mask=attn_mask ) a = rearrange(a, "... h s e -> ... s (h e)") a = self.attn_out(a) return x + a def _forward_ff(self, x: Tensor) -> Tensor: f = self.ff_norm(x) f = self.ff_in(f) f = self.ff_act(f) f = self.ff_drop(f) f = self.ff_out(f) return x + f def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor: x = self._forward_attn(x, k, v, attn_mask) x = self._forward_ff(x) return x class HydraPool(Module): def __init__( self, attn_dim: int, head_dim: int, n_classes: int, *, mid_blocks: int = 0, roots: tuple[int, int, int] = (0, 0, 0), ff_ratio: float = 3.0, ff_dropout: float = 0.0, input_dim: int = -1, output_dim: int = 1, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() if input_dim < 0: input_dim = attn_dim assert attn_dim % head_dim == 0 n_heads = attn_dim // head_dim self.n_classes = n_classes self.head_dim = head_dim self.output_dim = output_dim self._has_roots = False self._has_ff = False self.q: Parameter | Buffer self._q_normed: bool | None if roots != (0, 0, 0): self._has_roots = True n_roots, n_classroots, n_subclasses = roots if n_classroots < n_roots: raise ValueError("Number of classroots cannot be less than the number of roots.") self.cls = Parameter(torch.randn( n_heads, n_classes, head_dim, device=device, dtype=dtype )) self.roots = Parameter(torch.randn( n_heads, n_roots, head_dim, device=device, dtype=dtype )) if n_roots > 0 else None self.clsroots = IndexedAdd( n_classroots, dim=-2, weight_shape=(n_heads, -1, 1), device=device, dtype=dtype ) if n_classroots > 0 else None self.clscls = IndexedAdd( n_subclasses, dim=-2, weight_shape=(n_heads, -1, 1), inplace=True, device=device, dtype=dtype ) if n_subclasses > 0 else None self.q = Buffer(torch.empty( n_heads, n_classes, head_dim, device=device, dtype=dtype )) self._q_normed = None else: self.q = Parameter(torch.randn( n_heads, n_classes, head_dim, device=device, dtype=dtype )) self._q_normed = False self.kv = Linear( input_dim, attn_dim * 2, bias=False, device=device, dtype=dtype ) self.qk_norm = RMSNorm( head_dim, eps=1e-5, elementwise_affine=False ) if ff_ratio > 0.0: self._has_ff = True hidden_dim = int(attn_dim * ff_ratio) self.ff_norm = LayerNorm( attn_dim, device=device, dtype=dtype ) self.ff_in = Linear( attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype ) self.ff_act = SwiGLU() self.ff_drop = Dropout(ff_dropout) self.ff_out = Linear( hidden_dim, attn_dim, bias=False, device=device, dtype=dtype ) elif mid_blocks > 0: raise ValueError("Feedforward required with mid blocks.") self.mid_blocks = ModuleList( _MidBlock( attn_dim, head_dim, n_classes, ff_ratio=ff_ratio, ff_dropout=ff_dropout, device=device, dtype=dtype ) for _ in range(mid_blocks) ) self.out_proj = BatchLinear( n_classes, attn_dim, output_dim * 2, device=device, dtype=dtype ) self.out_act = SwiGLU() @property def has_roots(self) -> bool: return self._has_roots def get_extra_state(self) -> dict[str, Any]: return { "q_normed": self._q_normed } def set_extra_state(self, state: dict[str, Any]) -> None: self._q_normed = state["q_normed"] def create_head(self) -> Module: if self.output_dim == 1: return Flatten(-2) return Mean(-1) def train(self, mode: bool = True) -> Self: super().train(mode) if mode: if self._has_roots: self._q_normed = None else: self._q_normed = False else: if self._has_roots: self._cache_query() return self def inference(self) -> Self: super().train(False) self._cache_query() if self._has_roots: self._has_roots = False self.q = Parameter(self.q) del self.cls, self.roots, self.clsroots, self.clscls return self def _cache_query(self) -> None: assert not self.training if self._q_normed: return with torch.no_grad(): self.q.to(device=self.kv.weight.device) self.q.copy_(self._forward_q()) self._q_normed = True def _forward_q(self) -> Tensor: match self._q_normed: case None: assert self._has_roots if self.roots is not None: q = self.qk_norm(self.roots) q = self.clsroots(self.cls, q) else: q = self.cls if self.clscls is not None: q = self.clscls(q, q.detach()) q = self.qk_norm(q) return q case False: assert not self._has_roots return self.qk_norm(self.q) case True: return self.q def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]: q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1) x = self.kv(x) k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0) k = self.qk_norm(k) x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) return rearrange(x, "... h s e -> ... s (h e)"), k, v def _forward_ff(self, x: Tensor) -> Tensor: if not self._has_ff: return x f = self.ff_norm(x) f = self.ff_in(f) f = self.ff_act(f) f = self.ff_drop(f) f = self.ff_out(f) return x + f def _forward_out(self, x: Tensor) -> Tensor: x = self.out_proj(x) x = self.out_act(x) return x def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor: x, k, v = self._forward_attn(x, attn_mask) x = self._forward_ff(x) for block in self.mid_blocks: x = block(x, k, v, attn_mask) x = self._forward_out(x) return x def prune_roots(self, retain_classes: set[int]) -> tuple[list[int], list[int]]: if not self._has_roots or self.roots is None: raise TypeError("No roots to prune.") if self.clscls is not None: raise TypeError("Subclass roots cannot be pruned.") used_roots: set[int] = set() used_clsroots: list[int] = [] assert self.clsroots is not None clsroots = [ cast(list[int], clsroot.tolist()) for clsroot in self.clsroots.index.cpu().unbind(1) ] for idx, (src, dest) in enumerate(clsroots): if dest in retain_classes: used_roots.add(src) used_clsroots.append(idx) sorted_roots = sorted(used_roots) del used_roots rootmap = { root: idx for idx, root in enumerate(sorted_roots) } clsmap = { cls: idx for idx, cls in enumerate(sorted(retain_classes)) } for idx in used_clsroots: src, dest = clsroots[idx] self.clsroots.index[0, idx] = rootmap[src] self.clsroots.index[1, idx] = clsmap[dest] return sorted_roots, used_clsroots @staticmethod def for_state( state_dict: dict[str, Any], prefix: str = "", *, ff_dropout: float = 0.0, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> "HydraPool": n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape attn_dim = n_heads * head_dim roots_t = state_dict.get(f"{prefix}roots") clsroots_t = state_dict.get(f"{prefix}clsroots.index") clscls_t = state_dict.get(f"{prefix}clscls.index") roots = ( roots_t.size(1) if roots_t is not None else 0, clsroots_t.size(1) if clsroots_t is not None else 0, clscls_t.size(1) if clscls_t is not None else 0 ) input_dim = state_dict[f"{prefix}kv.weight"].size(1) output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2 # avoid off-by-one issue due to truncation ffout_t = state_dict.get(f"{prefix}ff_out.weight") hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0 ff_ratio = hidden_dim / attn_dim pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.") mid_blocks = max([-1, *( int(match[1]) for key in state_dict if (match := pattern.match(key)) is not None )]) + 1 return HydraPool( attn_dim, head_dim, n_classes, mid_blocks=mid_blocks, roots=roots, ff_ratio=ff_ratio, ff_dropout=ff_dropout, input_dim=input_dim, output_dim=output_dim, device=device, dtype=dtype )