ProArd / attacks /pgd.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
import numpy as np
import torch
import torch.nn as nn
from .utils import ctx_noparamgrad_and_eval
from .base import Attack, LabelMixin
from typing import Dict
from .utils import batch_clamp
from .utils import batch_multiply
from .utils import clamp
from .utils import clamp_by_pnorm
from .utils import is_float_or_torch_tensor
from .utils import normalize_by_pnorm
from .utils import rand_init_delta
from .utils import replicate_input
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf,
clip_min=0.0, clip_max=1.0):
"""
Iteratively maximize the loss over the input. It is a shared method for iterative attacks.
Arguments:
xvar (torch.Tensor): input data.
yvar (torch.Tensor): input labels.
predict (nn.Module): forward pass function.
nb_iter (int): number of iterations.
eps (float): maximum distortion.
eps_iter (float): attack step size.
loss_fn (nn.Module): loss function.
delta_init (torch.Tensor): (optional) tensor contains the random initialization.
minimize (bool): (optional) whether to minimize or maximize the loss.
ord (int): (optional) the order of maximum distortion (inf or 2).
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
Returns:
torch.Tensor containing the perturbed input,
torch.Tensor containing the perturbation
"""
if delta_init is not None:
delta = delta_init
else:
delta = torch.zeros_like(xvar)
delta.requires_grad_()
for ii in range(nb_iter):
outputs = predict(xvar + delta)
loss = loss_fn(outputs, yvar)
if minimize:
loss = -loss
loss.backward()
if ord == np.inf:
grad_sign = delta.grad.data.sign()
delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
delta.data = batch_clamp(eps, delta.data)
delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
elif ord == 2:
grad = delta.grad.data
grad = normalize_by_pnorm(grad)
delta.data = delta.data + batch_multiply(eps_iter, grad)
delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
if eps is not None:
delta.data = clamp_by_pnorm(delta.data, ord, eps)
else:
error = "Only ord=inf and ord=2 have been implemented"
raise NotImplementedError(error)
delta.grad.data.zero_()
x_adv = clamp(xvar + delta, clip_min, clip_max)
r_adv = x_adv - xvar
return x_adv, r_adv
class PGDAttack(Attack, LabelMixin):
"""
The projected gradient descent attack (Madry et al, 2017).
The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point.
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): maximum distortion.
nb_iter (int): number of iterations.
eps_iter (float): attack step size.
rand_init (bool): (optional) random initialization.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
ord (int): (optional) the order of maximum distortion (inf or 2).
targeted (bool): if the attack is targeted.
rand_init_type (str): (optional) random initialization type.
"""
def __init__(
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
ord=np.inf, targeted=False, rand_init_type='uniform'):
super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.nb_iter = nb_iter
self.eps_iter = eps_iter
self.rand_init = rand_init
self.rand_init_type = rand_init_type
self.ord = ord
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
assert is_float_or_torch_tensor(self.eps_iter)
assert is_float_or_torch_tensor(self.eps)
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted
labels.
- if self.targeted=True, then y must be the targeted labels.
Returns:
torch.Tensor containing perturbed inputs,
torch.Tensor containing the perturbation
"""
x, y = self._verify_and_process_inputs(x, y)
delta = torch.zeros_like(x)
delta = nn.Parameter(delta)
if self.rand_init:
if self.rand_init_type == 'uniform':
rand_init_delta(
delta, x, self.ord, self.eps, self.clip_min, self.clip_max)
delta.data = clamp(
x + delta.data, min=self.clip_min, max=self.clip_max) - x
elif self.rand_init_type == 'normal':
delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES
else:
raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.')
x_adv, r_adv = perturb_iterative(
x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn,
minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta
)
return x_adv.data, r_adv.data
def eval_pgd(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results
class LinfPGDAttack(PGDAttack):
"""
PGD Attack with order=Linf
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): maximum distortion.
nb_iter (int): number of iterations.
eps_iter (float): attack step size.
rand_init (bool): (optional) random initialization.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): if the attack is targeted.
rand_init_type (str): (optional) random initialization type.
"""
def __init__(
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
targeted=False, rand_init_type='uniform'):
ord = np.inf
super(LinfPGDAttack, self).__init__(
predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init,
clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)
class L2PGDAttack(PGDAttack):
"""
PGD Attack with order=L2
Arguments:
predict (nn.Module): forward pass function.
loss_fn (nn.Module): loss function.
eps (float): maximum distortion.
nb_iter (int): number of iterations.
eps_iter (float): attack step size.
rand_init (bool): (optional) random initialization.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
targeted (bool): if the attack is targeted.
rand_init_type (str): (optional) random initialization type.
"""
def __init__(
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
targeted=False, rand_init_type='uniform'):
ord = 2
super(L2PGDAttack, self).__init__(
predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init,
clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)