RedHotTensors commited on
Commit
d62ba4b
·
1 Parent(s): 6884ab9

JTP-3 Hydra Release

Browse files
Files changed (7) hide show
  1. README.md +5 -5
  2. app.py +434 -0
  3. glu.py +40 -0
  4. hydra_pool.py +581 -0
  5. image.py +271 -0
  6. model.py +192 -0
  7. requirements.txt +8 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: JTP 3 Demo
3
- emoji: 🌖
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
@@ -11,4 +11,4 @@ license: apache-2.0
11
  short_description: JTP-3 Hydra Demo
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: JTP 3 Hydra Demo
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
 
11
  short_description: JTP-3 Hydra Demo
12
  ---
13
 
14
+ <a href="https://https://huggingface.co/RedRocket/JTP-3">JTP-3 Hydra Main Repository</a>
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from threading import Lock
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import Parameter
9
+
10
+ import spaces
11
+ from huggingface_hub import hf_hub_download
12
+ import gradio as gr
13
+
14
+ from PIL import Image, ImageDraw, ImageFont
15
+
16
+ import requests
17
+
18
+ from model import load_model, process_image, patchify_image
19
+ from image import unpatchify
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ PATCH_SIZE = 16
24
+ MAX_SEQ_LEN = 1024
25
+
26
+
27
+ model_lock = Lock()
28
+ model, tag_list = load_model(
29
+ hf_hub_download(repo_id="RedRocket/JTP-3", filename="models/jtp-3-hydra.safetensors"),
30
+ device=device
31
+ )
32
+ model.requires_grad_(False)
33
+
34
+ tags = {
35
+ tag.replace("_", " ").replace("vulva", "pussy"): idx
36
+ for idx, tag in enumerate(tag_list)
37
+ }
38
+ tag_list = list(tags.keys())
39
+
40
+ FONT = ImageFont.load_default(24)
41
+
42
+ @spaces.GPU(duration=5)
43
+ @torch.no_grad()
44
+ def run_classifier(image: Image.Image, cam_depth: int):
45
+ patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN)
46
+ patches = patches.unsqueeze(0).to(device=device, non_blocking=True)
47
+ patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True)
48
+ patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True)
49
+
50
+ patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
51
+ patch_coords = patch_coords.to(dtype=torch.int32)
52
+
53
+ with model_lock:
54
+ features = model.forward_intermediates(
55
+ patches,
56
+ patch_coord=patch_coords,
57
+ patch_valid=patch_valid,
58
+ indices=cam_depth,
59
+ output_dict=True,
60
+ output_fmt='NLC'
61
+ )
62
+
63
+ logits = model.forward_head(features["image_features"], patch_valid=patch_valid)
64
+ del features["image_features"]
65
+
66
+ features["patch_coords"] = patch_coords
67
+ features["patch_valid"] = patch_valid
68
+ del patches, patch_coords, patch_valid
69
+
70
+ probits = logits[0].float().sigmoid_().mul_(2.0).sub_(1.0) # scale to -1 to 1
71
+
72
+ values, indices = probits.cpu().topk(250)
73
+ predictions = {
74
+ tag_list[idx.item()]: val.item()
75
+ for idx, val in sorted(
76
+ zip(indices, values),
77
+ key=lambda item: item[1].item(),
78
+ reverse=True
79
+ )
80
+ }
81
+
82
+ return features, predictions
83
+
84
+ @spaces.GPU(duration=5)
85
+ @torch.no_grad()
86
+ def run_cam(
87
+ display_image: Image.Image,
88
+ image: Image.Image, features: dict[str, Tensor],
89
+ tag_idx: int, cam_depth: int
90
+ ):
91
+ intermediates = features["image_intermediates"]
92
+ if len(intermediates) < cam_depth:
93
+ features, _ = run_classifier(image, cam_depth)
94
+ intermediates = features["image_intermediates"]
95
+ elif len(intermediates) > cam_depth:
96
+ intermediates = intermediates[-cam_depth:]
97
+
98
+ patch_coords = features["patch_coords"]
99
+ patch_valid = features["patch_valid"]
100
+
101
+ with model_lock:
102
+ saved_q = model.attn_pool.q
103
+ saved_p = model.attn_pool.out_proj.weight
104
+
105
+ try:
106
+ model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False)
107
+ model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False)
108
+
109
+ with torch.enable_grad():
110
+ for intermediate in intermediates:
111
+ intermediate.requires_grad_(True).retain_grad()
112
+ model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward()
113
+ finally:
114
+ model.attn_pool.q = saved_q
115
+ model.attn_pool.out_proj.weight = saved_p
116
+
117
+ cam_1d: Tensor | None = None
118
+ for intermediate in intermediates:
119
+ patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2))
120
+ intermediate.grad = None
121
+
122
+ if cam_1d is None:
123
+ cam_1d = patch_grad
124
+ else:
125
+ cam_1d.add_(patch_grad)
126
+
127
+ assert cam_1d is not None
128
+
129
+ cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy()
130
+ return cam_composite(display_image, cam_2d), features
131
+
132
+ def cam_composite(image: Image.Image, cam: np.ndarray):
133
+ """
134
+ Overlays CAM on image and returns a PIL image.
135
+ Args:
136
+ image_pil: PIL Image (RGB)
137
+ cam: 2D numpy array (activation map)
138
+
139
+ Returns:
140
+ PIL.Image.Image with overlay
141
+ """
142
+
143
+ cam_abs = np.abs(cam)
144
+ cam_scale = cam_abs.max()
145
+
146
+ cam_rgba = np.dstack((
147
+ (cam < 0).astype(np.float32),
148
+ (cam > 0).astype(np.float32),
149
+ np.zeros_like(cam, dtype=np.float32),
150
+ cam_abs * (0.5 / cam_scale),
151
+ )) # Shape: (H, W, 4)
152
+
153
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8))
154
+ cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST)
155
+
156
+ image = Image.blend(
157
+ image.convert('RGBA'),
158
+ image.convert('L').convert('RGBA'),
159
+ 0.33
160
+ )
161
+
162
+ image = Image.alpha_composite(image, cam_pil)
163
+
164
+ draw = ImageDraw.Draw(image)
165
+ draw.text(
166
+ (image.width - 7, image.height - 7),
167
+ f"{cam_scale.item():.4g}",
168
+ anchor="rd", font=FONT, fill=(32, 32, 255, 255)
169
+ )
170
+
171
+ return image
172
+
173
+ def filter_tags(predictions: dict[str, float], threshold: float):
174
+ predictions = {
175
+ key: value
176
+ for key, value in predictions.items()
177
+ if value >= threshold
178
+ }
179
+
180
+ tag_str = ", ".join(predictions.keys())
181
+ return tag_str, predictions
182
+
183
+ def resize_image(image: Image.Image) -> Image.Image:
184
+ longest_side = max(image.height, image.width)
185
+ if longest_side < 1080:
186
+ return image
187
+
188
+ scale = 1080 / longest_side
189
+ return image.resize(
190
+ (
191
+ int(round(image.width * scale)),
192
+ int(round(image.height * scale)),
193
+ ),
194
+ resample=Image.Resampling.LANCZOS,
195
+ reducing_gap=3.0
196
+ )
197
+
198
+ def image_upload(image: Image.Image):
199
+ display_image = resize_image(image)
200
+ processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
201
+
202
+ if display_image is not image and processed_image is not image:
203
+ image.close()
204
+
205
+ return (
206
+ "", {}, "None", "",
207
+ gr.skip() if display_image is image else display_image, display_image,
208
+ processed_image,
209
+ )
210
+
211
+ def url_submit(url: str):
212
+ resp = requests.get(url, timeout=10)
213
+ resp.raise_for_status()
214
+
215
+ image = Image.open(BytesIO(resp.content))
216
+ display_image = resize_image(image)
217
+ processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
218
+
219
+ if display_image is not image and processed_image is not image:
220
+ image.close()
221
+
222
+ return (
223
+ "", {}, "None",
224
+ display_image, display_image,
225
+ processed_image,
226
+ )
227
+
228
+ def image_changed(image: Image.Image, threshold: float, cam_depth: int):
229
+ features, predictions = run_classifier(image, cam_depth)
230
+ return *filter_tags(predictions, threshold), features, predictions
231
+
232
+ def image_clear():
233
+ return (
234
+ "", {}, "None", "",
235
+ None, None,
236
+ None, None, {},
237
+ )
238
+
239
+ def cam_changed(
240
+ display_image: Image.Image,
241
+ image: Image.Image, features: dict[str, Tensor],
242
+ tag: str, cam_depth: int
243
+ ):
244
+ if tag == "None":
245
+ return display_image, features
246
+
247
+ return run_cam(display_image, image, features, tags[tag], cam_depth)
248
+
249
+ def tag_box_select(evt: gr.SelectData):
250
+ return evt.value
251
+
252
+ custom_css = """
253
+ .output-class { display: none; }
254
+ .inferno-slider input[type=range] {
255
+ background: linear-gradient(to right,
256
+ #000004, #1b0c41, #4a0c6b, #781c6d,
257
+ #a52c60, #cf4446, #ed6925, #fb9b06,
258
+ #f7d13d, #fcffa4
259
+ ) !important;
260
+ background-size: 100% 100% !important;
261
+ }
262
+ #image_container-image {
263
+ width: 100%;
264
+ aspect-ratio: 1 / 1;
265
+ max-height: 100%;
266
+ }
267
+ #image_container img {
268
+ object-fit: contain !important;
269
+ }
270
+ .show-api, .show-api-divider {
271
+ display: none !important;
272
+ }
273
+ """
274
+
275
+ with gr.Blocks(
276
+ title="RedRocket JTP-3 Hydra Demo",
277
+ css=custom_css,
278
+ analytics_enabled=False,
279
+ ) as demo:
280
+ display_image_state = gr.State()
281
+ image_state = gr.State()
282
+ features_state = gr.State()
283
+ predictions_state = gr.State(value={})
284
+
285
+ gr.HTML(
286
+ "<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
287
+ "<a href='https://huggingface.co/RedRocket' target='_blank'>"
288
+ "<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>"
289
+ "</a>"
290
+ "<span><a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> &ndash; JTP-3 Hydra Demo</span>"
291
+ "<span style='font-weight: normal;'>&nbsp;&bull;&nbsp;<a href='https://huggingface.co/RedRocket/JTP-3' target='_blank'>Download</a></span>"
292
+ "</h1>"
293
+ )
294
+
295
+ with gr.Row():
296
+ with gr.Column():
297
+ with gr.Column():
298
+ image = gr.Image(
299
+ sources=['upload', 'clipboard'], type='pil',
300
+ show_label=False,
301
+ show_download_button=False,
302
+ show_share_button=False,
303
+ elem_id="image_container"
304
+ )
305
+
306
+ url = gr.Textbox(
307
+ label="Upload Image via Url:",
308
+ placeholder="https://example.com/image.jpg",
309
+ max_lines=1,
310
+ submit_btn="⮝",
311
+ )
312
+
313
+ with gr.Column():
314
+ cam_tag = gr.Dropdown(
315
+ value="None", choices=["None"] + tag_list,
316
+ label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True
317
+ )
318
+ cam_depth = gr.Slider(
319
+ minimum=1, maximum=27, step=1, value=1,
320
+ label="CAM Depth (1=fastest, more precise; 27=slowest, more general)"
321
+ )
322
+
323
+ with gr.Column():
324
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold")
325
+ tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True)
326
+ tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
327
+
328
+ image.upload(
329
+ fn=image_upload,
330
+ inputs=[image],
331
+ outputs=[
332
+ tag_string, tag_box, cam_tag, url,
333
+ image, display_image_state,
334
+ image_state,
335
+ ],
336
+ show_progress='minimal',
337
+ show_progress_on=[image]
338
+ ).then(
339
+ fn=image_changed,
340
+ inputs=[image_state, threshold_slider, cam_depth],
341
+ outputs=[
342
+ tag_string, tag_box,
343
+ features_state, predictions_state,
344
+ ],
345
+ show_progress='minimal',
346
+ show_progress_on=[tag_box]
347
+ )
348
+
349
+ url.submit(
350
+ fn=url_submit,
351
+ inputs=[url],
352
+ outputs=[
353
+ tag_string, tag_box, cam_tag,
354
+ image, display_image_state,
355
+ image_state,
356
+ ],
357
+ show_progress='minimal',
358
+ show_progress_on=[url]
359
+ ).then(
360
+ fn=image_changed,
361
+ inputs=[image_state, threshold_slider, cam_depth],
362
+ outputs=[
363
+ tag_string, tag_box,
364
+ features_state, predictions_state,
365
+ ],
366
+ show_progress='minimal',
367
+ show_progress_on=[tag_box]
368
+ )
369
+
370
+ image.clear(
371
+ fn=image_clear,
372
+ inputs=[],
373
+ outputs=[
374
+ tag_string, tag_box, cam_tag, url,
375
+ image, display_image_state,
376
+ image_state, features_state, predictions_state,
377
+ ],
378
+ show_progress='hidden'
379
+ )
380
+
381
+ threshold_slider.input(
382
+ fn=filter_tags,
383
+ inputs=[predictions_state, threshold_slider],
384
+ outputs=[tag_string, tag_box],
385
+ trigger_mode='always_last',
386
+ show_progress='hidden'
387
+ )
388
+
389
+ cam_tag.input(
390
+ fn=cam_changed,
391
+ inputs=[
392
+ display_image_state,
393
+ image_state, features_state,
394
+ cam_tag, cam_depth,
395
+ ],
396
+ outputs=[image, features_state],
397
+ trigger_mode='always_last',
398
+ show_progress='minimal',
399
+ show_progress_on=[cam_tag]
400
+ )
401
+
402
+ cam_depth.input(
403
+ fn=cam_changed,
404
+ inputs=[
405
+ display_image_state,
406
+ image_state, features_state,
407
+ cam_tag, cam_depth,
408
+ ],
409
+ outputs=[image, features_state],
410
+ trigger_mode='always_last',
411
+ show_progress='minimal',
412
+ show_progress_on=[cam_depth]
413
+ )
414
+
415
+ tag_box.select(
416
+ fn=tag_box_select,
417
+ inputs=[],
418
+ outputs=[cam_tag],
419
+ trigger_mode='always_last',
420
+ show_progress='hidden',
421
+ ).then(
422
+ fn=cam_changed,
423
+ inputs=[
424
+ display_image_state,
425
+ image_state, features_state,
426
+ cam_tag, cam_depth,
427
+ ],
428
+ outputs=[image, features_state],
429
+ show_progress='minimal',
430
+ show_progress_on=[cam_tag]
431
+ )
432
+
433
+ if __name__ == "__main__":
434
+ demo.launch()
glu.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Literal
3
+
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+ from torch.nn.functional import silu, gelu
7
+
8
+ class GatedUnit(Module):
9
+ def __init__(self, dim: int = -1) -> None:
10
+ super().__init__()
11
+
12
+ self.dim = dim
13
+
14
+ @abstractmethod
15
+ def _activation(self, x: Tensor) -> Tensor:
16
+ ...
17
+
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ f, g = x.chunk(2, dim=self.dim)
20
+ return self._activation(f) * g
21
+
22
+ class SwiGLU(GatedUnit):
23
+ def __init__(self, dim: int = -1) -> None:
24
+ super().__init__(dim)
25
+
26
+ def _activation(self, x: Tensor) -> Tensor:
27
+ return silu(x)
28
+
29
+ class GeGLU(GatedUnit):
30
+ def __init__(
31
+ self,
32
+ dim: int = -1,
33
+ approximate: Literal["tanh", "none"] = "tanh"
34
+ ) -> None:
35
+ super().__init__(dim)
36
+
37
+ self.approximate = approximate
38
+
39
+ def _activation(self, x: Tensor) -> Tensor:
40
+ return gelu(x, self.approximate)
hydra_pool.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+ from math import sqrt
4
+ from typing import Any, Iterable, Self, cast
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import (
9
+ Module, ModuleList, Parameter, Buffer,
10
+ Linear, LayerNorm, RMSNorm, Dropout, Flatten,
11
+ init
12
+ )
13
+ from torch.nn.functional import pad, scaled_dot_product_attention
14
+
15
+ from einops import rearrange
16
+
17
+ from glu import SwiGLU
18
+
19
+ class IndexedAdd(Module):
20
+ def __init__(
21
+ self,
22
+ n_indices: int,
23
+ dim: int,
24
+ weight_shape: tuple[int, ...] | None = None,
25
+ *,
26
+ inplace: bool = False,
27
+ device: torch.device | str | None = None,
28
+ dtype: torch.dtype | None = None,
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.dim = dim
33
+ self.inplace = inplace
34
+
35
+ self.index = Buffer(torch.empty(
36
+ 2, n_indices,
37
+ device=device, dtype=torch.int32
38
+ ))
39
+
40
+ self.weight = Parameter(torch.ones(
41
+ *(sz if sz != -1 else n_indices for sz in weight_shape),
42
+ device=device, dtype=dtype
43
+ )) if weight_shape is not None else None
44
+
45
+ def _save_to_state_dict(
46
+ self,
47
+ destination: dict[str, Any],
48
+ prefix: str,
49
+ keep_vars: bool
50
+ ) -> None:
51
+ super()._save_to_state_dict(destination, prefix, keep_vars)
52
+
53
+ if keep_vars:
54
+ return
55
+
56
+ with torch.no_grad():
57
+ index_key = f"{prefix}index"
58
+ index = destination[index_key]
59
+
60
+ min_index = index.amin(None).item()
61
+ if min_index >= 0:
62
+ max_index = index.amax(None).item()
63
+ if max_index < (1 << 8):
64
+ destination[index_key] = index.to(dtype=torch.uint8)
65
+ elif max_index < (1 << 16):
66
+ destination[index_key] = index.to(dtype=torch.uint16)
67
+
68
+ @torch.no_grad()
69
+ def load_indices(self, indices: Iterable[tuple[int, int]], *, mean: bool = False) -> None:
70
+ if mean:
71
+ if self.weight is None:
72
+ raise ValueError("No weights to initialize with means.")
73
+
74
+ groups: dict[int, list[int]] = defaultdict(list)
75
+
76
+ idx = -1
77
+ for idx, (src, dst) in enumerate(indices):
78
+ self.index[0, idx] = src
79
+ self.index[1, idx] = dst
80
+
81
+ if mean:
82
+ groups[dst].append(idx)
83
+
84
+ if (idx + 1) != self.index.size(1):
85
+ raise IndexError(f"Expected {self.index.size(1)} indices, but got {idx + 1}.")
86
+
87
+ if not mean:
88
+ return
89
+
90
+ assert self.weight is not None
91
+
92
+ for idxs in groups.values():
93
+ if len(idxs) < 2:
94
+ continue
95
+
96
+ self.weight.index_fill_(
97
+ self.dim,
98
+ torch.tensor(idxs, device=self.weight.device, dtype=torch.int64),
99
+ 1.0 / len(idxs)
100
+ )
101
+
102
+ def forward(self, dst: Tensor, src: Tensor) -> Tensor:
103
+ src = src.index_select(self.dim, self.index[0])
104
+
105
+ if self.weight is not None:
106
+ src.mul_(self.weight)
107
+
108
+ return (
109
+ dst.index_add_(self.dim, self.index[1], src)
110
+ if self.inplace else
111
+ dst.index_add(self.dim, self.index[1], src)
112
+ )
113
+
114
+ class BatchLinear(Module):
115
+ def __init__(
116
+ self,
117
+ batch_shape: tuple[int, ...] | int,
118
+ in_features: int,
119
+ out_features: int,
120
+ *,
121
+ bias: bool = False,
122
+ flatten: bool = False,
123
+ bias_inplace: bool = True,
124
+ device: torch.device | str | None = None,
125
+ dtype: torch.dtype | None = None,
126
+ ) -> None:
127
+ super().__init__()
128
+
129
+ if isinstance(batch_shape, int):
130
+ batch_shape = (batch_shape,)
131
+ elif not batch_shape:
132
+ raise ValueError("At least one batch dimension is required.")
133
+
134
+ self.flatten = -(len(batch_shape) + 1) if flatten else 0
135
+
136
+ self.weight = Parameter(torch.empty(
137
+ *batch_shape, in_features, out_features,
138
+ device=device, dtype=dtype
139
+ ))
140
+
141
+ bt = self.weight.flatten(end_dim=-3).mT
142
+ for idx in range(bt.size(0)):
143
+ init.kaiming_uniform_(bt[idx], a=sqrt(5))
144
+
145
+ self.bias = Parameter(torch.zeros(
146
+ *batch_shape, out_features,
147
+ device=device, dtype=dtype
148
+ )) if bias else None
149
+
150
+ self.bias_inplace = bias_inplace
151
+
152
+ def forward(self, x: Tensor) -> Tensor:
153
+ # ... B... 1 I @ B... I O -> ... B... O
154
+ x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
155
+
156
+ if self.bias is not None:
157
+ if self.bias_inplace:
158
+ x.add_(self.bias)
159
+ else:
160
+ x = x + self.bias
161
+
162
+ if self.flatten:
163
+ x = x.flatten(self.flatten)
164
+
165
+ return x
166
+
167
+ class Mean(Module):
168
+ def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
169
+ super().__init__()
170
+
171
+ self.dim = dim
172
+ self.keepdim = keepdim
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ return x.mean(self.dim, self.keepdim)
176
+
177
+ class _MidBlock(Module):
178
+ def __init__(
179
+ self,
180
+ attn_dim: int,
181
+ head_dim: int,
182
+ n_classes: int,
183
+ *,
184
+ ff_ratio: float,
185
+ ff_dropout: float,
186
+ q_cls_inplace: bool = True,
187
+ device: torch.device | str | None,
188
+ dtype: torch.dtype | None,
189
+ ) -> None:
190
+ super().__init__()
191
+
192
+ self.head_dim = head_dim
193
+ self.q_cls_inplace = q_cls_inplace
194
+
195
+ hidden_dim = int(attn_dim * ff_ratio)
196
+
197
+ self.q_proj = Linear(
198
+ attn_dim, attn_dim, bias=False,
199
+ device=device, dtype=dtype
200
+ )
201
+
202
+ self.q_cls = Parameter(torch.zeros(
203
+ n_classes, attn_dim,
204
+ device=device, dtype=dtype
205
+ ))
206
+
207
+ self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
208
+
209
+ self.attn_out = Linear(
210
+ attn_dim, attn_dim, bias=False,
211
+ device=device, dtype=dtype
212
+ )
213
+
214
+ self.ff_norm = LayerNorm(
215
+ attn_dim,
216
+ device=device, dtype=dtype
217
+ )
218
+ self.ff_in = Linear(
219
+ attn_dim, hidden_dim * 2, bias=False,
220
+ device=device, dtype=dtype
221
+ )
222
+ self.ff_act = SwiGLU()
223
+ self.ff_drop = Dropout(ff_dropout)
224
+ self.ff_out = Linear(
225
+ hidden_dim, attn_dim, bias=False,
226
+ device=device, dtype=dtype
227
+ )
228
+
229
+ def _forward_q(self, x: Tensor) -> Tensor:
230
+ x = self.q_proj(x)
231
+
232
+ if self.q_cls_inplace:
233
+ x.add_(self.q_cls)
234
+ else:
235
+ x = x + self.q_cls
236
+
237
+ x = self.q_norm(x)
238
+ x = rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)
239
+ return x
240
+
241
+ def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
242
+ a = scaled_dot_product_attention(
243
+ self._forward_q(x), k, v,
244
+ attn_mask=attn_mask
245
+ )
246
+ a = rearrange(a, "... h s e -> ... s (h e)")
247
+ a = self.attn_out(a)
248
+ return x + a
249
+
250
+ def _forward_ff(self, x: Tensor) -> Tensor:
251
+ f = self.ff_norm(x)
252
+ f = self.ff_in(f)
253
+ f = self.ff_act(f)
254
+ f = self.ff_drop(f)
255
+ f = self.ff_out(f)
256
+ return x + f
257
+
258
+ def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
259
+ x = self._forward_attn(x, k, v, attn_mask)
260
+ x = self._forward_ff(x)
261
+ return x
262
+
263
+ class HydraPool(Module):
264
+ def __init__(
265
+ self,
266
+ attn_dim: int,
267
+ head_dim: int,
268
+ n_classes: int,
269
+ *,
270
+ mid_blocks: int = 0,
271
+ roots: tuple[int, int, int] = (0, 0, 0),
272
+ ff_ratio: float = 3.0,
273
+ ff_dropout: float = 0.0,
274
+ input_dim: int = -1,
275
+ output_dim: int = 1,
276
+ device: torch.device | str | None = None,
277
+ dtype: torch.dtype | None = None,
278
+ ) -> None:
279
+ super().__init__()
280
+
281
+ if input_dim < 0:
282
+ input_dim = attn_dim
283
+
284
+ assert attn_dim % head_dim == 0
285
+ n_heads = attn_dim // head_dim
286
+
287
+ self.n_classes = n_classes
288
+ self.head_dim = head_dim
289
+ self.output_dim = output_dim
290
+
291
+ self._has_roots = False
292
+ self._has_ff = False
293
+
294
+ self.q: Parameter | Buffer
295
+ self._q_normed: bool | None
296
+
297
+ if roots != (0, 0, 0):
298
+ self._has_roots = True
299
+ n_roots, n_classroots, n_subclasses = roots
300
+
301
+ if n_classroots < n_roots:
302
+ raise ValueError("Number of classroots cannot be less than the number of roots.")
303
+
304
+ self.cls = Parameter(torch.randn(
305
+ n_heads, n_classes, head_dim,
306
+ device=device, dtype=dtype
307
+ ))
308
+
309
+ self.roots = Parameter(torch.randn(
310
+ n_heads, n_roots, head_dim,
311
+ device=device, dtype=dtype
312
+ )) if n_roots > 0 else None
313
+
314
+ self.clsroots = IndexedAdd(
315
+ n_classroots, dim=-2, weight_shape=(n_heads, -1, 1),
316
+ device=device, dtype=dtype
317
+ ) if n_classroots > 0 else None
318
+
319
+ self.clscls = IndexedAdd(
320
+ n_subclasses, dim=-2, weight_shape=(n_heads, -1, 1),
321
+ inplace=True, device=device, dtype=dtype
322
+ ) if n_subclasses > 0 else None
323
+
324
+ self.q = Buffer(torch.empty(
325
+ n_heads, n_classes, head_dim,
326
+ device=device, dtype=dtype
327
+ ))
328
+ self._q_normed = None
329
+ else:
330
+ self.q = Parameter(torch.randn(
331
+ n_heads, n_classes, head_dim,
332
+ device=device, dtype=dtype
333
+ ))
334
+ self._q_normed = False
335
+
336
+ self.kv = Linear(
337
+ input_dim, attn_dim * 2, bias=False,
338
+ device=device, dtype=dtype
339
+ )
340
+ self.qk_norm = RMSNorm(
341
+ head_dim, eps=1e-5, elementwise_affine=False
342
+ )
343
+
344
+ if ff_ratio > 0.0:
345
+ self._has_ff = True
346
+ hidden_dim = int(attn_dim * ff_ratio)
347
+
348
+ self.ff_norm = LayerNorm(
349
+ attn_dim,
350
+ device=device, dtype=dtype
351
+ )
352
+ self.ff_in = Linear(
353
+ attn_dim, hidden_dim * 2, bias=False,
354
+ device=device, dtype=dtype
355
+ )
356
+ self.ff_act = SwiGLU()
357
+ self.ff_drop = Dropout(ff_dropout)
358
+ self.ff_out = Linear(
359
+ hidden_dim, attn_dim, bias=False,
360
+ device=device, dtype=dtype
361
+ )
362
+ elif mid_blocks > 0:
363
+ raise ValueError("Feedforward required with mid blocks.")
364
+
365
+ self.mid_blocks = ModuleList(
366
+ _MidBlock(
367
+ attn_dim, head_dim, n_classes,
368
+ ff_ratio=ff_ratio, ff_dropout=ff_dropout,
369
+ device=device, dtype=dtype
370
+ ) for _ in range(mid_blocks)
371
+ )
372
+
373
+ self.out_proj = BatchLinear(
374
+ n_classes, attn_dim, output_dim * 2,
375
+ device=device, dtype=dtype
376
+ )
377
+ self.out_act = SwiGLU()
378
+
379
+ @property
380
+ def has_roots(self) -> bool:
381
+ return self._has_roots
382
+
383
+ def get_extra_state(self) -> dict[str, Any]:
384
+ return { "q_normed": self._q_normed }
385
+
386
+ def set_extra_state(self, state: dict[str, Any]) -> None:
387
+ self._q_normed = state["q_normed"]
388
+
389
+ def create_head(self) -> Module:
390
+ if self.output_dim == 1:
391
+ return Flatten(-2)
392
+
393
+ return Mean(-1)
394
+
395
+ def train(self, mode: bool = True) -> Self:
396
+ super().train(mode)
397
+
398
+ if mode:
399
+ if self._has_roots:
400
+ self._q_normed = None
401
+ else:
402
+ self._q_normed = False
403
+ else:
404
+ if self._has_roots:
405
+ self._cache_query()
406
+
407
+ return self
408
+
409
+ def inference(self) -> Self:
410
+ super().train(False)
411
+ self._cache_query()
412
+
413
+ if self._has_roots:
414
+ self._has_roots = False
415
+ self.q = Parameter(self.q)
416
+
417
+ del self.cls, self.roots, self.clsroots, self.clscls
418
+
419
+ return self
420
+
421
+ def _cache_query(self) -> None:
422
+ assert not self.training
423
+
424
+ if self._q_normed:
425
+ return
426
+
427
+ with torch.no_grad():
428
+ self.q.to(device=self.kv.weight.device)
429
+ self.q.copy_(self._forward_q())
430
+ self._q_normed = True
431
+
432
+ def _forward_q(self) -> Tensor:
433
+ match self._q_normed:
434
+ case None:
435
+ assert self._has_roots
436
+
437
+ if self.roots is not None:
438
+ q = self.qk_norm(self.roots)
439
+ q = self.clsroots(self.cls, q)
440
+ else:
441
+ q = self.cls
442
+
443
+ if self.clscls is not None:
444
+ q = self.clscls(q, q.detach())
445
+
446
+ q = self.qk_norm(q)
447
+ return q
448
+
449
+ case False:
450
+ assert not self._has_roots
451
+ return self.qk_norm(self.q)
452
+
453
+ case True:
454
+ return self.q
455
+
456
+ def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
457
+ q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
458
+
459
+ x = self.kv(x)
460
+ k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
461
+ k = self.qk_norm(k)
462
+
463
+ x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
464
+ return rearrange(x, "... h s e -> ... s (h e)"), k, v
465
+
466
+ def _forward_ff(self, x: Tensor) -> Tensor:
467
+ if not self._has_ff:
468
+ return x
469
+
470
+ f = self.ff_norm(x)
471
+ f = self.ff_in(f)
472
+ f = self.ff_act(f)
473
+ f = self.ff_drop(f)
474
+ f = self.ff_out(f)
475
+ return x + f
476
+
477
+ def _forward_out(self, x: Tensor) -> Tensor:
478
+ x = self.out_proj(x)
479
+ x = self.out_act(x)
480
+ return x
481
+
482
+ def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
483
+ x, k, v = self._forward_attn(x, attn_mask)
484
+ x = self._forward_ff(x)
485
+
486
+ for block in self.mid_blocks:
487
+ x = block(x, k, v, attn_mask)
488
+
489
+ x = self._forward_out(x)
490
+ return x
491
+
492
+ def prune_roots(self, retain_classes: set[int]) -> tuple[list[int], list[int]]:
493
+ if not self._has_roots or self.roots is None:
494
+ raise TypeError("No roots to prune.")
495
+
496
+ if self.clscls is not None:
497
+ raise TypeError("Subclass roots cannot be pruned.")
498
+
499
+ used_roots: set[int] = set()
500
+ used_clsroots: list[int] = []
501
+
502
+ assert self.clsroots is not None
503
+ clsroots = [
504
+ cast(list[int], clsroot.tolist())
505
+ for clsroot in self.clsroots.index.cpu().unbind(1)
506
+ ]
507
+
508
+ for idx, (src, dest) in enumerate(clsroots):
509
+ if dest in retain_classes:
510
+ used_roots.add(src)
511
+ used_clsroots.append(idx)
512
+
513
+ sorted_roots = sorted(used_roots)
514
+ del used_roots
515
+
516
+ rootmap = {
517
+ root: idx
518
+ for idx, root in enumerate(sorted_roots)
519
+ }
520
+
521
+ clsmap = {
522
+ cls: idx
523
+ for idx, cls in enumerate(sorted(retain_classes))
524
+ }
525
+
526
+ for idx in used_clsroots:
527
+ src, dest = clsroots[idx]
528
+ self.clsroots.index[0, idx] = rootmap[src]
529
+ self.clsroots.index[1, idx] = clsmap[dest]
530
+
531
+ return sorted_roots, used_clsroots
532
+
533
+ @staticmethod
534
+ def for_state(
535
+ state_dict: dict[str, Any],
536
+ prefix: str = "",
537
+ *,
538
+ ff_dropout: float = 0.0,
539
+ device: torch.device | str | None = None,
540
+ dtype: torch.dtype | None = None,
541
+ ) -> "HydraPool":
542
+ n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
543
+ attn_dim = n_heads * head_dim
544
+
545
+ roots_t = state_dict.get(f"{prefix}roots")
546
+ clsroots_t = state_dict.get(f"{prefix}clsroots.index")
547
+ clscls_t = state_dict.get(f"{prefix}clscls.index")
548
+ roots = (
549
+ roots_t.size(1) if roots_t is not None else 0,
550
+ clsroots_t.size(1) if clsroots_t is not None else 0,
551
+ clscls_t.size(1) if clscls_t is not None else 0
552
+ )
553
+
554
+ input_dim = state_dict[f"{prefix}kv.weight"].size(1)
555
+ output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
556
+
557
+ # avoid off-by-one issue due to truncation
558
+ ffout_t = state_dict.get(f"{prefix}ff_out.weight")
559
+ hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
560
+ ff_ratio = hidden_dim / attn_dim
561
+
562
+ pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
563
+ mid_blocks = max([-1, *(
564
+ int(match[1])
565
+ for key in state_dict
566
+ if (match := pattern.match(key)) is not None
567
+ )]) + 1
568
+
569
+ return HydraPool(
570
+ attn_dim,
571
+ head_dim,
572
+ n_classes,
573
+ mid_blocks=mid_blocks,
574
+ roots=roots,
575
+ ff_ratio=ff_ratio,
576
+ ff_dropout=ff_dropout,
577
+ input_dim=input_dim,
578
+ output_dim=output_dim,
579
+ device=device,
580
+ dtype=dtype
581
+ )
image.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Callable, cast
3
+ from warnings import warn, catch_warnings, filterwarnings
4
+
5
+ import numpy as np
6
+ from torch import Tensor
7
+
8
+ from einops import rearrange
9
+
10
+ import PIL.Image as image
11
+ import PIL.ImageCms as image_cms
12
+
13
+ from PIL.Image import Image, Resampling
14
+ from PIL.ImageCms import (
15
+ Direction, Intent, ImageCmsProfile, PyCMSError,
16
+ createProfile, getDefaultIntent, isIntentSupported, profileToProfile
17
+ )
18
+ from PIL.ImageOps import exif_transpose
19
+
20
+ try:
21
+ import pillow_jxl
22
+ except ImportError:
23
+ pass
24
+
25
+ image.MAX_IMAGE_PIXELS = None
26
+
27
+ _SRGB = createProfile(colorSpace='sRGB')
28
+
29
+ _INTENT_FLAGS = {
30
+ Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
31
+ Intent.RELATIVE_COLORIMETRIC: (
32
+ image_cms.FLAGS["HIGHRESPRECALC"] |
33
+ image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
34
+ ),
35
+ Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
36
+ }
37
+
38
+ class CMSWarning(UserWarning):
39
+ def __init__(
40
+ self,
41
+ message: str,
42
+ *,
43
+ path: str | None = None,
44
+ cms_info: dict[str, Any] | None = None,
45
+ cause: Exception | None = None,
46
+ ):
47
+ super().__init__(message)
48
+ self.__cause__ = cause
49
+
50
+ self.path = path
51
+ self.cms_info = cms_info
52
+
53
+ self.add_note(f"path: {path}")
54
+ self.add_note(f"info: {cms_info}")
55
+
56
+ def _coalesce_intent(intent: Intent | int) -> Intent:
57
+ if isinstance(intent, Intent):
58
+ return intent
59
+
60
+ match intent:
61
+ case 0:
62
+ return Intent.PERCEPTUAL
63
+ case 1:
64
+ return Intent.RELATIVE_COLORIMETRIC
65
+ case 2:
66
+ return Intent.SATURATION
67
+ case 3:
68
+ return Intent.ABSOLUTE_COLORIMETRIC
69
+ case _:
70
+ raise ValueError("invalid intent")
71
+
72
+ def _add_info(info: dict[str, Any], source: object, key: str) -> None:
73
+ try:
74
+ if (value := getattr(source, key, None)) is not None:
75
+ info[key] = value
76
+ except Exception:
77
+ pass
78
+
79
+ def open_srgb(
80
+ path: str,
81
+ *,
82
+ resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
83
+ crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
84
+ expect: tuple[int, int] | None = None,
85
+ ) -> Image:
86
+ with open(path, "rb", buffering=(1024 * 1024)) as file:
87
+ img: Image = image.open(file)
88
+
89
+ try:
90
+ out = process_srgb(img, resize=resize, crop=crop, expect=expect)
91
+ except:
92
+ img.close()
93
+ raise
94
+
95
+ if img is not out:
96
+ img.close()
97
+
98
+ return out
99
+
100
+ def process_srgb(
101
+ img: Image,
102
+ *,
103
+ resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
104
+ crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
105
+ expect: tuple[int, int] | None = None,
106
+ ) -> Image:
107
+ img.load()
108
+
109
+ try:
110
+ exif_transpose(img, in_place=True)
111
+ except Exception:
112
+ pass # corrupt EXIF metadata is fine
113
+
114
+ size = (img.width, img.height)
115
+
116
+ if expect is not None and size != expect:
117
+ raise RuntimeError(
118
+ f"Image is {size[0]}x{size[1]}, "
119
+ f"but expected {expect[0]}x{expect[1]}."
120
+ )
121
+
122
+ if (icc_raw := img.info.get("icc_profile")) is not None:
123
+ cms_info: dict[str, Any] = {
124
+ "native_mode": img.mode,
125
+ "transparency": img.has_transparency_data,
126
+ }
127
+
128
+ try:
129
+ profile = ImageCmsProfile(BytesIO(icc_raw))
130
+ _add_info(cms_info, profile.profile, "profile_description")
131
+ _add_info(cms_info, profile.profile, "target")
132
+ _add_info(cms_info, profile.profile, "xcolor_space")
133
+ _add_info(cms_info, profile.profile, "connection_space")
134
+ _add_info(cms_info, profile.profile, "colorimetric_intent")
135
+ _add_info(cms_info, profile.profile, "rendering_intent")
136
+
137
+ working_mode = img.mode
138
+ if img.mode.startswith(("RGB", "BGR", "P")):
139
+ working_mode = "RGBA" if img.has_transparency_data else "RGB"
140
+ elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
141
+ working_mode = "LA" if img.has_transparency_data else "L"
142
+
143
+ if img.mode != working_mode:
144
+ cms_info["working_mode"] = working_mode
145
+ img = img.convert(working_mode)
146
+
147
+ mode = "RGBA" if img.has_transparency_data else "RGB"
148
+
149
+ intent = Intent.RELATIVE_COLORIMETRIC
150
+ if isIntentSupported(profile, intent, Direction.INPUT) != 1:
151
+ intent = _coalesce_intent(getDefaultIntent(profile))
152
+
153
+ cms_info["conversion_intent"] = intent
154
+
155
+ if (flags := _INTENT_FLAGS.get(intent)) is None:
156
+ raise RuntimeError("Unsupported intent")
157
+
158
+ if img.mode == mode:
159
+ profileToProfile(
160
+ img,
161
+ profile,
162
+ _SRGB,
163
+ renderingIntent=intent,
164
+ inPlace=True,
165
+ flags=flags
166
+ )
167
+ else:
168
+ img = cast(Image, profileToProfile(
169
+ img,
170
+ profile,
171
+ _SRGB,
172
+ renderingIntent=intent,
173
+ outputMode=mode,
174
+ flags=flags
175
+ ))
176
+ except Exception as ex:
177
+ pass
178
+
179
+ if img.has_transparency_data:
180
+ if img.mode != "RGBa":
181
+ try:
182
+ img = img.convert("RGBa")
183
+ except ValueError:
184
+ img = img.convert("RGBA").convert("RGBa")
185
+ elif img.mode != "RGB":
186
+ img = img.convert("RGB")
187
+
188
+ if crop is not None and not isinstance(crop, tuple):
189
+ crop = crop(size)
190
+
191
+ if crop is not None:
192
+ left, top, right, bottom = crop
193
+ size = (right - left, top - bottom)
194
+
195
+ if resize is not None and not isinstance(resize, tuple):
196
+ resize = resize(size)
197
+
198
+ if resize is not None and size != resize:
199
+ img = img.resize(
200
+ resize,
201
+ Resampling.LANCZOS,
202
+ box=crop,
203
+ reducing_gap=3.0
204
+ )
205
+ crop = None
206
+
207
+ if crop is not None:
208
+ img = img.crop(crop)
209
+
210
+ return img
211
+
212
+ def put_srgb(img: Image, tensor: Tensor) -> None:
213
+ if img.mode not in ("RGB", "RGBA", "RGBa"):
214
+ raise ValueError(f"Image has non-RGB mode {img.mode}.")
215
+
216
+ np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
217
+
218
+ def put_srgb_patch(
219
+ img: Image,
220
+ patch_data: Tensor,
221
+ patch_coord: Tensor,
222
+ patch_valid: Tensor,
223
+ patch_size: int
224
+ ) -> None:
225
+ if img.mode not in ("RGB", "RGBA", "RGBa"):
226
+ raise ValueError(f"Image has non-RGB mode {img.mode}.")
227
+
228
+ patches = rearrange(
229
+ np.asarray(img)[:, :, :3],
230
+ "(h p1) (w p2) c -> h w (p1 p2 c)",
231
+ p1=patch_size, p2=patch_size
232
+ )
233
+
234
+ coords = np.stack(np.meshgrid(
235
+ np.arange(patches.shape[0], dtype=np.int16),
236
+ np.arange(patches.shape[1], dtype=np.int16),
237
+ indexing="ij"
238
+ ), axis=-1)
239
+
240
+ coords = rearrange(coords, "h w c -> (h w) c")
241
+ patches = rearrange(patches, "h w p -> (h w) p")
242
+ n = patches.shape[0]
243
+
244
+ np.copyto(patch_data[:n].numpy(), patches, casting="no")
245
+ np.copyto(patch_coord[:n].numpy(), coords, casting="no")
246
+ patch_valid[:n] = True
247
+
248
+ def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
249
+ """
250
+ Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.
251
+
252
+ Args:
253
+ input: Tensor of shape (seqlen, ...), patch data.
254
+ coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
255
+ valid: Tensor of shape (seqlen,), boolean mask for valid patches.
256
+
257
+ Returns:
258
+ Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
259
+ """
260
+
261
+ valid_coords = coords[0, valid[0]] # (n_valid, 2)
262
+ valid_patches = input[valid[0]] # (n_valid, ...)
263
+
264
+ h = int(valid_coords[:, 0].max().item()) + 1
265
+ w = int(valid_coords[:, 1].max().item()) + 1
266
+
267
+ output_shape = (h, w) + input.shape[1:]
268
+ output = input.new_zeros(output_shape)
269
+
270
+ output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
271
+ return output
model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Identity
6
+
7
+ import timm
8
+ from timm.models import NaFlexVit
9
+
10
+ from PIL import Image
11
+
12
+ from safetensors import safe_open
13
+
14
+ from image import process_srgb, put_srgb_patch
15
+
16
+ def sdpa_attn_mask(
17
+ patch_valid: Tensor,
18
+ num_prefix_tokens: int = 0,
19
+ symmetric: bool = True,
20
+ q_len: int | None = None,
21
+ dtype: torch.dtype | None = None,
22
+ ) -> Tensor:
23
+ mask = patch_valid.unflatten(-1, (1, 1, -1))
24
+
25
+ if num_prefix_tokens:
26
+ mask = torch.cat((
27
+ torch.ones(
28
+ *mask.shape[:-1], num_prefix_tokens,
29
+ device=patch_valid.device, dtype=torch.bool
30
+ ), mask
31
+ ), dim=-1)
32
+
33
+ return mask
34
+
35
+ timm.models.naflexvit.create_attention_mask = sdpa_attn_mask
36
+
37
+ def get_image_size_for_seq(
38
+ image_hw: tuple[int, int],
39
+ patch_size: int = 16,
40
+ max_seq_len: int = 1024,
41
+ max_ratio: float = 1.0,
42
+ eps: float = 1e-5,
43
+ ) -> tuple[int, int]:
44
+ """Determine image size for sequence length constraint."""
45
+
46
+ assert max_ratio >= 1.0
47
+ assert eps * 2 < max_ratio
48
+
49
+ h, w = image_hw
50
+ max_py = int(max((h * max_ratio) // patch_size, 1))
51
+ max_px = int(max((w * max_ratio) // patch_size, 1))
52
+
53
+ if (max_py * max_px) <= max_seq_len:
54
+ return max_py * patch_size, max_px * patch_size
55
+
56
+ def patchify(ratio: float) -> tuple[int, int]:
57
+ return (
58
+ min(int(ceil((h * ratio) / patch_size)), max_py),
59
+ min(int(ceil((w * ratio) / patch_size)), max_px)
60
+ )
61
+
62
+ py, px = patchify(eps)
63
+ if (py * px) > max_seq_len:
64
+ raise ValueError(f"Image of size {w}x{h} is too large.")
65
+
66
+ ratio = eps
67
+ while (max_ratio - ratio) >= eps:
68
+ mid = (ratio + max_ratio) / 2.0
69
+
70
+ mpy, mpx = patchify(mid)
71
+ seq_len = mpy * mpx
72
+
73
+ if seq_len > max_seq_len:
74
+ max_ratio = mid
75
+ continue
76
+
77
+ ratio = mid
78
+ py = mpy
79
+ px = mpx
80
+
81
+ if seq_len == max_seq_len:
82
+ break
83
+
84
+ assert py >= 1 and px >= 1
85
+ return py * patch_size, px * patch_size
86
+
87
+ def process_image(img: Image.Image, patch_size: int, max_seq_len: int) -> Image.Image:
88
+ def compute_resize(wh: tuple[int, int]) -> tuple[int, int]:
89
+ h, w = get_image_size_for_seq((wh[1], wh[0]), patch_size, max_seq_len)
90
+ return w, h
91
+
92
+ return process_srgb(img, resize=compute_resize)
93
+
94
+ def patchify_image(img: Image.Image, patch_size: int, max_seq_len: int, share_memory: bool = False) -> tuple[Tensor, Tensor, Tensor]:
95
+ patches = torch.zeros(max_seq_len, patch_size * patch_size * 3, device="cpu", dtype=torch.uint8)
96
+ patch_coords = torch.zeros(max_seq_len, 2, device="cpu", dtype=torch.int16)
97
+ patch_valid = torch.zeros(max_seq_len, device="cpu", dtype=torch.bool)
98
+
99
+ if share_memory:
100
+ patches.share_memory_()
101
+ patch_coords.share_memory_()
102
+ patch_valid.share_memory_()
103
+
104
+ put_srgb_patch(img, patches, patch_coords, patch_valid, patch_size)
105
+ return patches, patch_coords, patch_valid
106
+
107
+ def load_image(
108
+ path: str,
109
+ patch_size: int = 16,
110
+ max_seq_len: int = 1024,
111
+ share_memory: bool = False
112
+ ) -> tuple[Tensor, Tensor, Tensor]:
113
+ with open(path, "rb", buffering=(1024 * 1024)) as file:
114
+ img: Image.Image = Image.open(file)
115
+
116
+ try:
117
+ processed = process_image(img, patch_size, max_seq_len)
118
+ except:
119
+ img.close()
120
+ raise
121
+
122
+ if img is not processed:
123
+ img.close()
124
+
125
+ return patchify_image(processed, patch_size, max_seq_len, share_memory)
126
+
127
+ def load_model(path: str, device: torch.device | str | None = None) -> tuple[NaFlexVit, list[str]]:
128
+ with safe_open(path, framework="pt", device="cpu") as file:
129
+ metadata = file.metadata()
130
+
131
+ state_dict = {
132
+ key: file.get_tensor(key)
133
+ for key in file.keys()
134
+ }
135
+
136
+ arch = metadata["modelspec.architecture"]
137
+ if not arch.startswith("naflexvit_so400m_patch16_siglip"):
138
+ raise ValueError(f"Unrecognized model architecture: {arch}")
139
+
140
+ tags = metadata["classifier.labels"].split("\n")
141
+
142
+ model = timm.create_model(
143
+ 'naflexvit_so400m_patch16_siglip',
144
+ pretrained=False, num_classes=0,
145
+ pos_embed_interp_mode="bilinear",
146
+ weight_init="skip", fix_init=False,
147
+ device="cpu", dtype=torch.bfloat16,
148
+ )
149
+
150
+ match arch[31:]:
151
+ case "": # vanilla
152
+ model.reset_classifier(len(tags))
153
+
154
+ case "+rr_slim":
155
+ model.reset_classifier(len(tags))
156
+
157
+ if "attn_pool.q.weight" not in state_dict:
158
+ model.attn_pool.q = Identity()
159
+
160
+ if "head.bias" not in state_dict:
161
+ model.head.bias = None
162
+
163
+ case "+rr_chonker":
164
+ from chonker_pool import ChonkerPool
165
+
166
+ model.attn_pool = ChonkerPool(
167
+ 2, 1152, 72,
168
+ device=device, dtype=torch.bfloat16
169
+ )
170
+ model.head = model.attn_pool.create_head(len(tags))
171
+ model.num_classes = len(tags)
172
+
173
+ case "+rr_hydra":
174
+ from hydra_pool import HydraPool
175
+
176
+ model.attn_pool = HydraPool.for_state(
177
+ state_dict, "attn_pool.",
178
+ device=device, dtype=torch.bfloat16
179
+ )
180
+ model.head = model.attn_pool.create_head()
181
+ model.num_classes = len(tags)
182
+
183
+ state_dict["attn_pool._extra_state"] = { "q_normed": True }
184
+
185
+ case _:
186
+ raise ValueError(f"Unrecognized model architecture: {arch}")
187
+
188
+ model.eval().to(dtype=torch.bfloat16)
189
+ model.load_state_dict(state_dict, strict=True)
190
+ model.to(device=device)
191
+
192
+ return model, tags
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ numpy
4
+ pillow
5
+ einops
6
+ safetensors
7
+ gradio
8
+ requests