File size: 10,930 Bytes
7771996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import numpy as np
import torch
import torch.nn as nn
from .utils import  ctx_noparamgrad_and_eval
from .base import Attack, LabelMixin
from typing import Dict
from .utils import batch_clamp
from .utils import batch_multiply
from .utils import clamp
from .utils import clamp_by_pnorm
from .utils import is_float_or_torch_tensor
from .utils import normalize_by_pnorm
from .utils import rand_init_delta
from .utils import replicate_input
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy

def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf, 
                      clip_min=0.0, clip_max=1.0):
    """
    Iteratively maximize the loss over the input. It is a shared method for iterative attacks.
    Arguments:
        xvar (torch.Tensor): input data.
        yvar (torch.Tensor): input labels.
        predict (nn.Module): forward pass function.
        nb_iter (int): number of iterations.
        eps (float): maximum distortion.
        eps_iter (float): attack step size.
        loss_fn (nn.Module): loss function.
        delta_init (torch.Tensor): (optional) tensor contains the random initialization.
        minimize (bool): (optional) whether to minimize or maximize the loss.
        ord (int): (optional) the order of maximum distortion (inf or 2).
        clip_min (float): mininum value per input dimension.
        clip_max (float): maximum value per input dimension.
    Returns: 
        torch.Tensor containing the perturbed input, 
        torch.Tensor containing the perturbation
    """
    if delta_init is not None:
        delta = delta_init
    else:
        delta = torch.zeros_like(xvar)

    delta.requires_grad_()
    for ii in range(nb_iter):
        outputs = predict(xvar + delta)
        loss = loss_fn(outputs, yvar)
        if minimize:
            loss = -loss

        loss.backward()
        if ord == np.inf:
            grad_sign = delta.grad.data.sign()
            delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
            delta.data = batch_clamp(eps, delta.data)
            delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
        elif ord == 2:
            grad = delta.grad.data
            grad = normalize_by_pnorm(grad)
            delta.data = delta.data + batch_multiply(eps_iter, grad)
            delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
            if eps is not None:
                delta.data = clamp_by_pnorm(delta.data, ord, eps)
        else:
            error = "Only ord=inf and ord=2 have been implemented"
            raise NotImplementedError(error)
        delta.grad.data.zero_()

    x_adv = clamp(xvar + delta, clip_min, clip_max)
    r_adv = x_adv - xvar
    return x_adv, r_adv


