Deep leakage from Gradients论文解析
今天来给大家介绍下2019年NIPS上发表的一篇通过梯度进行原始数据恢复的论文。
论文传送门
**问题背景:**现在分布式机器学习和联邦学习中普遍接受的一个做法是将数据梯度进行共享,多方数据通过共享的梯度信息进行联合建模,即在原始数据不出库的前提下进行建模,那么这样引出作者的一个思考:这样的梯度信息是否是安全的呢?我们知道,梯度与标签和样本特征有关,那么意味着梯度其中包含着部分的标签信息和原始信息,所以作者做了这样一个工作,通过神经网络中的梯度信息去反推原始数据和标签。
**方法:**作者将随机生成一份和真数据同样大小的假输入样本和假的标签,然后把这些假样本和假标签输入到现有的模型当中,然后得到假的模型梯度。方法的目标是生成与原模型相同梯度的假梯度,这样在假样本和假标签就和真实的样本标签一致。目标函数如下:
流程:
这里拿一张小猫图片进行示例,对于输入样本可以通过训练过的网络得到预测值和梯度。而在攻击模型中,将随机输入我们的输入x和标签向量,将模型迁移过来。然后我们将计算我们的梯度与原模型梯度大差值,通过反推更新输入样本和标签信息,以此进行迭代。
算法:
如上所属,对没个输入样本遍历,迭代更新输入的假样本和假标签,直到达到收敛。
实验结果:
从实验中,我们看出,对于输入的图片,经过少轮的迭代即可恢复出大致图像信息。
可以看出,随着网络层数的增加,原始信息恢复的就更多,相比Meils方法,作者的方法显著更小。
代码:
原作者代码传送
代码解读:这里有详细的代码解析。
def main():
seed = 1234
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
dataset = args.dataset
root_path = '.'
data_path = os.path.join(root_path, './data').replace('\\', '/')
save_path = os.path.join(root_path, 'results/DLG_%s' % dataset).replace('\\', '/')
lr = 0.2
num_dummy = 1
iteration = 300
num_exp = 1
use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'
tt = transforms.Compose([transforms.ToTensor()])
tp = transforms.Compose([transforms.ToPILImage()])
'''
打印路径而已
'''
print(dataset, 'root_path:', root_path)
print(dataset, 'data_path:', data_path)
print(dataset, 'save_path:', save_path)
if not os.path.exists('results'):
os.mkdir('results')
if not os.path.exists(save_path):
os.mkdir(save_path)
'''
加载数据
'''
if dataset == 'MNIST' or dataset == 'mnist':
image_shape = (28, 28)
num_classes = 10
channel = 1
hidden = 588
dst = datasets.MNIST(data_path, download=True)
elif dataset == 'cifar10' or dataset == 'CIFAR10':
image_shape = (32, 32)
num_classes = 10
channel = 3
hidden = 768
dst = datasets.CIFAR10(data_path, download=True)
elif dataset == 'cifar100' or dataset == 'CIFAR100':
image_shape = (32, 32)
num_classes = 100
channel = 3
hidden = 768
dst = datasets.CIFAR100(data_path, download=True)
elif dataset == 'lfw':
shape_img = (32, 32)
num_classes = 5749
channel = 3
hidden = 768
lfw_path = os.path.join(root_path, './data/lfw')
dst = lfw_dataset(lfw_path, shape_img)
else:
exit('unkown dataset')
for idx_net in range(num_exp):
net = LeNet(channel=channel, hidden=hidden, num_classes=num_classes)
net.apply(weights_init)
print('running %d|%d experiment' % (idx_net, num_exp))
net = net.to(device)
print('%s, Try to generate %d images' % ('DLG', num_dummy))
criterion = nn.CrossEntropyLoss().to(device)
imidx_list = []
for imidx in range(num_dummy):
idx = args.index
imidx_list.append(idx)
tmp_datum = tt(dst[idx][0]).float().to(device)
tmp_datum = tmp_datum.view(1, *tmp_datum.size())
tmp_label = torch.Tensor([dst[idx][1]]).long().to(device)
tmp_label = tmp_label.view(1, )
if imidx == 0:
gt_data = tmp_datum
gt_label = tmp_label
else:
gt_data = torch.cat((gt_data, tmp_datum), dim=0)
gt_label = torch.cat((gt_label, tmp_label), dim=0)
out = net(gt_data)
y = criterion(out, gt_label)
dy_dx = torch.autograd.grad(y, net.parameters())
original_dy_dx = list((_.detach().clone() for _ in dy_dx))
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn((gt_data.shape[0], num_classes)).to(device).requires_grad_(True)
optimizer = torch.optim.LBFGS([dummy_data, dummy_label], lr=lr)
history = []
history_iters = []
grad_difference = []
data_difference = []
train_iters = []
print('lr =', lr)
for iters in range(iteration):
def closure():
optimizer.zero_grad()
pred = net(dummy_data)
dummy_loss = -torch.mean(
torch.sum(torch.softmax(dummy_label, -1) * torch.log(torch.softmax(pred, -1)), dim=-1))
dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
grad_diff = 0
for gx, gy in zip(dummy_dy_dx, original_dy_dx):
grad_diff += ((gx - gy) ** 2).sum()
grad_diff.backward()
return grad_diff
optimizer.step(closure)
current_loss = closure().item()
train_iters.append(iters)
grad_difference.append(current_loss)
data_difference.append(torch.mean((dummy_data - gt_data) ** 2).item())
if iters % int(iteration / 30) == 0:
current_time = str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))
print(current_time, iters, '梯度差 = %.8f, 数据差 = %.8f' % (current_loss, data_difference[-1]))
history.append([tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy)])
history_iters.append(iters)
for imidx in range(num_dummy):
plt.figure(figsize=(12, 8))
plt.subplot(3, 10, 1)
plt.imshow(tp(gt_data[imidx].cpu()), cmap='gray')
for i in range(min(len(history), 29)):
plt.subplot(3, 10, i + 2)
plt.imshow(history[i][imidx], cmap='gray')
plt.title('iter=%d' % (history_iters[i]))
plt.axis('off')
plt.savefig('%s/DLG_on_%s_%05d.png' % (save_path, imidx_list, imidx_list[imidx]))
plt.close()
if current_loss < 0.000001:
break
loss_DLG = grad_difference
label_DLG = torch.argmax(dummy_label, dim=-1).detach().item()
mse_DLG = data_difference
缺陷:读者认为该论文方向新颖,通过逼近现有模型的梯度信息来反推原始数据,并且取得效果不错。但方法条件研究苛刻,需要将原模型的网络架构共享给攻击模型,在实际场景中是难以满足这样要求的,其他方是不可能完整拿到网络结构的,因此单纯实际场景下该方法并不能直接使用,但这篇论文引起联邦从业者的广泛担忧,毕竟梯度信息泄漏是可能是“危险的”。
|