Spaces:
Runtime error
Runtime error
| # Standard libraries | |
| import numpy as np | |
| # PyTorch | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| y_table = np.array( | |
| [ | |
| [16, 11, 10, 16, 24, 40, 51, 61], | |
| [12, 12, 14, 19, 26, 58, 60, 55], | |
| [14, 13, 16, 24, 40, 57, 69, 56], | |
| [14, 17, 22, 29, 51, 87, 80, 62], | |
| [18, 22, 37, 56, 68, 109, 103, 77], | |
| [24, 35, 55, 64, 81, 104, 113, 92], | |
| [49, 64, 78, 87, 103, 121, 120, 101], | |
| [72, 92, 95, 98, 112, 100, 103, 99], | |
| ], | |
| dtype=np.float32, | |
| ).T | |
| y_table = nn.Parameter(torch.from_numpy(y_table)) | |
| # | |
| c_table = np.empty((8, 8), dtype=np.float32) | |
| c_table.fill(99) | |
| c_table[:4, :4] = np.array( | |
| [[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]] | |
| ).T | |
| c_table = nn.Parameter(torch.from_numpy(c_table)) | |
| def diff_round_back(x): | |
| """Differentiable rounding function | |
| Input: | |
| x(tensor) | |
| Output: | |
| x(tensor) | |
| """ | |
| return torch.round(x) + (x - torch.round(x)) ** 3 | |
| def diff_round(input_tensor): | |
| test = 0 | |
| for n in range(1, 10): | |
| test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor) | |
| final_tensor = input_tensor - 1 / math.pi * test | |
| return final_tensor | |
| class Quant(torch.autograd.Function): | |
| def forward(ctx, input): | |
| input = torch.clamp(input, 0, 1) | |
| output = (input * 255.0).round() / 255.0 | |
| return output | |
| def backward(ctx, grad_output): | |
| return grad_output | |
| class Quantization(nn.Module): | |
| def __init__(self): | |
| super(Quantization, self).__init__() | |
| def forward(self, input): | |
| return Quant.apply(input) | |
| def quality_to_factor(quality): | |
| """Calculate factor corresponding to quality | |
| Input: | |
| quality(float): Quality for jpeg compression | |
| Output: | |
| factor(float): Compression factor | |
| """ | |
| if quality < 50: | |
| quality = 5000.0 / quality | |
| else: | |
| quality = 200.0 - quality * 2 | |
| return quality / 100.0 | |