|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import random |
|
|
from typing import Union |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision.transforms import functional as TVF |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
|
|
|
class AreaResize: |
|
|
def __init__( |
|
|
self, |
|
|
max_area: float, |
|
|
downsample_only: bool = False, |
|
|
interpolation: InterpolationMode = InterpolationMode.BICUBIC, |
|
|
): |
|
|
self.max_area = max_area |
|
|
self.downsample_only = downsample_only |
|
|
self.interpolation = interpolation |
|
|
|
|
|
def __call__(self, image: Union[torch.Tensor, Image.Image]): |
|
|
|
|
|
if isinstance(image, torch.Tensor): |
|
|
height, width = image.shape[-2:] |
|
|
elif isinstance(image, Image.Image): |
|
|
width, height = image.size |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
scale = math.sqrt(self.max_area / (height * width)) |
|
|
|
|
|
|
|
|
scale = 1 if scale >= 1 and self.downsample_only else scale |
|
|
|
|
|
resized_height, resized_width = round(height * scale), round(width * scale) |
|
|
|
|
|
return TVF.resize( |
|
|
image, |
|
|
size=(resized_height, resized_width), |
|
|
interpolation=self.interpolation, |
|
|
) |
|
|
|
|
|
|
|
|
class AreaRandomCrop: |
|
|
def __init__( |
|
|
self, |
|
|
max_area: float, |
|
|
): |
|
|
self.max_area = max_area |
|
|
|
|
|
def get_params(self, input_size, output_size): |
|
|
"""Get parameters for ``crop`` for a random crop. |
|
|
|
|
|
Args: |
|
|
img (PIL Image): Image to be cropped. |
|
|
output_size (tuple): Expected output size of the crop. |
|
|
|
|
|
Returns: |
|
|
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. |
|
|
""" |
|
|
|
|
|
h, w = input_size |
|
|
th, tw = output_size |
|
|
if w <= tw and h <= th: |
|
|
return 0, 0, h, w |
|
|
|
|
|
i = random.randint(0, h - th) |
|
|
j = random.randint(0, w - tw) |
|
|
return i, j, th, tw |
|
|
|
|
|
def __call__(self, image: Union[torch.Tensor, Image.Image]): |
|
|
if isinstance(image, torch.Tensor): |
|
|
height, width = image.shape[-2:] |
|
|
elif isinstance(image, Image.Image): |
|
|
width, height = image.size |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
resized_height = math.sqrt(self.max_area / (width / height)) |
|
|
resized_width = (width / height) * resized_height |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resized_height, resized_width = round(resized_height), round(resized_width) |
|
|
i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) |
|
|
image = TVF.crop(image, i, j, h, w) |
|
|
return image |
|
|
|
|
|
class ScaleResize: |
|
|
def __init__( |
|
|
self, |
|
|
scale: float, |
|
|
): |
|
|
self.scale = scale |
|
|
|
|
|
def __call__(self, image: Union[torch.Tensor, Image.Image]): |
|
|
if isinstance(image, torch.Tensor): |
|
|
height, width = image.shape[-2:] |
|
|
interpolation_mode = InterpolationMode.BILINEAR |
|
|
antialias = True if image.ndim == 4 else "warn" |
|
|
elif isinstance(image, Image.Image): |
|
|
width, height = image.size |
|
|
interpolation_mode = InterpolationMode.LANCZOS |
|
|
antialias = "warn" |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
scale = self.scale |
|
|
|
|
|
|
|
|
|
|
|
resized_height, resized_width = round(height * scale), round(width * scale) |
|
|
image = TVF.resize( |
|
|
image, |
|
|
size=(resized_height, resized_width), |
|
|
interpolation=interpolation_mode, |
|
|
antialias=antialias, |
|
|
) |
|
|
return image |
|
|
|