class PGDAttack(Attack, LabelMixin):
    """
    The projected gradient descent attack (Madry et al, 2017).
    The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point.
    Arguments:
        predict (nn.Module): forward pass function.
        loss_fn (nn.Module): loss function.
        eps (float): maximum distortion.
        nb_iter (int): number of iterations.
        eps_iter (float): attack step size.
        rand_init (bool): (optional) random initialization.    
        clip_min (float): mininum value per input dimension.
        clip_max (float): maximum value per input dimension.
        ord (int): (optional) the order of maximum distortion (inf or 2).
        targeted (bool): if the attack is targeted.
        rand_init_type (str): (optional) random initialization type.
    """

    def __init__(
            self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
            ord=np.inf, targeted=False, rand_init_type='uniform'):
        super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
        self.eps = eps
        self.nb_iter = nb_iter
        self.eps_iter = eps_iter
        self.rand_init = rand_init
        self.rand_init_type = rand_init_type
        self.ord = ord
        self.targeted = targeted
        if self.loss_fn is None:
            self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
        assert is_float_or_torch_tensor(self.eps_iter)
        assert is_float_or_torch_tensor(self.eps)

    def perturb(self, x, y=None):
        """
        Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
        Arguments:
            x (torch.Tensor): input tensor.
            y (torch.Tensor): label tensor.
                - if None and self.targeted=False, compute y as predicted
                labels.
                - if self.targeted=True, then y must be the targeted labels.
        Returns: 
            torch.Tensor containing perturbed inputs,
            torch.Tensor containing the perturbation    
        """
        x, y = self._verify_and_process_inputs(x, y)

        delta = torch.zeros_like(x)
        delta = nn.Parameter(delta)
        if self.rand_init:
            if self.rand_init_type == 'uniform':
                rand_init_delta(
                    delta, x, self.ord, self.eps, self.clip_min, self.clip_max)
                delta.data = clamp(
                    x + delta.data, min=self.clip_min, max=self.clip_max) - x
            elif self.rand_init_type == 'normal':
                delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES
            else:
                raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.')
        
        x_adv, r_adv = perturb_iterative(
            x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn, 
            minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta
        )

        return x_adv.data, r_adv.data

    def eval_pgd(self,data_loader_dict: Dict)-> Dict:

        test_criterion = nn.CrossEntropyLoss().cuda()
        val_loss = DistributedMetric()
        val_top1 = DistributedMetric()
        val_top5 = DistributedMetric()
        val_advloss = DistributedMetric()
        val_advtop1 = DistributedMetric()
        val_advtop5 = DistributedMetric()
        self.predict.eval()
        with tqdm(
                total=len(data_loader_dict["val"]),
                desc="Eval",
                disable=not dist.is_master(),
            ) as t:
                for images, labels in data_loader_dict["val"]:
                    images, labels = images.cuda(), labels.cuda()
                    # compute output
                    output = self.predict(images)
                    loss = test_criterion(output, labels)
                    val_loss.update(loss, images.shape[0])
                    acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                    val_top5.update(acc5[0], images.shape[0])
                    val_top1.update(acc1[0], images.shape[0])
                    with ctx_noparamgrad_and_eval(self.predict):
                        images_adv,_ = self.perturb(images, labels)
                    output_adv = self.predict(images_adv)   
                    loss_adv = test_criterion(output_adv,labels) 
                    val_advloss.update(loss_adv, images.shape[0])   
                    acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))   
                    val_advtop1.update(acc1_adv[0], images.shape[0])
                    val_advtop5.update(acc5_adv[0], images.shape[0])                  
                    t.set_postfix(
                        {
                            "loss": val_loss.avg.item(),
                            "top1": val_top1.avg.item(),
                            "top5": val_top5.avg.item(),
                            "adv_loss": val_advloss.avg.item(),
                            "adv_top1": val_advtop1.avg.item(),
                            "adv_top5": val_advtop5.avg.item(),
                            "#samples": val_top1.count.item(),
                            "batch_size": images.shape[0],
                            "img_size": images.shape[2],
                        }
                    )
                    t.update()

        val_results = {
            "val_top1": val_top1.avg.item(),
            "val_top5": val_top5.avg.item(),
            "val_loss": val_loss.avg.item(),
            "val_advtop1": val_advtop1.avg.item(),
            "val_advtop5": val_advtop5.avg.item(),
            "val_advloss": val_advloss.avg.item(),
        }
        return val_results
class LinfPGDAttack(PGDAttack):
    """
    PGD Attack with order=Linf
    Arguments:
        predict (nn.Module): forward pass function.
        loss_fn (nn.Module): loss function.
        eps (float): maximum distortion.
        nb_iter (int): number of iterations.
        eps_iter (float): attack step size.
        rand_init (bool): (optional) random initialization.    
        clip_min (float): mininum value per input dimension.
        clip_max (float): maximum value per input dimension.
        targeted (bool): if the attack is targeted.
        rand_init_type (str): (optional) random initialization type.
    """

    def __init__(
            self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
            targeted=False, rand_init_type='uniform'):
        ord = np.inf
        super(LinfPGDAttack, self).__init__(
            predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 
            clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)


class L2PGDAttack(PGDAttack):
    """
    PGD Attack with order=L2
    Arguments:
        predict (nn.Module): forward pass function.
        loss_fn (nn.Module): loss function.
        eps (float): maximum distortion.
        nb_iter (int): number of iterations.
        eps_iter (float): attack step size.
        rand_init (bool): (optional) random initialization.    
        clip_min (float): mininum value per input dimension.
        clip_max (float): maximum value per input dimension.
        targeted (bool): if the attack is targeted.
        rand_init_type (str): (optional) random initialization type.
    """

    def __init__(
            self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
            targeted=False, rand_init_type='uniform'):
        ord = 2
        super(L2PGDAttack, self).__init__(
            predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 
            clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)