Spaces:
Runtime error
Runtime error
| # Standard libraries | |
| import itertools | |
| import numpy as np | |
| # PyTorch | |
| import torch | |
| import torch.nn as nn | |
| # Local | |
| from . import JPEG_utils | |
| class rgb_to_ycbcr_jpeg(nn.Module): | |
| """Converts RGB image to YCbCr | |
| Input: | |
| image(tensor): batch x 3 x height x width | |
| Outpput: | |
| result(tensor): batch x height x width x 3 | |
| """ | |
| def __init__(self): | |
| super(rgb_to_ycbcr_jpeg, self).__init__() | |
| matrix = np.array( | |
| [ | |
| [0.299, 0.587, 0.114], | |
| [-0.168736, -0.331264, 0.5], | |
| [0.5, -0.418688, -0.081312], | |
| ], | |
| dtype=np.float32, | |
| ).T | |
| self.shift = nn.Parameter(torch.tensor([0.0, 128.0, 128.0])) | |
| # | |
| self.matrix = nn.Parameter(torch.from_numpy(matrix)) | |
| def forward(self, image): | |
| image = image.permute(0, 2, 3, 1) | |
| result = torch.tensordot(image, self.matrix, dims=1) + self.shift | |
| # result = torch.from_numpy(result) | |
| result.view(image.shape) | |
| return result | |
| class chroma_subsampling(nn.Module): | |
| """Chroma subsampling on CbCv channels | |
| Input: | |
| image(tensor): batch x height x width x 3 | |
| Output: | |
| y(tensor): batch x height x width | |
| cb(tensor): batch x height/2 x width/2 | |
| cr(tensor): batch x height/2 x width/2 | |
| """ | |
| def __init__(self): | |
| super(chroma_subsampling, self).__init__() | |
| def forward(self, image): | |
| image_2 = image.permute(0, 3, 1, 2).clone() | |
| avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), count_include_pad=False) | |
| cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) | |
| cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) | |
| cb = cb.permute(0, 2, 3, 1) | |
| cr = cr.permute(0, 2, 3, 1) | |
| return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) | |
| class block_splitting(nn.Module): | |
| """Splitting image into patches | |
| Input: | |
| image(tensor): batch x height x width | |
| Output: | |
| patch(tensor): batch x h*w/64 x h x w | |
| """ | |
| def __init__(self): | |
| super(block_splitting, self).__init__() | |
| self.k = 8 | |
| def forward(self, image): | |
| height, width = image.shape[1:3] | |
| # print(height, width) | |
| batch_size = image.shape[0] | |
| # print(image.shape) | |
| image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) | |
| image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) | |
| return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) | |
| class dct_8x8(nn.Module): | |
| """Discrete Cosine Transformation | |
| Input: | |
| image(tensor): batch x height x width | |
| Output: | |
| dcp(tensor): batch x height x width | |
| """ | |
| def __init__(self): | |
| super(dct_8x8, self).__init__() | |
| tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) | |
| for x, y, u, v in itertools.product(range(8), repeat=4): | |
| tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( | |
| (2 * y + 1) * v * np.pi / 16 | |
| ) | |
| alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7) | |
| # | |
| self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) | |
| self.scale = nn.Parameter( | |
| torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() | |
| ) | |
| def forward(self, image): | |
| image = image - 128 | |
| result = self.scale * torch.tensordot(image, self.tensor, dims=2) | |
| result.view(image.shape) | |
| return result | |
| class y_quantize(nn.Module): | |
| """JPEG Quantization for Y channel | |
| Input: | |
| image(tensor): batch x height x width | |
| rounding(function): rounding function to use | |
| factor(float): Degree of compression | |
| Output: | |
| image(tensor): batch x height x width | |
| """ | |
| def __init__(self, rounding, factor=1): | |
| super(y_quantize, self).__init__() | |
| self.rounding = rounding | |
| self.factor = factor | |
| self.y_table = JPEG_utils.y_table | |
| def forward(self, image): | |
| image = image.float() / (self.y_table * self.factor) | |
| image = self.rounding(image) | |
| return image | |
| class c_quantize(nn.Module): | |
| """JPEG Quantization for CrCb channels | |
| Input: | |
| image(tensor): batch x height x width | |
| rounding(function): rounding function to use | |
| factor(float): Degree of compression | |
| Output: | |
| image(tensor): batch x height x width | |
| """ | |
| def __init__(self, rounding, factor=1): | |
| super(c_quantize, self).__init__() | |
| self.rounding = rounding | |
| self.factor = factor | |
| self.c_table = JPEG_utils.c_table | |
| def forward(self, image): | |
| image = image.float() / (self.c_table * self.factor) | |
| image = self.rounding(image) | |
| return image | |
| class compress_jpeg(nn.Module): | |
| """Full JPEG compression algortihm | |
| Input: | |
| imgs(tensor): batch x 3 x height x width | |
| rounding(function): rounding function to use | |
| factor(float): Compression factor | |
| Ouput: | |
| compressed(dict(tensor)): batch x h*w/64 x 8 x 8 | |
| """ | |
| def __init__(self, rounding=torch.round, factor=1): | |
| super(compress_jpeg, self).__init__() | |
| self.l1 = nn.Sequential( | |
| rgb_to_ycbcr_jpeg(), | |
| # comment this line if no subsampling | |
| chroma_subsampling(), | |
| ) | |
| self.l2 = nn.Sequential(block_splitting(), dct_8x8()) | |
| self.c_quantize = c_quantize(rounding=rounding, factor=factor) | |
| self.y_quantize = y_quantize(rounding=rounding, factor=factor) | |
| def forward(self, image): | |
| y, cb, cr = self.l1(image * 255) # modify | |
| # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2] | |
| components = {"y": y, "cb": cb, "cr": cr} | |
| for k in components.keys(): | |
| comp = self.l2(components[k]) | |
| # print(comp.shape) | |
| if k in ("cb", "cr"): | |
| comp = self.c_quantize(comp) | |
| else: | |
| comp = self.y_quantize(comp) | |
| components[k] = comp | |
| return components["y"], components["cb"], components["cr"] | |