smi08's picture
Upload folder using huggingface_hub
381fcd6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from attacks import create_attack
import numpy as np
from torch.autograd import Variable
from contextlib import contextmanager
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
class ctx_noparamgrad(object):
def __init__(self, module):
self.prev_grad_state = get_param_grad_state(module)
self.module = module
set_param_grad_off(module)
def __enter__(self):
pass
def __exit__(self, *args):
set_param_grad_state(self.module, self.prev_grad_state)
return False
class ctx_eval(object):
def __init__(self, module):
self.prev_training_state = get_module_training_state(module)
self.module = module
set_module_training_off(module)
def __enter__(self):
pass
def __exit__(self, *args):
set_module_training_state(self.module, self.prev_training_state)
return False
@contextmanager
def ctx_noparamgrad_and_eval(module):
with ctx_noparamgrad(module) as a, ctx_eval(module) as b:
yield (a, b)
def get_module_training_state(module):
return {mod: mod.training for mod in module.modules()}
def set_module_training_state(module, training_state):
for mod in module.modules():
mod.training = training_state[mod]
def set_module_training_off(module):
for mod in module.modules():
mod.training = False
def get_param_grad_state(module):
return {param: param.requires_grad for param in module.parameters()}
def set_param_grad_state(module, grad_state):
for param in module.parameters():
param.requires_grad = grad_state[param]
def set_param_grad_off(module):
for param in module.parameters():
param.requires_grad = False
class MadrysLoss(nn.Module):
def __init__(self, step_size=0.007, epsilon=0.031, perturb_steps=10, beta=6.0,
distance='l_inf', cutmix=False, adjust_freeze=True, cutout=False,
cutout_length=16):
super(MadrysLoss, self).__init__()
self.step_size = step_size
self.epsilon = epsilon
self.perturb_steps = perturb_steps
self.beta = beta
self.distance = distance
self.cross_entropy = torch.nn.CrossEntropyLoss()
self.adjust_freeze = adjust_freeze
self.cutout = cutout
self.cutout_length = cutout_length
def forward(self, model, x_natural, labels): #optimizer
model.eval()
if self.adjust_freeze:
for param in model.parameters():
param.requires_grad = False
# generate adversarial example
x_adv = x_natural.detach() + self.step_size * torch.randn(x_natural.shape).to(device).detach()
if self.distance == 'l_inf':
adv_loss = 0
for _ in range(self.perturb_steps):
x_adv.requires_grad_()
loss_ce = self.cross_entropy(model(x_adv), labels)
grad = torch.autograd.grad(loss_ce, [x_adv])[0]
x_adv = x_adv.detach() + self.step_size * torch.sign(grad.detach())
x_adv = torch.min(torch.max(x_adv, x_natural - self.epsilon), x_natural + self.epsilon)
x_adv = torch.clamp(x_adv, 0.0, 1.0)
else:
x_adv = torch.clamp(x_adv, 0.0, 1.0)
x_adv = Variable(x_adv, requires_grad=False)
if self.adjust_freeze:
for param in model.parameters():
param.requires_grad = True
if self.cutout:
batch_size = x_adv.shape[0]
c, h, w = x_adv.shape[1], x_adv.shape[2], x_adv.shape[3]
mask = torch.ones(batch_size, c, h, w).float()
for j in range(batch_size):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.cutout_length // 2, 0, h)
y2 = np.clip(y + self.cutout_length // 2, 0, h)
x1 = np.clip(x - self.cutout_length // 2, 0, w)
x2 = np.clip(x + self.cutout_length // 2, 0, w)
mask[j, :, y1: y2, x1: x2] = 0.0
x_adv = x_adv * mask.to(device)
model.train()
# optimizer.zero_grad()
logits = model(x_adv)
loss = self.cross_entropy(logits, labels)
return loss
def sat_loss(model, x, y,optimizer,step_size,epsilon,num_steps,attack_type,beta,criterion= torch.nn.CrossEntropyLoss()):
"""
Adversarial training (Madry et al, 2017).
"""
attack = create_attack(model, criterion, 'linf-pgd', epsilon, num_steps, step_size)
with ctx_noparamgrad_and_eval(model):
x_adv, _ = attack.perturb(x, y)
print(x_adv.shape)
y_adv = y
out = model(x_adv)
loss = criterion(out, y_adv)
return loss