一、常见的配置概念
联邦学习在开发过程中会涉及大量的参数配置,其中比较常用的参数设置包括以下几个:
● 训练的客户端数量:每一轮的迭代,服务端会首先从所有的客户端中挑选部分客户端进行本地训练,挑出部分不仅不会影响全局收敛的效果,而且能够提升训练的效率。 ● 全局迭代次数:即服务端和客户端的通信次数。通常会设置一个最大的全局迭代次数,但在训练过程中,只要模型满足收敛的条件,那么训练也可以提前终止。 ● 本地模型的迭代次数:即每一个客户端在进行本地模型训练时的迭代次数。每一个客户端的本地模型的迭代次数可以相同,也可以不同。 ● 本地训练相关的算法配置:本地模型进行训练时的参数设置,如学习率、训练样本大小、使用的优化算法等。 ● 模型信息:即当前任务我们使用的模型结构。 ● 数据信息:联邦学习训练的数据。
联邦学习在模型训练之前,会将配置信息分别发送到服务端和客户端中保存,如果配置信息发生改变,也会同时对所有参与方进行同步,以保证各参与方的配置信息一致。
以上配置体现形式如下:
conf.json
{
"model_name" : "resnet18",
"no_models" : 10,
"type" : "cifar",
"global_epochs" : 20,
"local_epochs" : 3,
"k" : 5,
"batch_size" : 32,
"lr" : 0.001,
"momentum" : 0.0001,
"lambda" : 0.1
}
二、server端
服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。 已对代码做出了具体的注释说明,具体细节阅读代码即可。
server.py
import models, torch
class Server(object):
def __init__(self, conf, eval_dataset):
self.conf = conf
self.global_model = models.get_model(self.conf["model_name"])
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
def model_aggregate(self, weight_accumulator):
for name, data in self.global_model.state_dict().items():
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
if data.type() != update_per_layer.type():
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)
def model_eval(self):
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
dataset_size += data.size()[0]
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = self.global_model(data)
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(dataset_size))
total_l = total_loss / dataset_size
return acc, total_l
三、client端
客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。 已对代码做出了具体的注释说明,具体细节阅读代码即可。
client.py
import models, torch, copy
class Client(object):
def __init__(self, conf, model, train_dataset, id = -1):
self.conf = conf
self.local_model = models.get_model(self.conf["model_name"])
self.client_id = id
self.train_dataset = train_dataset
all_range = list(range(len(self.train_dataset)))
data_len = int(len(self.train_dataset) / self.conf['no_models'])
train_indices = all_range[id * data_len: (id + 1) * data_len]
self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"],
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
def local_train(self, model):
for name, param in model.state_dict().items():
self.local_model.state_dict()[name].copy_(param.clone())
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'])
self.local_model.train()
for e in range(self.conf["local_epochs"]):
for batch_id, batch in enumerate(self.train_loader):
data, target = batch
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
print("Epoch %d done." % e)
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
return diff
四、训练数据集
按照配置文件中的type字段信息,获取数据集。
datasets.py
import torch
from torchvision import datasets, transforms
def get_dataset(dir, name):
if name=='mnist':
train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
elif name=='cifar':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
return train_dataset, eval_dataset
五、整合训练
当配置文件、服务端和客户端都定义完后,就可以整合起来完成一次横向联邦学习的图像分类了。
main.py
import argparse, json
import datetime
import os
import logging
import torch, random
from server import *
from client import *
import models, datasets
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Federated Learning')
parser.add_argument('-c', '--conf', dest='conf')
args = parser.parse_args()
with open(args.conf, 'r') as f:
conf = json.load(f)
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
server = Server(conf, eval_datasets)
clients = []
for c in range(conf["no_models"]):
clients.append(Client(conf, server.global_model, train_datasets, c))
print("\n\n")
for e in range(conf["global_epochs"]):
candidates = random.sample(clients, conf["k"])
weight_accumulator = {}
for name, params in server.global_model.state_dict().items():
weight_accumulator[name] = torch.zeros_like(params)
for c in candidates:
diff = c.local_train(server.global_model)
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])
server.model_aggregate(weight_accumulator)
acc, loss = server.model_eval()
print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))
六、项目运行
执行以下命令即可:
python main.py -c conf.json
|