参考书目《联邦学习实战》 杨强 在阅读本书的过程中,我尝试根据书中的代码,自己实现横向联邦学习中的图像分类任务,这里是我对代码和逻辑的理解还有出现的问题,希望对大家的学习有所帮助。 下面的表格是一些实验基本信息:
配置信息 | 解释 |
---|
数据集 | Cifar10(其将样本划分后给每个客户端作为本地数据) | 全局迭代次数 | 服务器和客户端的通信次数 | 本地模型迭代次数 | 每一次客户端训练的轮数,各个客户端可以相同,也可以不同 |
????一些其它基础的模型配置信息在json文件中给出:
{
"model_name" : "resnet18",
"no_models" : 10,
"type" : "cifar",
"global_epochs" : 20,
"local_epochs": 3,
"k" : 6,
"batch_size" : 32,
"lr" : 0.001,
"momentum" : 0.0001,
"lambda" : 0.1
}
获取训练数据集函数dataset.py:
import torchvision.datasets as dataset
import torchvision.transforms as transform
def get_dataset(dir, name):
if name == 'mnist':
train_dataset = dataset.MNIST(dir, train=True, download=True, transform=transform.ToTensor())
eval_dataset = dataset.MNIST(dir, train=False, transform=transform.ToTensor())
elif name == 'cifar':
transform_train = transform.Compose([
transform.RandomCrop(32, padding=4),
transform.RandomHorizontalFlip(),
transform.ToTensor(),
transform.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transform.Compose([
transform.ToTensor(),
transform.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = dataset.CIFAR10(dir, train=True, download=True, transform=transform_train)
eval_dataset = dataset.CIFAR10(dir, train=False, transform=transform_test)
return train_dataset, eval_dataset
这是一个简单的测试,采用本地模拟的方式进行客户端和服务器的交互,现在定义一个服务端类Server,其中的聚合函数采用的是FedAvg算法更新全局模型,公式如下:
G
t
+
1
=
G
t
+
λ
∑
i
=
1
m
(
L
i
t
+
1
?
G
i
t
)
G_{t+1}=G_t+\lambda\sum_{i=1}^m(L_i^{t+1}-G_i^t)
Gt+1?=Gt?+λi=1∑m?(Lit+1??Git?) 其中
G
t
G_t
Gt?表示第t轮聚合后的全局模型,
L
i
t
+
1
L_i^{t+1}
Lit+1?表示第i个客户端在第t+1轮时本地更新后的模型,
G
t
+
1
G_{t+1}
Gt+1?表示第t+1轮聚合后的全局模型,由于models库可能有所改变,这里只能指明模型的类型为resnet18:
import torch
import torch.utils.data
from torchvision import models
class Server(object):
def __init__(self, conf, eval_dataset):
self.conf = conf
self.global_model = eval('models.{}()'.format(conf['model_name']))
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"],
shuffle=True)
if torch.cuda.is_available():
self.global_model = self.global_model.cuda()
def model_aggregrate(self, weight_accumulator):
for name, data in self.global_model.state_dict().items():
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
if data.type() != update_per_layer.type():
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)
def model_eval(self):
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
dataset_size += data.size()[0]
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
output = self.global_model(data)
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(dataset_size))
total_l = total_loss / dataset_size
return acc, total_l
客户端进行本地的训练,注意数据集采用的是整个数据集中的一部分,最后需要计算出本地模型与之前全局模型的差值用于传输给服务器更新:
import torch
class Client(object):
def __init__(self, conf, model, train_dataset, id=1):
self.conf = conf
self.local_model = model
self.client_id = id
self.train_dataset = train_dataset
all_range = list(range(len(self.train_dataset)))
data_len = int(len(self.train_dataset) / self.conf['no_models'])
indices = all_range[id * data_len : (id + 1) * data_len]
self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf['batch_size'],
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
if torch.cuda.is_available():
self.local_model = self.local_model.cuda()
def local_train(self, model):
for name, param in model.state_dict().items():
self.local_model.state_dict()[name].copy_(param.clone())
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'])
self.local_model.train()
for e in range(self.conf['local_epochs']):
for batch_id, bach in enumerate(self.train_loader):
data, target = bach
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
print('本地模型{}完成第{}轮训练'.format(self.client_id, e))
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
return diff
最后是主函数,其中我将准确率信息和loss记录在excel中用于绘图:
import json
import random
import time
import torch
import dataset
from server import Server
from client import Client
import pandas as pd
import matplotlib.pyplot as plt
if __name__ == '__main__':
accs = []
losses = []
with open('config.json', 'r') as f:
conf = json.load(f)
train_datasets, eval_datasets = dataset.get_dataset('./data/', conf['type'])
server = Server(conf, eval_datasets)
clients = []
for c in range(conf['no_models']):
clients.append(Client(conf, server.global_model, train_datasets, c))
for e in range(conf['global_epochs']):
candidates = random.sample(clients, conf['k'])
weight_accumulator = {}
for name, params in server.global_model.state_dict().items():
weight_accumulator[name] = torch.zeros_like(params)
for c in candidates:
diff = c.local_train(server.global_model)
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])
server.model_aggregrate(weight_accumulator)
acc, loss = server.model_eval()
accs.append(acc)
losses.append(loss)
print('全局模型:第{}轮完成!准确率:{:.2f} loss: {:.2f}'.format(e, acc, loss))
df = pd.DataFrame([accs, losses])
df.to_excel("data_{}.xlsx".format(int(time.time())))
|