JTP-3-Demo / image.py
RedHotTensors's picture
JTP-3 Hydra Release
d62ba4b
raw
history blame
8.28 kB
from io import BytesIO
from typing import Any, Callable, cast
from warnings import warn, catch_warnings, filterwarnings
import numpy as np
from torch import Tensor
from einops import rearrange
import PIL.Image as image
import PIL.ImageCms as image_cms
from PIL.Image import Image, Resampling
from PIL.ImageCms import (
Direction, Intent, ImageCmsProfile, PyCMSError,
createProfile, getDefaultIntent, isIntentSupported, profileToProfile
)
from PIL.ImageOps import exif_transpose
try:
import pillow_jxl
except ImportError:
pass
image.MAX_IMAGE_PIXELS = None
_SRGB = createProfile(colorSpace='sRGB')
_INTENT_FLAGS = {
Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
Intent.RELATIVE_COLORIMETRIC: (
image_cms.FLAGS["HIGHRESPRECALC"] |
image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
),
Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
}
class CMSWarning(UserWarning):
def __init__(
self,
message: str,
*,
path: str | None = None,
cms_info: dict[str, Any] | None = None,
cause: Exception | None = None,
):
super().__init__(message)
self.__cause__ = cause
self.path = path
self.cms_info = cms_info
self.add_note(f"path: {path}")
self.add_note(f"info: {cms_info}")
def _coalesce_intent(intent: Intent | int) -> Intent:
if isinstance(intent, Intent):
return intent
match intent:
case 0:
return Intent.PERCEPTUAL
case 1:
return Intent.RELATIVE_COLORIMETRIC
case 2:
return Intent.SATURATION
case 3:
return Intent.ABSOLUTE_COLORIMETRIC
case _:
raise ValueError("invalid intent")
def _add_info(info: dict[str, Any], source: object, key: str) -> None:
try:
if (value := getattr(source, key, None)) is not None:
info[key] = value
except Exception:
pass
def open_srgb(
path: str,
*,
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
expect: tuple[int, int] | None = None,
) -> Image:
with open(path, "rb", buffering=(1024 * 1024)) as file:
img: Image = image.open(file)
try:
out = process_srgb(img, resize=resize, crop=crop, expect=expect)
except:
img.close()
raise
if img is not out:
img.close()
return out
def process_srgb(
img: Image,
*,
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
expect: tuple[int, int] | None = None,
) -> Image:
img.load()
try:
exif_transpose(img, in_place=True)
except Exception:
pass # corrupt EXIF metadata is fine
size = (img.width, img.height)
if expect is not None and size != expect:
raise RuntimeError(
f"Image is {size[0]}x{size[1]}, "
f"but expected {expect[0]}x{expect[1]}."
)
if (icc_raw := img.info.get("icc_profile")) is not None:
cms_info: dict[str, Any] = {
"native_mode": img.mode,
"transparency": img.has_transparency_data,
}
try:
profile = ImageCmsProfile(BytesIO(icc_raw))
_add_info(cms_info, profile.profile, "profile_description")
_add_info(cms_info, profile.profile, "target")
_add_info(cms_info, profile.profile, "xcolor_space")
_add_info(cms_info, profile.profile, "connection_space")
_add_info(cms_info, profile.profile, "colorimetric_intent")
_add_info(cms_info, profile.profile, "rendering_intent")
working_mode = img.mode
if img.mode.startswith(("RGB", "BGR", "P")):
working_mode = "RGBA" if img.has_transparency_data else "RGB"
elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
working_mode = "LA" if img.has_transparency_data else "L"
if img.mode != working_mode:
cms_info["working_mode"] = working_mode
img = img.convert(working_mode)
mode = "RGBA" if img.has_transparency_data else "RGB"
intent = Intent.RELATIVE_COLORIMETRIC
if isIntentSupported(profile, intent, Direction.INPUT) != 1:
intent = _coalesce_intent(getDefaultIntent(profile))
cms_info["conversion_intent"] = intent
if (flags := _INTENT_FLAGS.get(intent)) is None:
raise RuntimeError("Unsupported intent")
if img.mode == mode:
profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
inPlace=True,
flags=flags
)
else:
img = cast(Image, profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
outputMode=mode,
flags=flags
))
except Exception as ex:
pass
if img.has_transparency_data:
if img.mode != "RGBa":
try:
img = img.convert("RGBa")
except ValueError:
img = img.convert("RGBA").convert("RGBa")
elif img.mode != "RGB":
img = img.convert("RGB")
if crop is not None and not isinstance(crop, tuple):
crop = crop(size)
if crop is not None:
left, top, right, bottom = crop
size = (right - left, top - bottom)
if resize is not None and not isinstance(resize, tuple):
resize = resize(size)
if resize is not None and size != resize:
img = img.resize(
resize,
Resampling.LANCZOS,
box=crop,
reducing_gap=3.0
)
crop = None
if crop is not None:
img = img.crop(crop)
return img
def put_srgb(img: Image, tensor: Tensor) -> None:
if img.mode not in ("RGB", "RGBA", "RGBa"):
raise ValueError(f"Image has non-RGB mode {img.mode}.")
np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
def put_srgb_patch(
img: Image,
patch_data: Tensor,
patch_coord: Tensor,
patch_valid: Tensor,
patch_size: int
) -> None:
if img.mode not in ("RGB", "RGBA", "RGBa"):
raise ValueError(f"Image has non-RGB mode {img.mode}.")
patches = rearrange(
np.asarray(img)[:, :, :3],
"(h p1) (w p2) c -> h w (p1 p2 c)",
p1=patch_size, p2=patch_size
)
coords = np.stack(np.meshgrid(
np.arange(patches.shape[0], dtype=np.int16),
np.arange(patches.shape[1], dtype=np.int16),
indexing="ij"
), axis=-1)
coords = rearrange(coords, "h w c -> (h w) c")
patches = rearrange(patches, "h w p -> (h w) p")
n = patches.shape[0]
np.copyto(patch_data[:n].numpy(), patches, casting="no")
np.copyto(patch_coord[:n].numpy(), coords, casting="no")
patch_valid[:n] = True
def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
"""
Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.
Args:
input: Tensor of shape (seqlen, ...), patch data.
coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
valid: Tensor of shape (seqlen,), boolean mask for valid patches.
Returns:
Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
"""
valid_coords = coords[0, valid[0]] # (n_valid, 2)
valid_patches = input[valid[0]] # (n_valid, ...)
h = int(valid_coords[:, 0].max().item()) + 1
w = int(valid_coords[:, 1].max().item()) + 1
output_shape = (h, w) + input.shape[1:]
output = input.new_zeros(output_shape)
output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
return output