基于 PyTorch 语法复现FedAvg 模型
本联邦学习模型是基于论文 : Communication-Efficient Learning of Deep Networks from Decentralized Data.实现。看文章之前需要对该文章有了解,本篇文章主要以 code 为主 源码地址:FedAvg
FedAvg 算法思想
FedAvg 大致思想如下:
- 服务端初始化一个权重参数,随机选择参与的客户端数量,广播给客户端
- 客户端获取初始化的权重参数,以及服务端选择的客户端,然后客户端在本地进行 n 轮训练,本地的每一轮训练都是以 batch 大小进行的训练, 对 n 轮结束以后得到的本地权重参数求平均值,传递给服务端
- 服务端收到来自客户端的权重,然后对客户端权重进行平均值求取更新服务端权重,再次传递给客户端进行下一轮的全局训练
代码实现
FedAvg.py 运行程序
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import torch
from tensorboardX import SummaryWriter
from options import args_parser
from update import LocalUpdate, test_inference
from model import CNNMnist
from utils import get_dataset, average_weights
if __name__ == '__main__':
start_time = time.time()
path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')
args = args_parser()
if args.gpu and torch.cuda.is_available():
device = 'cuda'
print(f'device is {device}')
else:
device = 'cpu'
print(f'device is {device}')
train_dataset, test_dataset, user_group = get_dataset(args)
if args.model == 'cnn' and args.dataset == 'mnist':
global_model = CNNMnist(args=args)
else:
exit('没有适合的模型,需要创建一个模型')
global_model.to(device)
global_model.train()
print(global_model)
global_weight = global_model.state_dict()
train_loss, train_acc = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
for epoch in tqdm(range(args.epochs)):
local_weight, local_losses = [], []
print(f'\n global training round: {epoch + 1} | \n')
global_model.train()
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
local_weight.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
global_weight = average_weights(local_weight)
global_model.load_state_dict(global_weight)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
list_acc, list_loss = [], []
global_model.eval()
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_group[idx], logger=logger)
acc, loss = local_model.inference(model=global_model)
list_acc.append(acc)
list_loss.append(loss)
train_acc.append(sum(list_acc) / len(list_acc))
if (epoch + 1) % print_every == 0:
print(f'\n avg training stats after {epoch + 1} global rounds: ')
print(f'training loss: {np.mean(np.array(train_loss))}')
print('Train Accuracy: {:.2f}% \n'.format(100 * train_acc[-1]))
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print(f' \n Results after {args.epochs} global rounds of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_acc[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))
file_name = './save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
)
with open(file_name, 'wb') as f:
pickle.dump([train_loss, train_acc], f)
print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
matplotlib.use('Agg')
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
))
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_acc)), train_acc, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.format(
args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs
))
options.py 运行时使用的一些参数
import argparse
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=10, help="number of rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum (default: 0.5)')
parser.add_argument('--model', type=str, default='cnn', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imgs")
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32,
help="number of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot.")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--gpu', default=True, help="To use cuda, set to a specific GPU ID. Default set to use CPU.")
parser.add_argument('--optimizer', type=str, default='sgd', help="type of optimizer")
parser.add_argument('--iid', type=int, default=1, help='Default set to IID. Set to 0 for non-IID.')
parser.add_argument('--unequal', type=int, default=0,
help='whether to use unequal data splits for non-i.i.d setting (use 0 for equal splits)')
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
parser.add_argument('--verbose', type=int, default=1, help='verbose')
parser.add_argument('--seed', type=int, default=1, help='random seed')
args = parser.parse_args()
return args
update.py 参数更新的部分
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class."""
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = [int(i) for i in idxs]
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return torch.tensor(image), torch.tensor(label)
class LocalUpdate(object):
def __init__(self, args, dataset, idxs, logger):
self.args = args
self.logger = logger
self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
self.device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
self.criterion = nn.NLLLoss().to(self.device)
def train_val_test(self, dataset, idxs):
"""
Returns train, validation and test dataloaders for a given dataset and user indexes.
"""
idxs_train = idxs[:int(0.8 * len(idxs))]
idxs_val = idxs[int(0.8 * len(idxs)):int(0.9 * len(idxs))]
idxs_test = idxs[int(0.9 * len(idxs)):]
trainloader = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
validloader = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val) / 10), shuffle=False)
testloader = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test) / 10), shuffle=False)
return trainloader, validloader, testloader
def update_weights(self, model, global_round):
model.train()
epoch_loss = []
if self.args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=0.5)
elif self.args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4)
for local_epoch in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.trainloader):
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
log_probs = model(images)
loss = self.criterion(log_probs, labels)
loss.backward()
optimizer.step()
if self.args.verbose and (batch_idx % 10 == 0):
print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)] \t Loss: {:.6f}'.format(
global_round, local_epoch, batch_idx * len(images),
len(self.trainloader.dataset), 100. * batch_idx / len(self.trainloader), loss.item()))
self.logger.add_scalar('loss', loss.item())
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
def inference(self, model):
""" Returns the inference accuracy and loss."""
model.eval()
loss, total, correct = 0.0, 0.0, 0.0
for batch_idx, (images, labels) in enumerate(self.testloader):
images, labels = images.to(self.device), labels.to(self.device)
outputs = model(images)
batch_loss = self.criterion(outputs, labels)
loss += batch_loss.item()
_, pred_labels = torch.max(outputs, 1)
pred_labels = pred_labels.view(-1)
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct / total
return accuracy, loss
def test_inference(args, model, test_dataset):
""" Returns the test accuracy and loss. """
model.eval()
loss, total, correct = 0.0, 0.0, 0.0
device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
criterion = nn.NLLLoss().to(device)
testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
for batch_idx, (images, labels) in enumerate(testloader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
batch_loss = criterion(outputs, labels)
loss += batch_loss.item()
_, pred_labels = torch.max(outputs, 1)
pred_labels = pred_labels.view(-1)
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct / total
return accuracy, loss
model.py 模型设计部分部分
from torch import nn
import torch.nn.functional as F
class CNNMnist(nn.Module):
def __init__(self, args):
super(CNNMnist, self).__init__()
self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, args.num_classes)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
utils.py 与数据有关的函数:获取数据,格式化数据,平均权重
import copy
import torch
from torchvision import datasets, transforms
def mnist_iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset) / num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
if args.dataset == 'mnist':
data_dir = '../data/pytorch/'
apply_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=apply_transform)
if args.iid:
user_groups = mnist_iid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
""" Returns the average of the weights."""
w_avg = copy.deepcopy(w[0])
for key in w_avg.keys():
for i in range(1, len(w)):
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
return w_avg
运行
默认执行 option.py 里面的 default 参数 ,执行条件如下
python FedAvg.py
如果想自定义参数,比如使用 cpu 计算,增加运算的epoch 次数,可以这样执行
python FedAvg.py --epochs=15 --gpu=False
如果想了解源码可以点击这里:fedavg
|