Spaces:
Running
Running
Commit
·
d62ba4b
1
Parent(s):
6884ab9
JTP-3 Hydra Release
Browse files- README.md +5 -5
- app.py +434 -0
- glu.py +40 -0
- hydra_pool.py +581 -0
- image.py +271 -0
- model.py +192 -0
- requirements.txt +8 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title: JTP 3 Demo
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
@@ -11,4 +11,4 @@ license: apache-2.0
|
|
| 11 |
short_description: JTP-3 Hydra Demo
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: JTP 3 Hydra Demo
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
|
|
| 11 |
short_description: JTP-3 Hydra Demo
|
| 12 |
---
|
| 13 |
|
| 14 |
+
<a href="https://https://huggingface.co/RedRocket/JTP-3">JTP-3 Hydra Main Repository</a>
|
app.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from threading import Lock
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import Parameter
|
| 9 |
+
|
| 10 |
+
import spaces
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 15 |
+
|
| 16 |
+
import requests
|
| 17 |
+
|
| 18 |
+
from model import load_model, process_image, patchify_image
|
| 19 |
+
from image import unpatchify
|
| 20 |
+
|
| 21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
|
| 23 |
+
PATCH_SIZE = 16
|
| 24 |
+
MAX_SEQ_LEN = 1024
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
model_lock = Lock()
|
| 28 |
+
model, tag_list = load_model(
|
| 29 |
+
hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors"),
|
| 30 |
+
device=device
|
| 31 |
+
)
|
| 32 |
+
model.requires_grad_(False)
|
| 33 |
+
|
| 34 |
+
tags = {
|
| 35 |
+
tag.replace("_", " ").replace("vulva", "pussy"): idx
|
| 36 |
+
for idx, tag in enumerate(tag_list)
|
| 37 |
+
}
|
| 38 |
+
tag_list = list(tags.keys())
|
| 39 |
+
|
| 40 |
+
FONT = ImageFont.load_default(24)
|
| 41 |
+
|
| 42 |
+
@spaces.GPU(duration=5)
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def run_classifier(image: Image.Image, cam_depth: int):
|
| 45 |
+
patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 46 |
+
patches = patches.unsqueeze(0).to(device=device, non_blocking=True)
|
| 47 |
+
patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True)
|
| 48 |
+
patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True)
|
| 49 |
+
|
| 50 |
+
patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
|
| 51 |
+
patch_coords = patch_coords.to(dtype=torch.int32)
|
| 52 |
+
|
| 53 |
+
with model_lock:
|
| 54 |
+
features = model.forward_intermediates(
|
| 55 |
+
patches,
|
| 56 |
+
patch_coord=patch_coords,
|
| 57 |
+
patch_valid=patch_valid,
|
| 58 |
+
indices=cam_depth,
|
| 59 |
+
output_dict=True,
|
| 60 |
+
output_fmt='NLC'
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
logits = model.forward_head(features["image_features"], patch_valid=patch_valid)
|
| 64 |
+
del features["image_features"]
|
| 65 |
+
|
| 66 |
+
features["patch_coords"] = patch_coords
|
| 67 |
+
features["patch_valid"] = patch_valid
|
| 68 |
+
del patches, patch_coords, patch_valid
|
| 69 |
+
|
| 70 |
+
probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0) # scale to -1 to 1
|
| 71 |
+
|
| 72 |
+
values, indices = probits.cpu().topk(250)
|
| 73 |
+
predictions = {
|
| 74 |
+
tag_list[idx.item()]: val.item()
|
| 75 |
+
for idx, val in sorted(
|
| 76 |
+
zip(indices, values),
|
| 77 |
+
key=lambda item: item[1].item(),
|
| 78 |
+
reverse=True
|
| 79 |
+
)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
return features, predictions
|
| 83 |
+
|
| 84 |
+
@spaces.GPU(duration=5)
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def run_cam(
|
| 87 |
+
display_image: Image.Image,
|
| 88 |
+
image: Image.Image, features: dict[str, Tensor],
|
| 89 |
+
tag_idx: int, cam_depth: int
|
| 90 |
+
):
|
| 91 |
+
intermediates = features["image_intermediates"]
|
| 92 |
+
if len(intermediates) < cam_depth:
|
| 93 |
+
features, _ = run_classifier(image, cam_depth)
|
| 94 |
+
intermediates = features["image_intermediates"]
|
| 95 |
+
elif len(intermediates) > cam_depth:
|
| 96 |
+
intermediates = intermediates[-cam_depth:]
|
| 97 |
+
|
| 98 |
+
patch_coords = features["patch_coords"]
|
| 99 |
+
patch_valid = features["patch_valid"]
|
| 100 |
+
|
| 101 |
+
with model_lock:
|
| 102 |
+
saved_q = model.attn_pool.q
|
| 103 |
+
saved_p = model.attn_pool.out_proj.weight
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False)
|
| 107 |
+
model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False)
|
| 108 |
+
|
| 109 |
+
with torch.enable_grad():
|
| 110 |
+
for intermediate in intermediates:
|
| 111 |
+
intermediate.requires_grad_(True).retain_grad()
|
| 112 |
+
model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward()
|
| 113 |
+
finally:
|
| 114 |
+
model.attn_pool.q = saved_q
|
| 115 |
+
model.attn_pool.out_proj.weight = saved_p
|
| 116 |
+
|
| 117 |
+
cam_1d: Tensor | None = None
|
| 118 |
+
for intermediate in intermediates:
|
| 119 |
+
patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2))
|
| 120 |
+
intermediate.grad = None
|
| 121 |
+
|
| 122 |
+
if cam_1d is None:
|
| 123 |
+
cam_1d = patch_grad
|
| 124 |
+
else:
|
| 125 |
+
cam_1d.add_(patch_grad)
|
| 126 |
+
|
| 127 |
+
assert cam_1d is not None
|
| 128 |
+
|
| 129 |
+
cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy()
|
| 130 |
+
return cam_composite(display_image, cam_2d), features
|
| 131 |
+
|
| 132 |
+
def cam_composite(image: Image.Image, cam: np.ndarray):
|
| 133 |
+
"""
|
| 134 |
+
Overlays CAM on image and returns a PIL image.
|
| 135 |
+
Args:
|
| 136 |
+
image_pil: PIL Image (RGB)
|
| 137 |
+
cam: 2D numpy array (activation map)
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
PIL.Image.Image with overlay
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
cam_abs = np.abs(cam)
|
| 144 |
+
cam_scale = cam_abs.max()
|
| 145 |
+
|
| 146 |
+
cam_rgba = np.dstack((
|
| 147 |
+
(cam < 0).astype(np.float32),
|
| 148 |
+
(cam > 0).astype(np.float32),
|
| 149 |
+
np.zeros_like(cam, dtype=np.float32),
|
| 150 |
+
cam_abs * (0.5 / cam_scale),
|
| 151 |
+
)) # Shape: (H, W, 4)
|
| 152 |
+
|
| 153 |
+
cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8))
|
| 154 |
+
cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST)
|
| 155 |
+
|
| 156 |
+
image = Image.blend(
|
| 157 |
+
image.convert('RGBA'),
|
| 158 |
+
image.convert('L').convert('RGBA'),
|
| 159 |
+
0.33
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
image = Image.alpha_composite(image, cam_pil)
|
| 163 |
+
|
| 164 |
+
draw = ImageDraw.Draw(image)
|
| 165 |
+
draw.text(
|
| 166 |
+
(image.width - 7, image.height - 7),
|
| 167 |
+
f"{cam_scale.item():.4g}",
|
| 168 |
+
anchor="rd", font=FONT, fill=(32, 32, 255, 255)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return image
|
| 172 |
+
|
| 173 |
+
def filter_tags(predictions: dict[str, float], threshold: float):
|
| 174 |
+
predictions = {
|
| 175 |
+
key: value
|
| 176 |
+
for key, value in predictions.items()
|
| 177 |
+
if value >= threshold
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
tag_str = ", ".join(predictions.keys())
|
| 181 |
+
return tag_str, predictions
|
| 182 |
+
|
| 183 |
+
def resize_image(image: Image.Image) -> Image.Image:
|
| 184 |
+
longest_side = max(image.height, image.width)
|
| 185 |
+
if longest_side < 1080:
|
| 186 |
+
return image
|
| 187 |
+
|
| 188 |
+
scale = 1080 / longest_side
|
| 189 |
+
return image.resize(
|
| 190 |
+
(
|
| 191 |
+
int(round(image.width * scale)),
|
| 192 |
+
int(round(image.height * scale)),
|
| 193 |
+
),
|
| 194 |
+
resample=Image.Resampling.LANCZOS,
|
| 195 |
+
reducing_gap=3.0
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def image_upload(image: Image.Image):
|
| 199 |
+
display_image = resize_image(image)
|
| 200 |
+
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 201 |
+
|
| 202 |
+
if display_image is not image and processed_image is not image:
|
| 203 |
+
image.close()
|
| 204 |
+
|
| 205 |
+
return (
|
| 206 |
+
"", {}, "None", "",
|
| 207 |
+
gr.skip() if display_image is image else display_image, display_image,
|
| 208 |
+
processed_image,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def url_submit(url: str):
|
| 212 |
+
resp = requests.get(url, timeout=10)
|
| 213 |
+
resp.raise_for_status()
|
| 214 |
+
|
| 215 |
+
image = Image.open(BytesIO(resp.content))
|
| 216 |
+
display_image = resize_image(image)
|
| 217 |
+
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 218 |
+
|
| 219 |
+
if display_image is not image and processed_image is not image:
|
| 220 |
+
image.close()
|
| 221 |
+
|
| 222 |
+
return (
|
| 223 |
+
"", {}, "None",
|
| 224 |
+
display_image, display_image,
|
| 225 |
+
processed_image,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def image_changed(image: Image.Image, threshold: float, cam_depth: int):
|
| 229 |
+
features, predictions = run_classifier(image, cam_depth)
|
| 230 |
+
return *filter_tags(predictions, threshold), features, predictions
|
| 231 |
+
|
| 232 |
+
def image_clear():
|
| 233 |
+
return (
|
| 234 |
+
"", {}, "None", "",
|
| 235 |
+
None, None,
|
| 236 |
+
None, None, {},
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def cam_changed(
|
| 240 |
+
display_image: Image.Image,
|
| 241 |
+
image: Image.Image, features: dict[str, Tensor],
|
| 242 |
+
tag: str, cam_depth: int
|
| 243 |
+
):
|
| 244 |
+
if tag == "None":
|
| 245 |
+
return display_image, features
|
| 246 |
+
|
| 247 |
+
return run_cam(display_image, image, features, tags[tag], cam_depth)
|
| 248 |
+
|
| 249 |
+
def tag_box_select(evt: gr.SelectData):
|
| 250 |
+
return evt.value
|
| 251 |
+
|
| 252 |
+
custom_css = """
|
| 253 |
+
.output-class { display: none; }
|
| 254 |
+
.inferno-slider input[type=range] {
|
| 255 |
+
background: linear-gradient(to right,
|
| 256 |
+
#000004, #1b0c41, #4a0c6b, #781c6d,
|
| 257 |
+
#a52c60, #cf4446, #ed6925, #fb9b06,
|
| 258 |
+
#f7d13d, #fcffa4
|
| 259 |
+
) !important;
|
| 260 |
+
background-size: 100% 100% !important;
|
| 261 |
+
}
|
| 262 |
+
#image_container-image {
|
| 263 |
+
width: 100%;
|
| 264 |
+
aspect-ratio: 1 / 1;
|
| 265 |
+
max-height: 100%;
|
| 266 |
+
}
|
| 267 |
+
#image_container img {
|
| 268 |
+
object-fit: contain !important;
|
| 269 |
+
}
|
| 270 |
+
.show-api, .show-api-divider {
|
| 271 |
+
display: none !important;
|
| 272 |
+
}
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
with gr.Blocks(
|
| 276 |
+
title="RedRocket JTP-3 Hydra Demo",
|
| 277 |
+
css=custom_css,
|
| 278 |
+
analytics_enabled=False,
|
| 279 |
+
) as demo:
|
| 280 |
+
display_image_state = gr.State()
|
| 281 |
+
image_state = gr.State()
|
| 282 |
+
features_state = gr.State()
|
| 283 |
+
predictions_state = gr.State(value={})
|
| 284 |
+
|
| 285 |
+
gr.HTML(
|
| 286 |
+
"<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
|
| 287 |
+
"<a href='https://huggingface.co/RedRocket' target='_blank'>"
|
| 288 |
+
"<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>"
|
| 289 |
+
"</a>"
|
| 290 |
+
"<span><a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> – JTP-3 Hydra Demo</span>"
|
| 291 |
+
"<span style='font-weight: normal;'> • <a href='https://huggingface.co/RedRocket/JTP-3' target='_blank'>Download</a></span>"
|
| 292 |
+
"</h1>"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
with gr.Row():
|
| 296 |
+
with gr.Column():
|
| 297 |
+
with gr.Column():
|
| 298 |
+
image = gr.Image(
|
| 299 |
+
sources=['upload', 'clipboard'], type='pil',
|
| 300 |
+
show_label=False,
|
| 301 |
+
show_download_button=False,
|
| 302 |
+
show_share_button=False,
|
| 303 |
+
elem_id="image_container"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
url = gr.Textbox(
|
| 307 |
+
label="Upload Image via Url:",
|
| 308 |
+
placeholder="https://example.com/image.jpg",
|
| 309 |
+
max_lines=1,
|
| 310 |
+
submit_btn="⮝",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
with gr.Column():
|
| 314 |
+
cam_tag = gr.Dropdown(
|
| 315 |
+
value="None", choices=["None"] + tag_list,
|
| 316 |
+
label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True
|
| 317 |
+
)
|
| 318 |
+
cam_depth = gr.Slider(
|
| 319 |
+
minimum=1, maximum=27, step=1, value=1,
|
| 320 |
+
label="CAM Depth (1=fastest, more precise; 27=slowest, more general)"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
with gr.Column():
|
| 324 |
+
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold")
|
| 325 |
+
tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True)
|
| 326 |
+
tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
|
| 327 |
+
|
| 328 |
+
image.upload(
|
| 329 |
+
fn=image_upload,
|
| 330 |
+
inputs=[image],
|
| 331 |
+
outputs=[
|
| 332 |
+
tag_string, tag_box, cam_tag, url,
|
| 333 |
+
image, display_image_state,
|
| 334 |
+
image_state,
|
| 335 |
+
],
|
| 336 |
+
show_progress='minimal',
|
| 337 |
+
show_progress_on=[image]
|
| 338 |
+
).then(
|
| 339 |
+
fn=image_changed,
|
| 340 |
+
inputs=[image_state, threshold_slider, cam_depth],
|
| 341 |
+
outputs=[
|
| 342 |
+
tag_string, tag_box,
|
| 343 |
+
features_state, predictions_state,
|
| 344 |
+
],
|
| 345 |
+
show_progress='minimal',
|
| 346 |
+
show_progress_on=[tag_box]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
url.submit(
|
| 350 |
+
fn=url_submit,
|
| 351 |
+
inputs=[url],
|
| 352 |
+
outputs=[
|
| 353 |
+
tag_string, tag_box, cam_tag,
|
| 354 |
+
image, display_image_state,
|
| 355 |
+
image_state,
|
| 356 |
+
],
|
| 357 |
+
show_progress='minimal',
|
| 358 |
+
show_progress_on=[url]
|
| 359 |
+
).then(
|
| 360 |
+
fn=image_changed,
|
| 361 |
+
inputs=[image_state, threshold_slider, cam_depth],
|
| 362 |
+
outputs=[
|
| 363 |
+
tag_string, tag_box,
|
| 364 |
+
features_state, predictions_state,
|
| 365 |
+
],
|
| 366 |
+
show_progress='minimal',
|
| 367 |
+
show_progress_on=[tag_box]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
image.clear(
|
| 371 |
+
fn=image_clear,
|
| 372 |
+
inputs=[],
|
| 373 |
+
outputs=[
|
| 374 |
+
tag_string, tag_box, cam_tag, url,
|
| 375 |
+
image, display_image_state,
|
| 376 |
+
image_state, features_state, predictions_state,
|
| 377 |
+
],
|
| 378 |
+
show_progress='hidden'
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
threshold_slider.input(
|
| 382 |
+
fn=filter_tags,
|
| 383 |
+
inputs=[predictions_state, threshold_slider],
|
| 384 |
+
outputs=[tag_string, tag_box],
|
| 385 |
+
trigger_mode='always_last',
|
| 386 |
+
show_progress='hidden'
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
cam_tag.input(
|
| 390 |
+
fn=cam_changed,
|
| 391 |
+
inputs=[
|
| 392 |
+
display_image_state,
|
| 393 |
+
image_state, features_state,
|
| 394 |
+
cam_tag, cam_depth,
|
| 395 |
+
],
|
| 396 |
+
outputs=[image, features_state],
|
| 397 |
+
trigger_mode='always_last',
|
| 398 |
+
show_progress='minimal',
|
| 399 |
+
show_progress_on=[cam_tag]
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
cam_depth.input(
|
| 403 |
+
fn=cam_changed,
|
| 404 |
+
inputs=[
|
| 405 |
+
display_image_state,
|
| 406 |
+
image_state, features_state,
|
| 407 |
+
cam_tag, cam_depth,
|
| 408 |
+
],
|
| 409 |
+
outputs=[image, features_state],
|
| 410 |
+
trigger_mode='always_last',
|
| 411 |
+
show_progress='minimal',
|
| 412 |
+
show_progress_on=[cam_depth]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
tag_box.select(
|
| 416 |
+
fn=tag_box_select,
|
| 417 |
+
inputs=[],
|
| 418 |
+
outputs=[cam_tag],
|
| 419 |
+
trigger_mode='always_last',
|
| 420 |
+
show_progress='hidden',
|
| 421 |
+
).then(
|
| 422 |
+
fn=cam_changed,
|
| 423 |
+
inputs=[
|
| 424 |
+
display_image_state,
|
| 425 |
+
image_state, features_state,
|
| 426 |
+
cam_tag, cam_depth,
|
| 427 |
+
],
|
| 428 |
+
outputs=[image, features_state],
|
| 429 |
+
show_progress='minimal',
|
| 430 |
+
show_progress_on=[cam_tag]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
demo.launch()
|
glu.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn import Module
|
| 6 |
+
from torch.nn.functional import silu, gelu
|
| 7 |
+
|
| 8 |
+
class GatedUnit(Module):
|
| 9 |
+
def __init__(self, dim: int = -1) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.dim = dim
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 16 |
+
...
|
| 17 |
+
|
| 18 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 19 |
+
f, g = x.chunk(2, dim=self.dim)
|
| 20 |
+
return self._activation(f) * g
|
| 21 |
+
|
| 22 |
+
class SwiGLU(GatedUnit):
|
| 23 |
+
def __init__(self, dim: int = -1) -> None:
|
| 24 |
+
super().__init__(dim)
|
| 25 |
+
|
| 26 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 27 |
+
return silu(x)
|
| 28 |
+
|
| 29 |
+
class GeGLU(GatedUnit):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int = -1,
|
| 33 |
+
approximate: Literal["tanh", "none"] = "tanh"
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__(dim)
|
| 36 |
+
|
| 37 |
+
self.approximate = approximate
|
| 38 |
+
|
| 39 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 40 |
+
return gelu(x, self.approximate)
|
hydra_pool.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from math import sqrt
|
| 4 |
+
from typing import Any, Iterable, Self, cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import (
|
| 9 |
+
Module, ModuleList, Parameter, Buffer,
|
| 10 |
+
Linear, LayerNorm, RMSNorm, Dropout, Flatten,
|
| 11 |
+
init
|
| 12 |
+
)
|
| 13 |
+
from torch.nn.functional import pad, scaled_dot_product_attention
|
| 14 |
+
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from glu import SwiGLU
|
| 18 |
+
|
| 19 |
+
class IndexedAdd(Module):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
n_indices: int,
|
| 23 |
+
dim: int,
|
| 24 |
+
weight_shape: tuple[int, ...] | None = None,
|
| 25 |
+
*,
|
| 26 |
+
inplace: bool = False,
|
| 27 |
+
device: torch.device | str | None = None,
|
| 28 |
+
dtype: torch.dtype | None = None,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.inplace = inplace
|
| 34 |
+
|
| 35 |
+
self.index = Buffer(torch.empty(
|
| 36 |
+
2, n_indices,
|
| 37 |
+
device=device, dtype=torch.int32
|
| 38 |
+
))
|
| 39 |
+
|
| 40 |
+
self.weight = Parameter(torch.ones(
|
| 41 |
+
*(sz if sz != -1 else n_indices for sz in weight_shape),
|
| 42 |
+
device=device, dtype=dtype
|
| 43 |
+
)) if weight_shape is not None else None
|
| 44 |
+
|
| 45 |
+
def _save_to_state_dict(
|
| 46 |
+
self,
|
| 47 |
+
destination: dict[str, Any],
|
| 48 |
+
prefix: str,
|
| 49 |
+
keep_vars: bool
|
| 50 |
+
) -> None:
|
| 51 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 52 |
+
|
| 53 |
+
if keep_vars:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
index_key = f"{prefix}index"
|
| 58 |
+
index = destination[index_key]
|
| 59 |
+
|
| 60 |
+
min_index = index.amin(None).item()
|
| 61 |
+
if min_index >= 0:
|
| 62 |
+
max_index = index.amax(None).item()
|
| 63 |
+
if max_index < (1 << 8):
|
| 64 |
+
destination[index_key] = index.to(dtype=torch.uint8)
|
| 65 |
+
elif max_index < (1 << 16):
|
| 66 |
+
destination[index_key] = index.to(dtype=torch.uint16)
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def load_indices(self, indices: Iterable[tuple[int, int]], *, mean: bool = False) -> None:
|
| 70 |
+
if mean:
|
| 71 |
+
if self.weight is None:
|
| 72 |
+
raise ValueError("No weights to initialize with means.")
|
| 73 |
+
|
| 74 |
+
groups: dict[int, list[int]] = defaultdict(list)
|
| 75 |
+
|
| 76 |
+
idx = -1
|
| 77 |
+
for idx, (src, dst) in enumerate(indices):
|
| 78 |
+
self.index[0, idx] = src
|
| 79 |
+
self.index[1, idx] = dst
|
| 80 |
+
|
| 81 |
+
if mean:
|
| 82 |
+
groups[dst].append(idx)
|
| 83 |
+
|
| 84 |
+
if (idx + 1) != self.index.size(1):
|
| 85 |
+
raise IndexError(f"Expected {self.index.size(1)} indices, but got {idx + 1}.")
|
| 86 |
+
|
| 87 |
+
if not mean:
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
assert self.weight is not None
|
| 91 |
+
|
| 92 |
+
for idxs in groups.values():
|
| 93 |
+
if len(idxs) < 2:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
self.weight.index_fill_(
|
| 97 |
+
self.dim,
|
| 98 |
+
torch.tensor(idxs, device=self.weight.device, dtype=torch.int64),
|
| 99 |
+
1.0 / len(idxs)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, dst: Tensor, src: Tensor) -> Tensor:
|
| 103 |
+
src = src.index_select(self.dim, self.index[0])
|
| 104 |
+
|
| 105 |
+
if self.weight is not None:
|
| 106 |
+
src.mul_(self.weight)
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
dst.index_add_(self.dim, self.index[1], src)
|
| 110 |
+
if self.inplace else
|
| 111 |
+
dst.index_add(self.dim, self.index[1], src)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
class BatchLinear(Module):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
batch_shape: tuple[int, ...] | int,
|
| 118 |
+
in_features: int,
|
| 119 |
+
out_features: int,
|
| 120 |
+
*,
|
| 121 |
+
bias: bool = False,
|
| 122 |
+
flatten: bool = False,
|
| 123 |
+
bias_inplace: bool = True,
|
| 124 |
+
device: torch.device | str | None = None,
|
| 125 |
+
dtype: torch.dtype | None = None,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
if isinstance(batch_shape, int):
|
| 130 |
+
batch_shape = (batch_shape,)
|
| 131 |
+
elif not batch_shape:
|
| 132 |
+
raise ValueError("At least one batch dimension is required.")
|
| 133 |
+
|
| 134 |
+
self.flatten = -(len(batch_shape) + 1) if flatten else 0
|
| 135 |
+
|
| 136 |
+
self.weight = Parameter(torch.empty(
|
| 137 |
+
*batch_shape, in_features, out_features,
|
| 138 |
+
device=device, dtype=dtype
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
bt = self.weight.flatten(end_dim=-3).mT
|
| 142 |
+
for idx in range(bt.size(0)):
|
| 143 |
+
init.kaiming_uniform_(bt[idx], a=sqrt(5))
|
| 144 |
+
|
| 145 |
+
self.bias = Parameter(torch.zeros(
|
| 146 |
+
*batch_shape, out_features,
|
| 147 |
+
device=device, dtype=dtype
|
| 148 |
+
)) if bias else None
|
| 149 |
+
|
| 150 |
+
self.bias_inplace = bias_inplace
|
| 151 |
+
|
| 152 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 153 |
+
# ... B... 1 I @ B... I O -> ... B... O
|
| 154 |
+
x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
|
| 155 |
+
|
| 156 |
+
if self.bias is not None:
|
| 157 |
+
if self.bias_inplace:
|
| 158 |
+
x.add_(self.bias)
|
| 159 |
+
else:
|
| 160 |
+
x = x + self.bias
|
| 161 |
+
|
| 162 |
+
if self.flatten:
|
| 163 |
+
x = x.flatten(self.flatten)
|
| 164 |
+
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
class Mean(Module):
|
| 168 |
+
def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.dim = dim
|
| 172 |
+
self.keepdim = keepdim
|
| 173 |
+
|
| 174 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 175 |
+
return x.mean(self.dim, self.keepdim)
|
| 176 |
+
|
| 177 |
+
class _MidBlock(Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
attn_dim: int,
|
| 181 |
+
head_dim: int,
|
| 182 |
+
n_classes: int,
|
| 183 |
+
*,
|
| 184 |
+
ff_ratio: float,
|
| 185 |
+
ff_dropout: float,
|
| 186 |
+
q_cls_inplace: bool = True,
|
| 187 |
+
device: torch.device | str | None,
|
| 188 |
+
dtype: torch.dtype | None,
|
| 189 |
+
) -> None:
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
self.head_dim = head_dim
|
| 193 |
+
self.q_cls_inplace = q_cls_inplace
|
| 194 |
+
|
| 195 |
+
hidden_dim = int(attn_dim * ff_ratio)
|
| 196 |
+
|
| 197 |
+
self.q_proj = Linear(
|
| 198 |
+
attn_dim, attn_dim, bias=False,
|
| 199 |
+
device=device, dtype=dtype
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self.q_cls = Parameter(torch.zeros(
|
| 203 |
+
n_classes, attn_dim,
|
| 204 |
+
device=device, dtype=dtype
|
| 205 |
+
))
|
| 206 |
+
|
| 207 |
+
self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
|
| 208 |
+
|
| 209 |
+
self.attn_out = Linear(
|
| 210 |
+
attn_dim, attn_dim, bias=False,
|
| 211 |
+
device=device, dtype=dtype
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.ff_norm = LayerNorm(
|
| 215 |
+
attn_dim,
|
| 216 |
+
device=device, dtype=dtype
|
| 217 |
+
)
|
| 218 |
+
self.ff_in = Linear(
|
| 219 |
+
attn_dim, hidden_dim * 2, bias=False,
|
| 220 |
+
device=device, dtype=dtype
|
| 221 |
+
)
|
| 222 |
+
self.ff_act = SwiGLU()
|
| 223 |
+
self.ff_drop = Dropout(ff_dropout)
|
| 224 |
+
self.ff_out = Linear(
|
| 225 |
+
hidden_dim, attn_dim, bias=False,
|
| 226 |
+
device=device, dtype=dtype
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _forward_q(self, x: Tensor) -> Tensor:
|
| 230 |
+
x = self.q_proj(x)
|
| 231 |
+
|
| 232 |
+
if self.q_cls_inplace:
|
| 233 |
+
x.add_(self.q_cls)
|
| 234 |
+
else:
|
| 235 |
+
x = x + self.q_cls
|
| 236 |
+
|
| 237 |
+
x = self.q_norm(x)
|
| 238 |
+
x = rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
|
| 242 |
+
a = scaled_dot_product_attention(
|
| 243 |
+
self._forward_q(x), k, v,
|
| 244 |
+
attn_mask=attn_mask
|
| 245 |
+
)
|
| 246 |
+
a = rearrange(a, "... h s e -> ... s (h e)")
|
| 247 |
+
a = self.attn_out(a)
|
| 248 |
+
return x + a
|
| 249 |
+
|
| 250 |
+
def _forward_ff(self, x: Tensor) -> Tensor:
|
| 251 |
+
f = self.ff_norm(x)
|
| 252 |
+
f = self.ff_in(f)
|
| 253 |
+
f = self.ff_act(f)
|
| 254 |
+
f = self.ff_drop(f)
|
| 255 |
+
f = self.ff_out(f)
|
| 256 |
+
return x + f
|
| 257 |
+
|
| 258 |
+
def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
| 259 |
+
x = self._forward_attn(x, k, v, attn_mask)
|
| 260 |
+
x = self._forward_ff(x)
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
class HydraPool(Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
attn_dim: int,
|
| 267 |
+
head_dim: int,
|
| 268 |
+
n_classes: int,
|
| 269 |
+
*,
|
| 270 |
+
mid_blocks: int = 0,
|
| 271 |
+
roots: tuple[int, int, int] = (0, 0, 0),
|
| 272 |
+
ff_ratio: float = 3.0,
|
| 273 |
+
ff_dropout: float = 0.0,
|
| 274 |
+
input_dim: int = -1,
|
| 275 |
+
output_dim: int = 1,
|
| 276 |
+
device: torch.device | str | None = None,
|
| 277 |
+
dtype: torch.dtype | None = None,
|
| 278 |
+
) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
if input_dim < 0:
|
| 282 |
+
input_dim = attn_dim
|
| 283 |
+
|
| 284 |
+
assert attn_dim % head_dim == 0
|
| 285 |
+
n_heads = attn_dim // head_dim
|
| 286 |
+
|
| 287 |
+
self.n_classes = n_classes
|
| 288 |
+
self.head_dim = head_dim
|
| 289 |
+
self.output_dim = output_dim
|
| 290 |
+
|
| 291 |
+
self._has_roots = False
|
| 292 |
+
self._has_ff = False
|
| 293 |
+
|
| 294 |
+
self.q: Parameter | Buffer
|
| 295 |
+
self._q_normed: bool | None
|
| 296 |
+
|
| 297 |
+
if roots != (0, 0, 0):
|
| 298 |
+
self._has_roots = True
|
| 299 |
+
n_roots, n_classroots, n_subclasses = roots
|
| 300 |
+
|
| 301 |
+
if n_classroots < n_roots:
|
| 302 |
+
raise ValueError("Number of classroots cannot be less than the number of roots.")
|
| 303 |
+
|
| 304 |
+
self.cls = Parameter(torch.randn(
|
| 305 |
+
n_heads, n_classes, head_dim,
|
| 306 |
+
device=device, dtype=dtype
|
| 307 |
+
))
|
| 308 |
+
|
| 309 |
+
self.roots = Parameter(torch.randn(
|
| 310 |
+
n_heads, n_roots, head_dim,
|
| 311 |
+
device=device, dtype=dtype
|
| 312 |
+
)) if n_roots > 0 else None
|
| 313 |
+
|
| 314 |
+
self.clsroots = IndexedAdd(
|
| 315 |
+
n_classroots, dim=-2, weight_shape=(n_heads, -1, 1),
|
| 316 |
+
device=device, dtype=dtype
|
| 317 |
+
) if n_classroots > 0 else None
|
| 318 |
+
|
| 319 |
+
self.clscls = IndexedAdd(
|
| 320 |
+
n_subclasses, dim=-2, weight_shape=(n_heads, -1, 1),
|
| 321 |
+
inplace=True, device=device, dtype=dtype
|
| 322 |
+
) if n_subclasses > 0 else None
|
| 323 |
+
|
| 324 |
+
self.q = Buffer(torch.empty(
|
| 325 |
+
n_heads, n_classes, head_dim,
|
| 326 |
+
device=device, dtype=dtype
|
| 327 |
+
))
|
| 328 |
+
self._q_normed = None
|
| 329 |
+
else:
|
| 330 |
+
self.q = Parameter(torch.randn(
|
| 331 |
+
n_heads, n_classes, head_dim,
|
| 332 |
+
device=device, dtype=dtype
|
| 333 |
+
))
|
| 334 |
+
self._q_normed = False
|
| 335 |
+
|
| 336 |
+
self.kv = Linear(
|
| 337 |
+
input_dim, attn_dim * 2, bias=False,
|
| 338 |
+
device=device, dtype=dtype
|
| 339 |
+
)
|
| 340 |
+
self.qk_norm = RMSNorm(
|
| 341 |
+
head_dim, eps=1e-5, elementwise_affine=False
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if ff_ratio > 0.0:
|
| 345 |
+
self._has_ff = True
|
| 346 |
+
hidden_dim = int(attn_dim * ff_ratio)
|
| 347 |
+
|
| 348 |
+
self.ff_norm = LayerNorm(
|
| 349 |
+
attn_dim,
|
| 350 |
+
device=device, dtype=dtype
|
| 351 |
+
)
|
| 352 |
+
self.ff_in = Linear(
|
| 353 |
+
attn_dim, hidden_dim * 2, bias=False,
|
| 354 |
+
device=device, dtype=dtype
|
| 355 |
+
)
|
| 356 |
+
self.ff_act = SwiGLU()
|
| 357 |
+
self.ff_drop = Dropout(ff_dropout)
|
| 358 |
+
self.ff_out = Linear(
|
| 359 |
+
hidden_dim, attn_dim, bias=False,
|
| 360 |
+
device=device, dtype=dtype
|
| 361 |
+
)
|
| 362 |
+
elif mid_blocks > 0:
|
| 363 |
+
raise ValueError("Feedforward required with mid blocks.")
|
| 364 |
+
|
| 365 |
+
self.mid_blocks = ModuleList(
|
| 366 |
+
_MidBlock(
|
| 367 |
+
attn_dim, head_dim, n_classes,
|
| 368 |
+
ff_ratio=ff_ratio, ff_dropout=ff_dropout,
|
| 369 |
+
device=device, dtype=dtype
|
| 370 |
+
) for _ in range(mid_blocks)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
self.out_proj = BatchLinear(
|
| 374 |
+
n_classes, attn_dim, output_dim * 2,
|
| 375 |
+
device=device, dtype=dtype
|
| 376 |
+
)
|
| 377 |
+
self.out_act = SwiGLU()
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def has_roots(self) -> bool:
|
| 381 |
+
return self._has_roots
|
| 382 |
+
|
| 383 |
+
def get_extra_state(self) -> dict[str, Any]:
|
| 384 |
+
return { "q_normed": self._q_normed }
|
| 385 |
+
|
| 386 |
+
def set_extra_state(self, state: dict[str, Any]) -> None:
|
| 387 |
+
self._q_normed = state["q_normed"]
|
| 388 |
+
|
| 389 |
+
def create_head(self) -> Module:
|
| 390 |
+
if self.output_dim == 1:
|
| 391 |
+
return Flatten(-2)
|
| 392 |
+
|
| 393 |
+
return Mean(-1)
|
| 394 |
+
|
| 395 |
+
def train(self, mode: bool = True) -> Self:
|
| 396 |
+
super().train(mode)
|
| 397 |
+
|
| 398 |
+
if mode:
|
| 399 |
+
if self._has_roots:
|
| 400 |
+
self._q_normed = None
|
| 401 |
+
else:
|
| 402 |
+
self._q_normed = False
|
| 403 |
+
else:
|
| 404 |
+
if self._has_roots:
|
| 405 |
+
self._cache_query()
|
| 406 |
+
|
| 407 |
+
return self
|
| 408 |
+
|
| 409 |
+
def inference(self) -> Self:
|
| 410 |
+
super().train(False)
|
| 411 |
+
self._cache_query()
|
| 412 |
+
|
| 413 |
+
if self._has_roots:
|
| 414 |
+
self._has_roots = False
|
| 415 |
+
self.q = Parameter(self.q)
|
| 416 |
+
|
| 417 |
+
del self.cls, self.roots, self.clsroots, self.clscls
|
| 418 |
+
|
| 419 |
+
return self
|
| 420 |
+
|
| 421 |
+
def _cache_query(self) -> None:
|
| 422 |
+
assert not self.training
|
| 423 |
+
|
| 424 |
+
if self._q_normed:
|
| 425 |
+
return
|
| 426 |
+
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
self.q.to(device=self.kv.weight.device)
|
| 429 |
+
self.q.copy_(self._forward_q())
|
| 430 |
+
self._q_normed = True
|
| 431 |
+
|
| 432 |
+
def _forward_q(self) -> Tensor:
|
| 433 |
+
match self._q_normed:
|
| 434 |
+
case None:
|
| 435 |
+
assert self._has_roots
|
| 436 |
+
|
| 437 |
+
if self.roots is not None:
|
| 438 |
+
q = self.qk_norm(self.roots)
|
| 439 |
+
q = self.clsroots(self.cls, q)
|
| 440 |
+
else:
|
| 441 |
+
q = self.cls
|
| 442 |
+
|
| 443 |
+
if self.clscls is not None:
|
| 444 |
+
q = self.clscls(q, q.detach())
|
| 445 |
+
|
| 446 |
+
q = self.qk_norm(q)
|
| 447 |
+
return q
|
| 448 |
+
|
| 449 |
+
case False:
|
| 450 |
+
assert not self._has_roots
|
| 451 |
+
return self.qk_norm(self.q)
|
| 452 |
+
|
| 453 |
+
case True:
|
| 454 |
+
return self.q
|
| 455 |
+
|
| 456 |
+
def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
|
| 457 |
+
q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
|
| 458 |
+
|
| 459 |
+
x = self.kv(x)
|
| 460 |
+
k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
|
| 461 |
+
k = self.qk_norm(k)
|
| 462 |
+
|
| 463 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 464 |
+
return rearrange(x, "... h s e -> ... s (h e)"), k, v
|
| 465 |
+
|
| 466 |
+
def _forward_ff(self, x: Tensor) -> Tensor:
|
| 467 |
+
if not self._has_ff:
|
| 468 |
+
return x
|
| 469 |
+
|
| 470 |
+
f = self.ff_norm(x)
|
| 471 |
+
f = self.ff_in(f)
|
| 472 |
+
f = self.ff_act(f)
|
| 473 |
+
f = self.ff_drop(f)
|
| 474 |
+
f = self.ff_out(f)
|
| 475 |
+
return x + f
|
| 476 |
+
|
| 477 |
+
def _forward_out(self, x: Tensor) -> Tensor:
|
| 478 |
+
x = self.out_proj(x)
|
| 479 |
+
x = self.out_act(x)
|
| 480 |
+
return x
|
| 481 |
+
|
| 482 |
+
def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
| 483 |
+
x, k, v = self._forward_attn(x, attn_mask)
|
| 484 |
+
x = self._forward_ff(x)
|
| 485 |
+
|
| 486 |
+
for block in self.mid_blocks:
|
| 487 |
+
x = block(x, k, v, attn_mask)
|
| 488 |
+
|
| 489 |
+
x = self._forward_out(x)
|
| 490 |
+
return x
|
| 491 |
+
|
| 492 |
+
def prune_roots(self, retain_classes: set[int]) -> tuple[list[int], list[int]]:
|
| 493 |
+
if not self._has_roots or self.roots is None:
|
| 494 |
+
raise TypeError("No roots to prune.")
|
| 495 |
+
|
| 496 |
+
if self.clscls is not None:
|
| 497 |
+
raise TypeError("Subclass roots cannot be pruned.")
|
| 498 |
+
|
| 499 |
+
used_roots: set[int] = set()
|
| 500 |
+
used_clsroots: list[int] = []
|
| 501 |
+
|
| 502 |
+
assert self.clsroots is not None
|
| 503 |
+
clsroots = [
|
| 504 |
+
cast(list[int], clsroot.tolist())
|
| 505 |
+
for clsroot in self.clsroots.index.cpu().unbind(1)
|
| 506 |
+
]
|
| 507 |
+
|
| 508 |
+
for idx, (src, dest) in enumerate(clsroots):
|
| 509 |
+
if dest in retain_classes:
|
| 510 |
+
used_roots.add(src)
|
| 511 |
+
used_clsroots.append(idx)
|
| 512 |
+
|
| 513 |
+
sorted_roots = sorted(used_roots)
|
| 514 |
+
del used_roots
|
| 515 |
+
|
| 516 |
+
rootmap = {
|
| 517 |
+
root: idx
|
| 518 |
+
for idx, root in enumerate(sorted_roots)
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
clsmap = {
|
| 522 |
+
cls: idx
|
| 523 |
+
for idx, cls in enumerate(sorted(retain_classes))
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
for idx in used_clsroots:
|
| 527 |
+
src, dest = clsroots[idx]
|
| 528 |
+
self.clsroots.index[0, idx] = rootmap[src]
|
| 529 |
+
self.clsroots.index[1, idx] = clsmap[dest]
|
| 530 |
+
|
| 531 |
+
return sorted_roots, used_clsroots
|
| 532 |
+
|
| 533 |
+
@staticmethod
|
| 534 |
+
def for_state(
|
| 535 |
+
state_dict: dict[str, Any],
|
| 536 |
+
prefix: str = "",
|
| 537 |
+
*,
|
| 538 |
+
ff_dropout: float = 0.0,
|
| 539 |
+
device: torch.device | str | None = None,
|
| 540 |
+
dtype: torch.dtype | None = None,
|
| 541 |
+
) -> "HydraPool":
|
| 542 |
+
n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
|
| 543 |
+
attn_dim = n_heads * head_dim
|
| 544 |
+
|
| 545 |
+
roots_t = state_dict.get(f"{prefix}roots")
|
| 546 |
+
clsroots_t = state_dict.get(f"{prefix}clsroots.index")
|
| 547 |
+
clscls_t = state_dict.get(f"{prefix}clscls.index")
|
| 548 |
+
roots = (
|
| 549 |
+
roots_t.size(1) if roots_t is not None else 0,
|
| 550 |
+
clsroots_t.size(1) if clsroots_t is not None else 0,
|
| 551 |
+
clscls_t.size(1) if clscls_t is not None else 0
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
input_dim = state_dict[f"{prefix}kv.weight"].size(1)
|
| 555 |
+
output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
|
| 556 |
+
|
| 557 |
+
# avoid off-by-one issue due to truncation
|
| 558 |
+
ffout_t = state_dict.get(f"{prefix}ff_out.weight")
|
| 559 |
+
hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
|
| 560 |
+
ff_ratio = hidden_dim / attn_dim
|
| 561 |
+
|
| 562 |
+
pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
|
| 563 |
+
mid_blocks = max([-1, *(
|
| 564 |
+
int(match[1])
|
| 565 |
+
for key in state_dict
|
| 566 |
+
if (match := pattern.match(key)) is not None
|
| 567 |
+
)]) + 1
|
| 568 |
+
|
| 569 |
+
return HydraPool(
|
| 570 |
+
attn_dim,
|
| 571 |
+
head_dim,
|
| 572 |
+
n_classes,
|
| 573 |
+
mid_blocks=mid_blocks,
|
| 574 |
+
roots=roots,
|
| 575 |
+
ff_ratio=ff_ratio,
|
| 576 |
+
ff_dropout=ff_dropout,
|
| 577 |
+
input_dim=input_dim,
|
| 578 |
+
output_dim=output_dim,
|
| 579 |
+
device=device,
|
| 580 |
+
dtype=dtype
|
| 581 |
+
)
|
image.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from typing import Any, Callable, cast
|
| 3 |
+
from warnings import warn, catch_warnings, filterwarnings
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
import PIL.Image as image
|
| 11 |
+
import PIL.ImageCms as image_cms
|
| 12 |
+
|
| 13 |
+
from PIL.Image import Image, Resampling
|
| 14 |
+
from PIL.ImageCms import (
|
| 15 |
+
Direction, Intent, ImageCmsProfile, PyCMSError,
|
| 16 |
+
createProfile, getDefaultIntent, isIntentSupported, profileToProfile
|
| 17 |
+
)
|
| 18 |
+
from PIL.ImageOps import exif_transpose
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import pillow_jxl
|
| 22 |
+
except ImportError:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
image.MAX_IMAGE_PIXELS = None
|
| 26 |
+
|
| 27 |
+
_SRGB = createProfile(colorSpace='sRGB')
|
| 28 |
+
|
| 29 |
+
_INTENT_FLAGS = {
|
| 30 |
+
Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
|
| 31 |
+
Intent.RELATIVE_COLORIMETRIC: (
|
| 32 |
+
image_cms.FLAGS["HIGHRESPRECALC"] |
|
| 33 |
+
image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
|
| 34 |
+
),
|
| 35 |
+
Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
class CMSWarning(UserWarning):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
message: str,
|
| 42 |
+
*,
|
| 43 |
+
path: str | None = None,
|
| 44 |
+
cms_info: dict[str, Any] | None = None,
|
| 45 |
+
cause: Exception | None = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__(message)
|
| 48 |
+
self.__cause__ = cause
|
| 49 |
+
|
| 50 |
+
self.path = path
|
| 51 |
+
self.cms_info = cms_info
|
| 52 |
+
|
| 53 |
+
self.add_note(f"path: {path}")
|
| 54 |
+
self.add_note(f"info: {cms_info}")
|
| 55 |
+
|
| 56 |
+
def _coalesce_intent(intent: Intent | int) -> Intent:
|
| 57 |
+
if isinstance(intent, Intent):
|
| 58 |
+
return intent
|
| 59 |
+
|
| 60 |
+
match intent:
|
| 61 |
+
case 0:
|
| 62 |
+
return Intent.PERCEPTUAL
|
| 63 |
+
case 1:
|
| 64 |
+
return Intent.RELATIVE_COLORIMETRIC
|
| 65 |
+
case 2:
|
| 66 |
+
return Intent.SATURATION
|
| 67 |
+
case 3:
|
| 68 |
+
return Intent.ABSOLUTE_COLORIMETRIC
|
| 69 |
+
case _:
|
| 70 |
+
raise ValueError("invalid intent")
|
| 71 |
+
|
| 72 |
+
def _add_info(info: dict[str, Any], source: object, key: str) -> None:
|
| 73 |
+
try:
|
| 74 |
+
if (value := getattr(source, key, None)) is not None:
|
| 75 |
+
info[key] = value
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def open_srgb(
|
| 80 |
+
path: str,
|
| 81 |
+
*,
|
| 82 |
+
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
|
| 83 |
+
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
|
| 84 |
+
expect: tuple[int, int] | None = None,
|
| 85 |
+
) -> Image:
|
| 86 |
+
with open(path, "rb", buffering=(1024 * 1024)) as file:
|
| 87 |
+
img: Image = image.open(file)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
out = process_srgb(img, resize=resize, crop=crop, expect=expect)
|
| 91 |
+
except:
|
| 92 |
+
img.close()
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
if img is not out:
|
| 96 |
+
img.close()
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
def process_srgb(
|
| 101 |
+
img: Image,
|
| 102 |
+
*,
|
| 103 |
+
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
|
| 104 |
+
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
|
| 105 |
+
expect: tuple[int, int] | None = None,
|
| 106 |
+
) -> Image:
|
| 107 |
+
img.load()
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
exif_transpose(img, in_place=True)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass # corrupt EXIF metadata is fine
|
| 113 |
+
|
| 114 |
+
size = (img.width, img.height)
|
| 115 |
+
|
| 116 |
+
if expect is not None and size != expect:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
f"Image is {size[0]}x{size[1]}, "
|
| 119 |
+
f"but expected {expect[0]}x{expect[1]}."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if (icc_raw := img.info.get("icc_profile")) is not None:
|
| 123 |
+
cms_info: dict[str, Any] = {
|
| 124 |
+
"native_mode": img.mode,
|
| 125 |
+
"transparency": img.has_transparency_data,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
profile = ImageCmsProfile(BytesIO(icc_raw))
|
| 130 |
+
_add_info(cms_info, profile.profile, "profile_description")
|
| 131 |
+
_add_info(cms_info, profile.profile, "target")
|
| 132 |
+
_add_info(cms_info, profile.profile, "xcolor_space")
|
| 133 |
+
_add_info(cms_info, profile.profile, "connection_space")
|
| 134 |
+
_add_info(cms_info, profile.profile, "colorimetric_intent")
|
| 135 |
+
_add_info(cms_info, profile.profile, "rendering_intent")
|
| 136 |
+
|
| 137 |
+
working_mode = img.mode
|
| 138 |
+
if img.mode.startswith(("RGB", "BGR", "P")):
|
| 139 |
+
working_mode = "RGBA" if img.has_transparency_data else "RGB"
|
| 140 |
+
elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
|
| 141 |
+
working_mode = "LA" if img.has_transparency_data else "L"
|
| 142 |
+
|
| 143 |
+
if img.mode != working_mode:
|
| 144 |
+
cms_info["working_mode"] = working_mode
|
| 145 |
+
img = img.convert(working_mode)
|
| 146 |
+
|
| 147 |
+
mode = "RGBA" if img.has_transparency_data else "RGB"
|
| 148 |
+
|
| 149 |
+
intent = Intent.RELATIVE_COLORIMETRIC
|
| 150 |
+
if isIntentSupported(profile, intent, Direction.INPUT) != 1:
|
| 151 |
+
intent = _coalesce_intent(getDefaultIntent(profile))
|
| 152 |
+
|
| 153 |
+
cms_info["conversion_intent"] = intent
|
| 154 |
+
|
| 155 |
+
if (flags := _INTENT_FLAGS.get(intent)) is None:
|
| 156 |
+
raise RuntimeError("Unsupported intent")
|
| 157 |
+
|
| 158 |
+
if img.mode == mode:
|
| 159 |
+
profileToProfile(
|
| 160 |
+
img,
|
| 161 |
+
profile,
|
| 162 |
+
_SRGB,
|
| 163 |
+
renderingIntent=intent,
|
| 164 |
+
inPlace=True,
|
| 165 |
+
flags=flags
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
img = cast(Image, profileToProfile(
|
| 169 |
+
img,
|
| 170 |
+
profile,
|
| 171 |
+
_SRGB,
|
| 172 |
+
renderingIntent=intent,
|
| 173 |
+
outputMode=mode,
|
| 174 |
+
flags=flags
|
| 175 |
+
))
|
| 176 |
+
except Exception as ex:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
if img.has_transparency_data:
|
| 180 |
+
if img.mode != "RGBa":
|
| 181 |
+
try:
|
| 182 |
+
img = img.convert("RGBa")
|
| 183 |
+
except ValueError:
|
| 184 |
+
img = img.convert("RGBA").convert("RGBa")
|
| 185 |
+
elif img.mode != "RGB":
|
| 186 |
+
img = img.convert("RGB")
|
| 187 |
+
|
| 188 |
+
if crop is not None and not isinstance(crop, tuple):
|
| 189 |
+
crop = crop(size)
|
| 190 |
+
|
| 191 |
+
if crop is not None:
|
| 192 |
+
left, top, right, bottom = crop
|
| 193 |
+
size = (right - left, top - bottom)
|
| 194 |
+
|
| 195 |
+
if resize is not None and not isinstance(resize, tuple):
|
| 196 |
+
resize = resize(size)
|
| 197 |
+
|
| 198 |
+
if resize is not None and size != resize:
|
| 199 |
+
img = img.resize(
|
| 200 |
+
resize,
|
| 201 |
+
Resampling.LANCZOS,
|
| 202 |
+
box=crop,
|
| 203 |
+
reducing_gap=3.0
|
| 204 |
+
)
|
| 205 |
+
crop = None
|
| 206 |
+
|
| 207 |
+
if crop is not None:
|
| 208 |
+
img = img.crop(crop)
|
| 209 |
+
|
| 210 |
+
return img
|
| 211 |
+
|
| 212 |
+
def put_srgb(img: Image, tensor: Tensor) -> None:
|
| 213 |
+
if img.mode not in ("RGB", "RGBA", "RGBa"):
|
| 214 |
+
raise ValueError(f"Image has non-RGB mode {img.mode}.")
|
| 215 |
+
|
| 216 |
+
np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
|
| 217 |
+
|
| 218 |
+
def put_srgb_patch(
|
| 219 |
+
img: Image,
|
| 220 |
+
patch_data: Tensor,
|
| 221 |
+
patch_coord: Tensor,
|
| 222 |
+
patch_valid: Tensor,
|
| 223 |
+
patch_size: int
|
| 224 |
+
) -> None:
|
| 225 |
+
if img.mode not in ("RGB", "RGBA", "RGBa"):
|
| 226 |
+
raise ValueError(f"Image has non-RGB mode {img.mode}.")
|
| 227 |
+
|
| 228 |
+
patches = rearrange(
|
| 229 |
+
np.asarray(img)[:, :, :3],
|
| 230 |
+
"(h p1) (w p2) c -> h w (p1 p2 c)",
|
| 231 |
+
p1=patch_size, p2=patch_size
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
coords = np.stack(np.meshgrid(
|
| 235 |
+
np.arange(patches.shape[0], dtype=np.int16),
|
| 236 |
+
np.arange(patches.shape[1], dtype=np.int16),
|
| 237 |
+
indexing="ij"
|
| 238 |
+
), axis=-1)
|
| 239 |
+
|
| 240 |
+
coords = rearrange(coords, "h w c -> (h w) c")
|
| 241 |
+
patches = rearrange(patches, "h w p -> (h w) p")
|
| 242 |
+
n = patches.shape[0]
|
| 243 |
+
|
| 244 |
+
np.copyto(patch_data[:n].numpy(), patches, casting="no")
|
| 245 |
+
np.copyto(patch_coord[:n].numpy(), coords, casting="no")
|
| 246 |
+
patch_valid[:n] = True
|
| 247 |
+
|
| 248 |
+
def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
|
| 249 |
+
"""
|
| 250 |
+
Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
input: Tensor of shape (seqlen, ...), patch data.
|
| 254 |
+
coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
|
| 255 |
+
valid: Tensor of shape (seqlen,), boolean mask for valid patches.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
valid_coords = coords[0, valid[0]] # (n_valid, 2)
|
| 262 |
+
valid_patches = input[valid[0]] # (n_valid, ...)
|
| 263 |
+
|
| 264 |
+
h = int(valid_coords[:, 0].max().item()) + 1
|
| 265 |
+
w = int(valid_coords[:, 1].max().item()) + 1
|
| 266 |
+
|
| 267 |
+
output_shape = (h, w) + input.shape[1:]
|
| 268 |
+
output = input.new_zeros(output_shape)
|
| 269 |
+
|
| 270 |
+
output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
|
| 271 |
+
return output
|
model.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import ceil
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn import Identity
|
| 6 |
+
|
| 7 |
+
import timm
|
| 8 |
+
from timm.models import NaFlexVit
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
from image import process_srgb, put_srgb_patch
|
| 15 |
+
|
| 16 |
+
def sdpa_attn_mask(
|
| 17 |
+
patch_valid: Tensor,
|
| 18 |
+
num_prefix_tokens: int = 0,
|
| 19 |
+
symmetric: bool = True,
|
| 20 |
+
q_len: int | None = None,
|
| 21 |
+
dtype: torch.dtype | None = None,
|
| 22 |
+
) -> Tensor:
|
| 23 |
+
mask = patch_valid.unflatten(-1, (1, 1, -1))
|
| 24 |
+
|
| 25 |
+
if num_prefix_tokens:
|
| 26 |
+
mask = torch.cat((
|
| 27 |
+
torch.ones(
|
| 28 |
+
*mask.shape[:-1], num_prefix_tokens,
|
| 29 |
+
device=patch_valid.device, dtype=torch.bool
|
| 30 |
+
), mask
|
| 31 |
+
), dim=-1)
|
| 32 |
+
|
| 33 |
+
return mask
|
| 34 |
+
|
| 35 |
+
timm.models.naflexvit.create_attention_mask = sdpa_attn_mask
|
| 36 |
+
|
| 37 |
+
def get_image_size_for_seq(
|
| 38 |
+
image_hw: tuple[int, int],
|
| 39 |
+
patch_size: int = 16,
|
| 40 |
+
max_seq_len: int = 1024,
|
| 41 |
+
max_ratio: float = 1.0,
|
| 42 |
+
eps: float = 1e-5,
|
| 43 |
+
) -> tuple[int, int]:
|
| 44 |
+
"""Determine image size for sequence length constraint."""
|
| 45 |
+
|
| 46 |
+
assert max_ratio >= 1.0
|
| 47 |
+
assert eps * 2 < max_ratio
|
| 48 |
+
|
| 49 |
+
h, w = image_hw
|
| 50 |
+
max_py = int(max((h * max_ratio) // patch_size, 1))
|
| 51 |
+
max_px = int(max((w * max_ratio) // patch_size, 1))
|
| 52 |
+
|
| 53 |
+
if (max_py * max_px) <= max_seq_len:
|
| 54 |
+
return max_py * patch_size, max_px * patch_size
|
| 55 |
+
|
| 56 |
+
def patchify(ratio: float) -> tuple[int, int]:
|
| 57 |
+
return (
|
| 58 |
+
min(int(ceil((h * ratio) / patch_size)), max_py),
|
| 59 |
+
min(int(ceil((w * ratio) / patch_size)), max_px)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
py, px = patchify(eps)
|
| 63 |
+
if (py * px) > max_seq_len:
|
| 64 |
+
raise ValueError(f"Image of size {w}x{h} is too large.")
|
| 65 |
+
|
| 66 |
+
ratio = eps
|
| 67 |
+
while (max_ratio - ratio) >= eps:
|
| 68 |
+
mid = (ratio + max_ratio) / 2.0
|
| 69 |
+
|
| 70 |
+
mpy, mpx = patchify(mid)
|
| 71 |
+
seq_len = mpy * mpx
|
| 72 |
+
|
| 73 |
+
if seq_len > max_seq_len:
|
| 74 |
+
max_ratio = mid
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
ratio = mid
|
| 78 |
+
py = mpy
|
| 79 |
+
px = mpx
|
| 80 |
+
|
| 81 |
+
if seq_len == max_seq_len:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
assert py >= 1 and px >= 1
|
| 85 |
+
return py * patch_size, px * patch_size
|
| 86 |
+
|
| 87 |
+
def process_image(img: Image.Image, patch_size: int, max_seq_len: int) -> Image.Image:
|
| 88 |
+
def compute_resize(wh: tuple[int, int]) -> tuple[int, int]:
|
| 89 |
+
h, w = get_image_size_for_seq((wh[1], wh[0]), patch_size, max_seq_len)
|
| 90 |
+
return w, h
|
| 91 |
+
|
| 92 |
+
return process_srgb(img, resize=compute_resize)
|
| 93 |
+
|
| 94 |
+
def patchify_image(img: Image.Image, patch_size: int, max_seq_len: int, share_memory: bool = False) -> tuple[Tensor, Tensor, Tensor]:
|
| 95 |
+
patches = torch.zeros(max_seq_len, patch_size * patch_size * 3, device="cpu", dtype=torch.uint8)
|
| 96 |
+
patch_coords = torch.zeros(max_seq_len, 2, device="cpu", dtype=torch.int16)
|
| 97 |
+
patch_valid = torch.zeros(max_seq_len, device="cpu", dtype=torch.bool)
|
| 98 |
+
|
| 99 |
+
if share_memory:
|
| 100 |
+
patches.share_memory_()
|
| 101 |
+
patch_coords.share_memory_()
|
| 102 |
+
patch_valid.share_memory_()
|
| 103 |
+
|
| 104 |
+
put_srgb_patch(img, patches, patch_coords, patch_valid, patch_size)
|
| 105 |
+
return patches, patch_coords, patch_valid
|
| 106 |
+
|
| 107 |
+
def load_image(
|
| 108 |
+
path: str,
|
| 109 |
+
patch_size: int = 16,
|
| 110 |
+
max_seq_len: int = 1024,
|
| 111 |
+
share_memory: bool = False
|
| 112 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 113 |
+
with open(path, "rb", buffering=(1024 * 1024)) as file:
|
| 114 |
+
img: Image.Image = Image.open(file)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
processed = process_image(img, patch_size, max_seq_len)
|
| 118 |
+
except:
|
| 119 |
+
img.close()
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
if img is not processed:
|
| 123 |
+
img.close()
|
| 124 |
+
|
| 125 |
+
return patchify_image(processed, patch_size, max_seq_len, share_memory)
|
| 126 |
+
|
| 127 |
+
def load_model(path: str, device: torch.device | str | None = None) -> tuple[NaFlexVit, list[str]]:
|
| 128 |
+
with safe_open(path, framework="pt", device="cpu") as file:
|
| 129 |
+
metadata = file.metadata()
|
| 130 |
+
|
| 131 |
+
state_dict = {
|
| 132 |
+
key: file.get_tensor(key)
|
| 133 |
+
for key in file.keys()
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
arch = metadata["modelspec.architecture"]
|
| 137 |
+
if not arch.startswith("naflexvit_so400m_patch16_siglip"):
|
| 138 |
+
raise ValueError(f"Unrecognized model architecture: {arch}")
|
| 139 |
+
|
| 140 |
+
tags = metadata["classifier.labels"].split("\n")
|
| 141 |
+
|
| 142 |
+
model = timm.create_model(
|
| 143 |
+
'naflexvit_so400m_patch16_siglip',
|
| 144 |
+
pretrained=False, num_classes=0,
|
| 145 |
+
pos_embed_interp_mode="bilinear",
|
| 146 |
+
weight_init="skip", fix_init=False,
|
| 147 |
+
device="cpu", dtype=torch.bfloat16,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
match arch[31:]:
|
| 151 |
+
case "": # vanilla
|
| 152 |
+
model.reset_classifier(len(tags))
|
| 153 |
+
|
| 154 |
+
case "+rr_slim":
|
| 155 |
+
model.reset_classifier(len(tags))
|
| 156 |
+
|
| 157 |
+
if "attn_pool.q.weight" not in state_dict:
|
| 158 |
+
model.attn_pool.q = Identity()
|
| 159 |
+
|
| 160 |
+
if "head.bias" not in state_dict:
|
| 161 |
+
model.head.bias = None
|
| 162 |
+
|
| 163 |
+
case "+rr_chonker":
|
| 164 |
+
from chonker_pool import ChonkerPool
|
| 165 |
+
|
| 166 |
+
model.attn_pool = ChonkerPool(
|
| 167 |
+
2, 1152, 72,
|
| 168 |
+
device=device, dtype=torch.bfloat16
|
| 169 |
+
)
|
| 170 |
+
model.head = model.attn_pool.create_head(len(tags))
|
| 171 |
+
model.num_classes = len(tags)
|
| 172 |
+
|
| 173 |
+
case "+rr_hydra":
|
| 174 |
+
from hydra_pool import HydraPool
|
| 175 |
+
|
| 176 |
+
model.attn_pool = HydraPool.for_state(
|
| 177 |
+
state_dict, "attn_pool.",
|
| 178 |
+
device=device, dtype=torch.bfloat16
|
| 179 |
+
)
|
| 180 |
+
model.head = model.attn_pool.create_head()
|
| 181 |
+
model.num_classes = len(tags)
|
| 182 |
+
|
| 183 |
+
state_dict["attn_pool._extra_state"] = { "q_normed": True }
|
| 184 |
+
|
| 185 |
+
case _:
|
| 186 |
+
raise ValueError(f"Unrecognized model architecture: {arch}")
|
| 187 |
+
|
| 188 |
+
model.eval().to(dtype=torch.bfloat16)
|
| 189 |
+
model.load_state_dict(state_dict, strict=True)
|
| 190 |
+
model.to(device=device)
|
| 191 |
+
|
| 192 |
+
return model, tags
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
timm
|
| 3 |
+
numpy
|
| 4 |
+
pillow
|
| 5 |
+
einops
|
| 6 |
+
safetensors
|
| 7 |
+
gradio
|
| 8 |
+
requests
|