from math import ceil import torch from torch import Tensor from torch.nn import Identity import timm from timm.models import NaFlexVit from PIL import Image from safetensors import safe_open from image import process_srgb, put_srgb_patch def sdpa_attn_mask( patch_valid: Tensor, num_prefix_tokens: int = 0, symmetric: bool = True, q_len: int | None = None, dtype: torch.dtype | None = None, ) -> Tensor: mask = patch_valid.unflatten(-1, (1, 1, -1)) if num_prefix_tokens: mask = torch.cat(( torch.ones( *mask.shape[:-1], num_prefix_tokens, device=patch_valid.device, dtype=torch.bool ), mask ), dim=-1) return mask timm.models.naflexvit.create_attention_mask = sdpa_attn_mask def get_image_size_for_seq( image_hw: tuple[int, int], patch_size: int = 16, max_seq_len: int = 1024, max_ratio: float = 1.0, eps: float = 1e-5, ) -> tuple[int, int]: """Determine image size for sequence length constraint.""" assert max_ratio >= 1.0 assert eps * 2 < max_ratio h, w = image_hw max_py = int(max((h * max_ratio) // patch_size, 1)) max_px = int(max((w * max_ratio) // patch_size, 1)) if (max_py * max_px) <= max_seq_len: return max_py * patch_size, max_px * patch_size def patchify(ratio: float) -> tuple[int, int]: return ( min(int(ceil((h * ratio) / patch_size)), max_py), min(int(ceil((w * ratio) / patch_size)), max_px) ) py, px = patchify(eps) if (py * px) > max_seq_len: raise ValueError(f"Image of size {w}x{h} is too large.") ratio = eps while (max_ratio - ratio) >= eps: mid = (ratio + max_ratio) / 2.0 mpy, mpx = patchify(mid) seq_len = mpy * mpx if seq_len > max_seq_len: max_ratio = mid continue ratio = mid py = mpy px = mpx if seq_len == max_seq_len: break assert py >= 1 and px >= 1 return py * patch_size, px * patch_size def process_image(img: Image.Image, patch_size: int, max_seq_len: int) -> Image.Image: def compute_resize(wh: tuple[int, int]) -> tuple[int, int]: h, w = get_image_size_for_seq((wh[1], wh[0]), patch_size, max_seq_len) return w, h return process_srgb(img, resize=compute_resize) def patchify_image(img: Image.Image, patch_size: int, max_seq_len: int, share_memory: bool = False) -> tuple[Tensor, Tensor, Tensor]: patches = torch.zeros(max_seq_len, patch_size * patch_size * 3, device="cpu", dtype=torch.uint8) patch_coords = torch.zeros(max_seq_len, 2, device="cpu", dtype=torch.int16) patch_valid = torch.zeros(max_seq_len, device="cpu", dtype=torch.bool) if share_memory: patches.share_memory_() patch_coords.share_memory_() patch_valid.share_memory_() put_srgb_patch(img, patches, patch_coords, patch_valid, patch_size) return patches, patch_coords, patch_valid def load_image( path: str, patch_size: int = 16, max_seq_len: int = 1024, share_memory: bool = False ) -> tuple[Tensor, Tensor, Tensor]: with open(path, "rb", buffering=(1024 * 1024)) as file: img: Image.Image = Image.open(file) try: processed = process_image(img, patch_size, max_seq_len) except: img.close() raise if img is not processed: img.close() return patchify_image(processed, patch_size, max_seq_len, share_memory) def load_model(path: str, device: torch.device | str | None = None) -> tuple[NaFlexVit, list[str]]: with safe_open(path, framework="pt", device="cpu") as file: metadata = file.metadata() state_dict = { key: file.get_tensor(key) for key in file.keys() } arch = metadata["modelspec.architecture"] if not arch.startswith("naflexvit_so400m_patch16_siglip"): raise ValueError(f"Unrecognized model architecture: {arch}") tags = metadata["classifier.labels"].split("\n") model = timm.create_model( 'naflexvit_so400m_patch16_siglip', pretrained=False, num_classes=0, pos_embed_interp_mode="bilinear", weight_init="skip", fix_init=False, device="cpu", dtype=torch.bfloat16, ) match arch[31:]: case "": # vanilla model.reset_classifier(len(tags)) case "+rr_slim": model.reset_classifier(len(tags)) if "attn_pool.q.weight" not in state_dict: model.attn_pool.q = Identity() if "head.bias" not in state_dict: model.head.bias = None case "+rr_chonker": from chonker_pool import ChonkerPool model.attn_pool = ChonkerPool( 2, 1152, 72, device=device, dtype=torch.bfloat16 ) model.head = model.attn_pool.create_head(len(tags)) model.num_classes = len(tags) case "+rr_hydra": from hydra_pool import HydraPool model.attn_pool = HydraPool.for_state( state_dict, "attn_pool.", device=device, dtype=torch.bfloat16 ) model.head = model.attn_pool.create_head() model.num_classes = len(tags) state_dict["attn_pool._extra_state"] = { "q_normed": True } case _: raise ValueError(f"Unrecognized model architecture: {arch}") model.eval().to(dtype=torch.bfloat16) model.load_state_dict(state_dict, strict=True) model.to(device=device) return model, tags