ProArd / attacks /base.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
import torch
import torch.nn as nn
from .utils import replicate_input
class Attack(object):
"""
Abstract base class for all attack classes.
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
"""
def __init__(self, predict, loss_fn, clip_min, clip_max):
self.predict = predict
self.loss_fn = loss_fn
self.clip_min = clip_min
self.clip_max = clip_max
def perturb(self, x, **kwargs):
"""
Virtual method for generating the adversarial examples.
Arguments:
x (torch.Tensor): the model's input tensor.
**kwargs: optional parameters used by child classes.
Returns:
adversarial examples.
"""
error = "Sub-classes must implement perturb."
raise NotImplementedError(error)
def __call__(self, *args, **kwargs):
return self.perturb(*args, **kwargs)
class LabelMixin(object):
def _get_predicted_label(self, x):
"""
Compute predicted labels given x. Used to prevent label leaking during adversarial training.
Arguments:
x (torch.Tensor): the model's input tensor.
Returns:
torch.Tensor containing predicted labels.
"""
with torch.no_grad():
outputs = self.predict(x)
_, y = torch.max(outputs, dim=1)
return y
def _verify_and_process_inputs(self, x, y):
if self.targeted:
assert y is not None
if not self.targeted:
if y is None:
y = self._get_predicted_label(x)
x = replicate_input(x)
y = replicate_input(y)
return x,y