Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit | |
| # https://github.com/facebookresearch/dino | |
| # https://github.com/facebookresearch/moco-v3 | |
| # https://github.com/microsoft/unilm/tree/master/beit | |
| # https://github.com/BUPT-PRIV/MAE-priv | |
| # https://github.com/facebookresearch/mae | |
| # -------------------------------------------------------- | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| class MaskedCrossEntropyLoss(nn.Module): | |
| """Cross-entropy loss with masking | |
| :param patch_size: Patch size | |
| :param stride: Stride of task / modality | |
| :param label_smoothing: Amount of smoothing in the loss (default is 0.0) | |
| """ | |
| def __init__(self, patch_size: int = 16, stride: int = 1, label_smoothing : float = 0.0): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.stride = stride | |
| self.scale_factor = patch_size // stride | |
| self.label_smoothing = label_smoothing | |
| def forward(self, input, target, mask=None): | |
| loss = F.cross_entropy(input, target, reduction='none', label_smoothing=self.label_smoothing) | |
| if mask is not None: | |
| if mask.sum() == 0: | |
| return torch.tensor(0).to(loss.device) | |
| H, W = input.shape[-2:] | |
| nh, nw = H // self.scale_factor, W // self.scale_factor | |
| # Resize mask and upsample | |
| mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
| mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
| loss = loss * mask | |
| # Compute mean per sample | |
| loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
| loss = loss.nanmean() # Account for zero masks | |
| else: | |
| loss = loss.mean() # If this is ever nan, we want it to stop training | |
| return loss | |
| class MaskedMSELoss(nn.Module): | |
| """L1 loss with masking | |
| :param patch_size: Patch size | |
| :param stride: Stride of task / modality | |
| :param norm_pix: Normalized pixel loss | |
| """ | |
| def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.stride = stride | |
| self.scale_factor = patch_size // stride | |
| self.norm_pix = norm_pix | |
| def patchify(self, imgs, nh, nw): | |
| p = self.scale_factor | |
| x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p) | |
| return x | |
| def unpatchify(self, x, nh, nw): | |
| p = self.scale_factor | |
| imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p) | |
| return imgs | |
| def forward(self, input, target, mask=None): | |
| H, W = input.shape[-2:] | |
| nh, nw = H // self.scale_factor, W // self.scale_factor | |
| if self.norm_pix: | |
| target = self.patchify(target, nh, nw) | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| eps = 1e-6 | |
| target = (target - mean) / torch.sqrt(var + eps) | |
| target = self.unpatchify(target, nh, nw) | |
| loss = F.mse_loss(input, target, reduction='none') | |
| if mask is not None: | |
| if mask.sum() == 0: | |
| return torch.tensor(0).to(loss.device) | |
| # Resize mask and upsample | |
| mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
| mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
| loss = loss.mean(dim=1) # B, C, H, W -> B, H, W | |
| loss = loss * mask | |
| # Compute mean per sample | |
| loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
| loss = loss.nanmean() # Account for zero masks | |
| else: | |
| loss = loss.mean() # If this is ever nan, we want it to stop training | |
| return loss | |
| class MaskedL1Loss(nn.Module): | |
| """L1 loss with masking | |
| :param patch_size: Patch size | |
| :param stride: Stride of task / modality | |
| :param norm_pix: Normalized pixel loss | |
| """ | |
| def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.stride = stride | |
| self.scale_factor = patch_size // stride | |
| self.norm_pix = norm_pix | |
| def patchify(self, imgs, nh, nw): | |
| p = self.scale_factor | |
| x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p) | |
| return x | |
| def unpatchify(self, x, nh, nw): | |
| p = self.scale_factor | |
| imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p) | |
| return imgs | |
| def forward(self, input, target, mask=None): | |
| H, W = input.shape[-2:] | |
| nh, nw = H // self.scale_factor, W // self.scale_factor | |
| if self.norm_pix: | |
| target = self.patchify(target, nh, nw) | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| eps = 1e-6 | |
| target = (target - mean) / torch.sqrt(var + eps) | |
| target = self.unpatchify(target, nh, nw) | |
| loss = F.l1_loss(input, target, reduction='none') | |
| if mask is not None: | |
| if mask.sum() == 0: | |
| return torch.tensor(0).to(loss.device) | |
| # Resize mask and upsample | |
| mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) | |
| mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1) | |
| loss = loss.mean(dim=1) # B, C, H, W -> B, H, W | |
| loss = loss * mask | |
| # Compute mean per sample | |
| loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1) | |
| loss = loss.nanmean() # Account for zero masks | |
| else: | |
| loss = loss.mean() # If this is ever nan, we want it to stop training | |
| return loss | |