ProArd / attacks /squred.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
from autoattack import AutoAttack
import numpy as np
from .base import Attack,LabelMixin
import torch.nn as nn
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
from typing import Dict
from .utils import ctx_noparamgrad_and_eval
class Squre_Attack(Attack, LabelMixin):
def __init__(
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
ord=np.inf, targeted=False, rand_init_type='uniform'):
super(Squre_Attack, self).__init__(predict, loss_fn, clip_min, clip_max)
self.eps = eps
self.nb_iter = nb_iter
self.eps_iter = eps_iter
self.rand_init = rand_init
self.rand_init_type = rand_init_type
self.ord = ord
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')
def perturb(self, x, y=None):
self.adversary.attacks_to_run=['square']
adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)
return adversarial_examples,adversarial_examples
def eval_squred(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results