from autoattack import AutoAttack import numpy as np from .base import Attack,LabelMixin import torch.nn as nn from utils.distributed import DistributedMetric from tqdm import tqdm from torchpack import distributed as dist from utils import accuracy from typing import Dict from .utils import ctx_noparamgrad_and_eval class Squre_Attack(Attack, LabelMixin): def __init__( self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., ord=np.inf, targeted=False, rand_init_type='uniform'): super(Squre_Attack, self).__init__(predict, loss_fn, clip_min, clip_max) self.eps = eps self.nb_iter = nb_iter self.eps_iter = eps_iter self.rand_init = rand_init self.rand_init_type = rand_init_type self.ord = ord self.targeted = targeted if self.loss_fn is None: self.loss_fn = nn.CrossEntropyLoss(reduction="sum") self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard') def perturb(self, x, y=None): self.adversary.attacks_to_run=['square'] adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100) return adversarial_examples,adversarial_examples def eval_squred(self,data_loader_dict: Dict)-> Dict: test_criterion = nn.CrossEntropyLoss().cuda() val_loss = DistributedMetric() val_top1 = DistributedMetric() val_top5 = DistributedMetric() val_advloss = DistributedMetric() val_advtop1 = DistributedMetric() val_advtop5 = DistributedMetric() self.predict.eval() with tqdm( total=len(data_loader_dict["val"]), desc="Eval", disable=not dist.is_master(), ) as t: for images, labels in data_loader_dict["val"]: images, labels = images.cuda(), labels.cuda() # compute output output = self.predict(images) loss = test_criterion(output, labels) val_loss.update(loss, images.shape[0]) acc1, acc5 = accuracy(output, labels, topk=(1, 5)) val_top5.update(acc5[0], images.shape[0]) val_top1.update(acc1[0], images.shape[0]) with ctx_noparamgrad_and_eval(self.predict): images_adv,_ = self.perturb(images, labels) output_adv = self.predict(images_adv) loss_adv = test_criterion(output_adv,labels) val_advloss.update(loss_adv, images.shape[0]) acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5)) val_advtop1.update(acc1_adv[0], images.shape[0]) val_advtop5.update(acc5_adv[0], images.shape[0]) t.set_postfix( { "loss": val_loss.avg.item(), "top1": val_top1.avg.item(), "top5": val_top5.avg.item(), "adv_loss": val_advloss.avg.item(), "adv_top1": val_advtop1.avg.item(), "adv_top5": val_advtop5.avg.item(), "#samples": val_top1.count.item(), "batch_size": images.shape[0], "img_size": images.shape[2], } ) t.update() val_results = { "val_top1": val_top1.avg.item(), "val_top5": val_top5.avg.item(), "val_loss": val_loss.avg.item(), "val_advtop1": val_advtop1.avg.item(), "val_advtop5": val_advtop5.avg.item(), "val_advloss": val_advloss.avg.item(), } return val_results