Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .JPEG_utils import diff_round, quality_to_factor, Quantization | |
| from .compression import compress_jpeg | |
| from .decompression import decompress_jpeg | |
| class DiffJPEG(nn.Module): | |
| def __init__(self, differentiable=True, quality=75): | |
| """Initialize the DiffJPEG layer | |
| Inputs: | |
| height(int): Original image height | |
| width(int): Original image width | |
| differentiable(bool): If true uses custom differentiable | |
| rounding function, if false uses standrard torch.round | |
| quality(float): Quality factor for jpeg compression scheme. | |
| """ | |
| super(DiffJPEG, self).__init__() | |
| if differentiable: | |
| rounding = diff_round | |
| # rounding = Quantization() | |
| else: | |
| rounding = torch.round | |
| factor = quality_to_factor(quality) | |
| self.compress = compress_jpeg(rounding=rounding, factor=factor) | |
| # self.decompress = decompress_jpeg(height, width, rounding=rounding, | |
| # factor=factor) | |
| self.decompress = decompress_jpeg(rounding=rounding, factor=factor) | |
| def forward(self, x): | |
| """ """ | |
| org_height = x.shape[2] | |
| org_width = x.shape[3] | |
| y, cb, cr = self.compress(x) | |
| recovered = self.decompress(y, cb, cr, org_height, org_width) | |
| return recovered | |