import torch import torch.nn as nn from .utils import replicate_input class Attack(object): """ Abstract base class for all attack classes. Arguments: predict (nn.Module): forward pass function. loss_fn (nn.Module): loss function. clip_min (float): mininum value per input dimension. clip_max (float): maximum value per input dimension. """ def __init__(self, predict, loss_fn, clip_min, clip_max): self.predict = predict self.loss_fn = loss_fn self.clip_min = clip_min self.clip_max = clip_max def perturb(self, x, **kwargs): """ Virtual method for generating the adversarial examples. Arguments: x (torch.Tensor): the model's input tensor. **kwargs: optional parameters used by child classes. Returns: adversarial examples. """ error = "Sub-classes must implement perturb." raise NotImplementedError(error) def __call__(self, *args, **kwargs): return self.perturb(*args, **kwargs) class LabelMixin(object): def _get_predicted_label(self, x): """ Compute predicted labels given x. Used to prevent label leaking during adversarial training. Arguments: x (torch.Tensor): the model's input tensor. Returns: torch.Tensor containing predicted labels. """ with torch.no_grad(): outputs = self.predict(x) _, y = torch.max(outputs, dim=1) return y def _verify_and_process_inputs(self, x, y): if self.targeted: assert y is not None if not self.targeted: if y is None: y = self._get_predicted_label(x) x = replicate_input(x) y = replicate_input(y) return x,y