|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from attacks import create_attack |
|
|
import numpy as np |
|
|
from torch.autograd import Variable |
|
|
from contextlib import contextmanager |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
|
|
|
class ctx_noparamgrad(object): |
|
|
def __init__(self, module): |
|
|
self.prev_grad_state = get_param_grad_state(module) |
|
|
self.module = module |
|
|
set_param_grad_off(module) |
|
|
|
|
|
def __enter__(self): |
|
|
pass |
|
|
|
|
|
def __exit__(self, *args): |
|
|
set_param_grad_state(self.module, self.prev_grad_state) |
|
|
return False |
|
|
|
|
|
|
|
|
class ctx_eval(object): |
|
|
def __init__(self, module): |
|
|
self.prev_training_state = get_module_training_state(module) |
|
|
self.module = module |
|
|
set_module_training_off(module) |
|
|
|
|
|
def __enter__(self): |
|
|
pass |
|
|
|
|
|
def __exit__(self, *args): |
|
|
set_module_training_state(self.module, self.prev_training_state) |
|
|
return False |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def ctx_noparamgrad_and_eval(module): |
|
|
with ctx_noparamgrad(module) as a, ctx_eval(module) as b: |
|
|
yield (a, b) |
|
|
|
|
|
|
|
|
def get_module_training_state(module): |
|
|
return {mod: mod.training for mod in module.modules()} |
|
|
|
|
|
|
|
|
def set_module_training_state(module, training_state): |
|
|
for mod in module.modules(): |
|
|
mod.training = training_state[mod] |
|
|
|
|
|
|
|
|
def set_module_training_off(module): |
|
|
for mod in module.modules(): |
|
|
mod.training = False |
|
|
|
|
|
|
|
|
def get_param_grad_state(module): |
|
|
return {param: param.requires_grad for param in module.parameters()} |
|
|
|
|
|
|
|
|
def set_param_grad_state(module, grad_state): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = grad_state[param] |
|
|
|
|
|
|
|
|
def set_param_grad_off(module): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
class MadrysLoss(nn.Module): |
|
|
def __init__(self, step_size=0.007, epsilon=0.031, perturb_steps=10, beta=6.0, |
|
|
distance='l_inf', cutmix=False, adjust_freeze=True, cutout=False, |
|
|
cutout_length=16): |
|
|
super(MadrysLoss, self).__init__() |
|
|
self.step_size = step_size |
|
|
self.epsilon = epsilon |
|
|
self.perturb_steps = perturb_steps |
|
|
self.beta = beta |
|
|
self.distance = distance |
|
|
self.cross_entropy = torch.nn.CrossEntropyLoss() |
|
|
self.adjust_freeze = adjust_freeze |
|
|
self.cutout = cutout |
|
|
self.cutout_length = cutout_length |
|
|
|
|
|
def forward(self, model, x_natural, labels): |
|
|
model.eval() |
|
|
if self.adjust_freeze: |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
x_adv = x_natural.detach() + self.step_size * torch.randn(x_natural.shape).to(device).detach() |
|
|
if self.distance == 'l_inf': |
|
|
adv_loss = 0 |
|
|
for _ in range(self.perturb_steps): |
|
|
x_adv.requires_grad_() |
|
|
loss_ce = self.cross_entropy(model(x_adv), labels) |
|
|
grad = torch.autograd.grad(loss_ce, [x_adv])[0] |
|
|
x_adv = x_adv.detach() + self.step_size * torch.sign(grad.detach()) |
|
|
x_adv = torch.min(torch.max(x_adv, x_natural - self.epsilon), x_natural + self.epsilon) |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
else: |
|
|
x_adv = torch.clamp(x_adv, 0.0, 1.0) |
|
|
|
|
|
x_adv = Variable(x_adv, requires_grad=False) |
|
|
|
|
|
if self.adjust_freeze: |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
if self.cutout: |
|
|
batch_size = x_adv.shape[0] |
|
|
c, h, w = x_adv.shape[1], x_adv.shape[2], x_adv.shape[3] |
|
|
mask = torch.ones(batch_size, c, h, w).float() |
|
|
for j in range(batch_size): |
|
|
y = np.random.randint(h) |
|
|
x = np.random.randint(w) |
|
|
|
|
|
y1 = np.clip(y - self.cutout_length // 2, 0, h) |
|
|
y2 = np.clip(y + self.cutout_length // 2, 0, h) |
|
|
x1 = np.clip(x - self.cutout_length // 2, 0, w) |
|
|
x2 = np.clip(x + self.cutout_length // 2, 0, w) |
|
|
|
|
|
mask[j, :, y1: y2, x1: x2] = 0.0 |
|
|
x_adv = x_adv * mask.to(device) |
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
logits = model(x_adv) |
|
|
loss = self.cross_entropy(logits, labels) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def sat_loss(model, x, y,optimizer,step_size,epsilon,num_steps,attack_type,beta,criterion= torch.nn.CrossEntropyLoss()): |
|
|
""" |
|
|
Adversarial training (Madry et al, 2017). |
|
|
""" |
|
|
attack = create_attack(model, criterion, 'linf-pgd', epsilon, num_steps, step_size) |
|
|
with ctx_noparamgrad_and_eval(model): |
|
|
x_adv, _ = attack.perturb(x, y) |
|
|
print(x_adv.shape) |
|
|
y_adv = y |
|
|
out = model(x_adv) |
|
|
loss = criterion(out, y_adv) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|