| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| def remap_image_torch(image): | |
| image_torch = ((image + 1) / 2.0) * 255.0 | |
| image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8) | |
| return image_torch | |
| class CenterCrop(torch.nn.Module): | |
| """Crops the given image at the center. Allows to crop to the maximum possible size. | |
| Args: | |
| size (sequence or int): Desired output size of the crop. If size is an | |
| int instead of sequence like (h, w), a square crop (size, size) is | |
| made. | |
| ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio. | |
| """ | |
| def __init__(self, size=None, ratio="1:1"): | |
| super().__init__() | |
| self.size = size | |
| self.ratio = ratio | |
| def forward(self, img): | |
| """ | |
| Args: | |
| img (PIL Image or Tensor): Image to be cropped. | |
| Returns: | |
| PIL Image or Tensor: Cropped image. | |
| """ | |
| if self.size is None: | |
| if isinstance(img, torch.Tensor): | |
| h, w = img.shape[-2:] | |
| else: | |
| w, h = img.size | |
| ratio = self.ratio.split(":") | |
| ratio = float(ratio[0]) / float(ratio[1]) | |
| ratioed_w = int(h * ratio) | |
| ratioed_h = int(w / ratio) | |
| if w >= h: | |
| if ratioed_h <= h: | |
| size = (ratioed_h, w) | |
| else: | |
| size = (h, ratioed_w) | |
| else: | |
| if ratioed_w <= w: | |
| size = (h, ratioed_w) | |
| else: | |
| size = (ratioed_h, w) | |
| else: | |
| size = self.size | |
| return torchvision.transforms.functional.center_crop(img, size) | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size})" | |