Spaces:
Build error
Build error
| from argparse import ( | |
| ArgumentParser, | |
| Namespace, | |
| ) | |
| from typing import ( | |
| Dict, | |
| Iterable, | |
| Optional, | |
| Tuple, | |
| ) | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from utils.misc import ( | |
| optional_string, | |
| iterable_to_str, | |
| ) | |
| from .contextual_loss import ContextualLoss | |
| from .color_transfer_loss import ColorTransferLoss | |
| from .regularize_noise import NoiseRegularizer | |
| from .reconstruction import ( | |
| EyeLoss, | |
| FaceLoss, | |
| create_perceptual_loss, | |
| ReconstructionArguments, | |
| ) | |
| class LossArguments: | |
| def add_arguments(parser: ArgumentParser): | |
| ReconstructionArguments.add_arguments(parser) | |
| parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight") | |
| parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight") | |
| parser.add_argument('--noise_regularize', type=float, default=5e4) | |
| # contextual loss | |
| parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight") | |
| parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers", | |
| choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'], | |
| default=['relu3_4', 'relu2_2', 'relu1_2']) | |
| def to_string(args: Namespace) -> str: | |
| return ( | |
| ReconstructionArguments.to_string(args) | |
| + optional_string(args.eye > 0, f"-eye{args.eye}") | |
| + optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}") | |
| + optional_string( | |
| args.contextual, | |
| f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})" | |
| ) | |
| #+ optional_string(args.mse, f"-mse{args.mse}") | |
| + optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}") | |
| ) | |
| class BakedMultiContextualLoss(nn.Module): | |
| """Random sample different image patches for different vgg layers.""" | |
| def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256): | |
| super().__init__() | |
| self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer]) | |
| for layer in args.cx_layers]) | |
| self.size = size | |
| self.sibling = sibling.detach() | |
| def forward(self, img: torch.Tensor): | |
| cx_loss = 0 | |
| for cx in self.cxs: | |
| h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2) | |
| cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss | |
| return cx_loss | |
| class BakedContextualLoss(ContextualLoss): | |
| def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256): | |
| super().__init__(use_vgg=True, vgg_layers=args.cx_layers) | |
| self.size = size | |
| self.sibling = sibling.detach() | |
| def forward(self, img: torch.Tensor): | |
| h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2) | |
| return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) | |
| class JointLoss(nn.Module): | |
| def __init__( | |
| self, | |
| args: Namespace, | |
| target: torch.Tensor, | |
| sibling: Optional[torch.Tensor], | |
| sibling_rgbs: Optional[Iterable[torch.Tensor]] = None, | |
| ): | |
| super().__init__() | |
| self.weights = { | |
| "face": 1., "eye": args.eye, | |
| "contextual": args.contextual, "color_transfer": args.color_transfer, | |
| "noise": args.noise_regularize, | |
| } | |
| reconstruction = {} | |
| if args.vgg > 0 or args.vggface > 0: | |
| percept = create_perceptual_loss(args) | |
| reconstruction.update( | |
| {"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)} | |
| ) | |
| if args.eye > 0: | |
| reconstruction.update( | |
| {"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)} | |
| ) | |
| self.reconstruction = nn.ModuleDict(reconstruction) | |
| exemplar = {} | |
| if args.contextual > 0 and len(args.cx_layers) > 0: | |
| assert sibling is not None | |
| exemplar.update( | |
| {"contextual": BakedContextualLoss(sibling, args)} | |
| ) | |
| if args.color_transfer > 0: | |
| assert sibling_rgbs is not None | |
| self.sibling_rgbs = sibling_rgbs | |
| exemplar.update( | |
| {"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)} | |
| ) | |
| self.exemplar = nn.ModuleDict(exemplar) | |
| if args.noise_regularize > 0: | |
| self.noise_criterion = NoiseRegularizer() | |
| def forward( | |
| self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """ | |
| Args: | |
| rgbs: results from the ToRGB layers | |
| """ | |
| # TODO: add current optimization resolution for noises | |
| losses = {} | |
| # reconstruction losses | |
| for name, criterion in self.reconstruction.items(): | |
| losses[name] = criterion(img, degrade=degrade) | |
| # exemplar losses | |
| if 'contextual' in self.exemplar: | |
| losses["contextual"] = self.exemplar["contextual"](img) | |
| if "color_transfer" in self.exemplar: | |
| assert rgbs is not None | |
| losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level) | |
| # noise regularizer | |
| if self.weights["noise"] > 0: | |
| losses["noise"] = self.noise_criterion(noises) | |
| total_loss = 0 | |
| for name, loss in losses.items(): | |
| total_loss = total_loss + self.weights[name] * loss | |
| return total_loss, losses | |
| def update_sibling(self, sibling: torch.Tensor): | |
| assert "contextual" in self.exemplar | |
| self.exemplar["contextual"].sibling = sibling.detach() | |