Spaces:
Running
on
L40S
Running
on
L40S
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Criterion to train CroCo | |
| # -------------------------------------------------------- | |
| # References: | |
| # MAE: https://github.com/facebookresearch/mae | |
| # -------------------------------------------------------- | |
| import torch | |
| class MaskedMSE(torch.nn.Module): | |
| def __init__(self, norm_pix_loss=False, masked=True): | |
| """ | |
| norm_pix_loss: normalize each patch by their pixel mean and variance | |
| masked: compute loss over the masked patches only | |
| """ | |
| super().__init__() | |
| self.norm_pix_loss = norm_pix_loss | |
| self.masked = masked | |
| def forward(self, pred, mask, target): | |
| if self.norm_pix_loss: | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.e-6)**.5 | |
| loss = (pred - target) ** 2 | |
| loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
| if self.masked: | |
| loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches | |
| else: | |
| loss = loss.mean() # mean loss | |
| return loss | |