File size: 8,279 Bytes
d62ba4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
from io import BytesIO
from typing import Any, Callable, cast
from warnings import warn, catch_warnings, filterwarnings

import numpy as np
from torch import Tensor

from einops import rearrange

import PIL.Image as image
import PIL.ImageCms as image_cms

from PIL.Image import Image, Resampling
from PIL.ImageCms import (
    Direction, Intent, ImageCmsProfile, PyCMSError,
    createProfile, getDefaultIntent, isIntentSupported, profileToProfile
)
from PIL.ImageOps import exif_transpose

try:
    import pillow_jxl
except ImportError:
    pass

image.MAX_IMAGE_PIXELS = None

_SRGB = createProfile(colorSpace='sRGB')

_INTENT_FLAGS = {
    Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
    Intent.RELATIVE_COLORIMETRIC: (
        image_cms.FLAGS["HIGHRESPRECALC"] |
        image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
    ),
    Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
}

class CMSWarning(UserWarning):
    def __init__(
        self,
        message: str,
        *,
        path: str | None = None,
        cms_info: dict[str, Any] | None = None,
        cause: Exception | None = None,
    ):
        super().__init__(message)
        self.__cause__ = cause

        self.path = path
        self.cms_info = cms_info

        self.add_note(f"path: {path}")
        self.add_note(f"info: {cms_info}")

def _coalesce_intent(intent: Intent | int) -> Intent:
    if isinstance(intent, Intent):
        return intent

    match intent:
        case 0:
            return Intent.PERCEPTUAL
        case 1:
            return Intent.RELATIVE_COLORIMETRIC
        case 2:
            return Intent.SATURATION
        case 3:
            return Intent.ABSOLUTE_COLORIMETRIC
        case _:
            raise ValueError("invalid intent")

def _add_info(info: dict[str, Any], source: object, key: str) -> None:
    try:
        if (value := getattr(source, key, None)) is not None:
            info[key] = value
    except Exception:
        pass

def open_srgb(
    path: str,
    *,
    resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
    crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
    expect: tuple[int, int] | None = None,
) -> Image:
    with open(path, "rb", buffering=(1024 * 1024)) as file:
        img: Image = image.open(file)

        try:
            out = process_srgb(img, resize=resize, crop=crop, expect=expect)
        except:
            img.close()
            raise

        if img is not out:
            img.close()

        return out

def process_srgb(
    img: Image,
    *,
    resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
    crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
    expect: tuple[int, int] | None = None,
) -> Image:
    img.load()

    try:
        exif_transpose(img, in_place=True)
    except Exception:
        pass # corrupt EXIF metadata is fine

    size = (img.width, img.height)

    if expect is not None and size != expect:
        raise RuntimeError(
            f"Image is {size[0]}x{size[1]}, "
            f"but expected {expect[0]}x{expect[1]}."
        )

    if (icc_raw := img.info.get("icc_profile")) is not None:
        cms_info: dict[str, Any] = {
            "native_mode": img.mode,
            "transparency": img.has_transparency_data,
        }

        try:
            profile = ImageCmsProfile(BytesIO(icc_raw))
            _add_info(cms_info, profile.profile, "profile_description")
            _add_info(cms_info, profile.profile, "target")
            _add_info(cms_info, profile.profile, "xcolor_space")
            _add_info(cms_info, profile.profile, "connection_space")
            _add_info(cms_info, profile.profile, "colorimetric_intent")
            _add_info(cms_info, profile.profile, "rendering_intent")

            working_mode = img.mode
            if img.mode.startswith(("RGB", "BGR", "P")):
                working_mode = "RGBA" if img.has_transparency_data else "RGB"
            elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
                working_mode = "LA" if img.has_transparency_data else "L"

            if img.mode != working_mode:
                cms_info["working_mode"] = working_mode
                img = img.convert(working_mode)

            mode = "RGBA" if img.has_transparency_data else "RGB"

            intent = Intent.RELATIVE_COLORIMETRIC
            if isIntentSupported(profile, intent, Direction.INPUT) != 1:
                intent = _coalesce_intent(getDefaultIntent(profile))

            cms_info["conversion_intent"] = intent

            if (flags := _INTENT_FLAGS.get(intent)) is None:
                raise RuntimeError("Unsupported intent")

            if img.mode == mode:
                profileToProfile(
                    img,
                    profile,
                    _SRGB,
                    renderingIntent=intent,
                    inPlace=True,
                    flags=flags
                )
            else:
                img = cast(Image, profileToProfile(
                    img,
                    profile,
                    _SRGB,
                    renderingIntent=intent,
                    outputMode=mode,
                    flags=flags
                ))
        except Exception as ex:
            pass

    if img.has_transparency_data:
        if img.mode != "RGBa":
            try:
                img = img.convert("RGBa")
            except ValueError:
                img = img.convert("RGBA").convert("RGBa")
    elif img.mode != "RGB":
        img = img.convert("RGB")

    if crop is not None and not isinstance(crop, tuple):
        crop = crop(size)

    if crop is not None:
        left, top, right, bottom = crop
        size = (right - left, top - bottom)

    if resize is not None and not isinstance(resize, tuple):
        resize = resize(size)

    if resize is not None and size != resize:
        img = img.resize(
            resize,
            Resampling.LANCZOS,
            box=crop,
            reducing_gap=3.0
        )
        crop = None

    if crop is not None:
        img = img.crop(crop)

    return img

def put_srgb(img: Image, tensor: Tensor) -> None:
    if img.mode not in ("RGB", "RGBA", "RGBa"):
        raise ValueError(f"Image has non-RGB mode {img.mode}.")

    np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")

def put_srgb_patch(
    img: Image,
    patch_data: Tensor,
    patch_coord: Tensor,
    patch_valid: Tensor,
    patch_size: int
) -> None:
    if img.mode not in ("RGB", "RGBA", "RGBa"):
        raise ValueError(f"Image has non-RGB mode {img.mode}.")

    patches = rearrange(
        np.asarray(img)[:, :, :3],
        "(h p1) (w p2) c -> h w (p1 p2 c)",
        p1=patch_size, p2=patch_size
    )

    coords = np.stack(np.meshgrid(
        np.arange(patches.shape[0], dtype=np.int16),
        np.arange(patches.shape[1], dtype=np.int16),
        indexing="ij"
    ), axis=-1)

    coords = rearrange(coords, "h w c -> (h w) c")
    patches = rearrange(patches, "h w p -> (h w) p")
    n = patches.shape[0]

    np.copyto(patch_data[:n].numpy(), patches, casting="no")
    np.copyto(patch_coord[:n].numpy(), coords, casting="no")
    patch_valid[:n] = True

def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
    """
    Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.

    Args:
        input: Tensor of shape (seqlen, ...), patch data.
        coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
        valid: Tensor of shape (seqlen,), boolean mask for valid patches.

    Returns:
        Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
    """

    valid_coords = coords[0, valid[0]]  # (n_valid, 2)
    valid_patches = input[valid[0]]  # (n_valid, ...)

    h = int(valid_coords[:, 0].max().item()) + 1
    w = int(valid_coords[:, 1].max().item()) + 1

    output_shape = (h, w) + input.shape[1:]
    output = input.new_zeros(output_shape)

    output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
    return output