Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SSIMLoss(nn.Module): | |
| """ | |
| SSIM loss module. | |
| """ | |
| def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): | |
| """ | |
| Initialize the Losses class. | |
| Parameters | |
| ---------- | |
| win_size : int, optional | |
| Window size for SSIM calculation. | |
| k1 : float, optional | |
| k1 parameter for SSIM calculation. | |
| k2 : float, optional | |
| k2 parameter for SSIM calculation. | |
| """ | |
| super().__init__() | |
| self.win_size = win_size | |
| self.k1, self.k2 = k1, k2 | |
| self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2) | |
| NP = win_size**2 | |
| self.cov_norm = NP / (NP - 1) | |
| def forward( | |
| self, | |
| X: torch.Tensor, | |
| Y: torch.Tensor, | |
| data_range: torch.Tensor, | |
| reduced: bool = True, | |
| ): | |
| assert isinstance(self.w, torch.Tensor) | |
| data_range = data_range[:, None, None, None].to(X.device) | |
| C1 = (self.k1 * data_range) ** 2 | |
| C2 = (self.k2 * data_range) ** 2 | |
| # Compute means | |
| ux = F.conv2d(X, self.w) | |
| uy = F.conv2d(Y, self.w) | |
| # Compute variances | |
| uxx = F.conv2d(X * X, self.w) | |
| uyy = F.conv2d(Y * Y, self.w) | |
| uxy = F.conv2d(X * Y, self.w) | |
| # Compute covariances | |
| vx = self.cov_norm * (uxx - ux * ux) | |
| vy = self.cov_norm * (uyy - uy * uy) | |
| vxy = self.cov_norm * (uxy - ux * uy) | |
| # Compute SSIM components | |
| A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2 | |
| B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2 | |
| D = B1 * B2 | |
| S = (A1 * A2) / D | |
| if reduced: | |
| return 1 - S.mean() | |
| else: | |
| return 1 - S | |
| if __name__ == "__main__": | |
| # Example usage | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Create the SSIMLoss module and move it to the GPU | |
| ssim_loss = SSIMLoss().to(device) | |
| # Create example tensors and move them to the GPU | |
| X = torch.randn(4, 1, 256, 256).to(device) | |
| Y = torch.randn(4, 1, 256, 256).to(device) | |
| data_range = torch.rand(4).to(device) | |
| # Compute the loss | |
| loss = ssim_loss(X, Y, data_range) | |
| print(loss) | |