概述
import torch
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from numpy import *
import torch.nn.functional as F
import advertorch.defenses as defenses
import numpy as np
import matplotlib.pyplot as plt
seed = 2014
torch.manual_seed(seed)
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
train_dataset = datasets.MNIST(root = 'data/', train = True,
transform = transforms.ToTensor(), download = True)
train_loader = DataLoader(dataset = train_dataset, batch_size = 500, shuffle = True)
class Linear_cliassifer(torch.nn.Module):
def __init__(self) :
super(Linear_cliassifer, self).__init__()
self.Line1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
x = self.Line1(x.view(-1, 28 * 28))
return x
net = Linear_cliassifer()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
p = 70
epoch = 5
for k in range(epoch):
sum_loss = 0.0
train_correct = 0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
inputs = defenses.JPEGFilter(quality=p)(inputs)
outputs = net(inputs)
loss = cost(outputs, labels)
loss.backward()
optimizer.step()
print(loss)
_, id = torch.max(outputs.data, 1)
sum_loss += loss.data
train_correct += torch.sum(id == labels.data)
#print('[%d,%d] loss:%.03f' % (k + 1, k, sum_loss / len(train_loader)))
print(' correct:%.03f%%' % (100 * train_correct / len(train_dataset)))
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon*sign_data_grad
# Adding clipping to maintain [0,1] range
perturbed_image = torch.clamp(perturbed_image, 0, 1)
# Return the perturbed image
return perturbed_image, epsilon*sign_data_grad
def test( model, test_loader, epsilon):
# Accuracy counter
correct = 0
adv_examples = []
ns = []
# Loop over all examples in test set
for data, target in test_loader:
# Set requires_grad attribute of tensor. Important for Attack
data = defenses.JPEGFilter(quality=p)(data)
data.requires_grad = True
output = net(data)
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
# If the initial prediction is wrong, dont bother attacking, just move on
if init_pred.item() != target.item():
continue
# Calculate the loss
loss = F.nll_loss(output, target)
# Zero all existing gradients
model.zero_grad()
# Calculate gradients of model in backward pass
loss.backward()
# Collect datagrad
data_grad = data.grad.data
# Call FGSM Attack
perturbed_data, n = fgsm_attack(data, epsilon, data_grad)
ns.append(torch.sum(torch.abs(n)).tolist())
# Re-classify the perturbed image
perturbed_data = defenses.JPEGFilter(quality=p)(perturbed_data)
output = net(perturbed_data)
# Check for success
final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
if final_pred.item() == target.item():
correct += 1
# Calculate final accuracy for this epsilon
final_acc = correct/float(len(test_loader))
print("Epsilon: {}tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
# Return the accuracy and an adversarial example
ns = sum(ns)
return final_acc, adv_examples, ns
accuracies = []
examples = []
noise = []
epsilons = [0, .05, .1, .15, .2, .25, .3]
# Run test for each epsilon
for eps in epsilons:
acc, ex, ns = test(net, test_loader, eps)
accuracies.append(acc)
examples.append(ex)
noise.append(ns)
print(accuracies)
plt.figure(figsize=(5,5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()
最后
以上就是甜美嚓茶为你收集整理的JPG压缩防御对抗样例攻击在MNIST数据集上(pytorch)的全部内容,希望文章能够帮你解决JPG压缩防御对抗样例攻击在MNIST数据集上(pytorch)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复