I. 前言
Per-FedAvg的原理请见:arXiv | Per-FedAvg:一种联邦元学习方法。
II. 数据集介绍
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
III. Per-FedAvg
算法伪代码:
1. 服务器端
服务器端和FedAvg一致,这里不再详细介绍了,可以看看前面几篇文章。
2. 客户端
对于每个客户端,我们定义它的元函数
F
i
(
w
)
F_i(w)
Fi?(w):
为了在本地训练中对
F
i
(
w
)
F_i(w)
Fi?(w)进行更新,我们需要计算其梯度:
代码实现如下:
def train(args, model):
model.train()
Dtr, Dte, m, n = nn_seq(model.name, args.B)
model.len = len(Dtr)
print('training...')
data = [x for x in iter(Dtr)]
for epoch in range(args.E):
model = one_step(args, data, model, lr=args.alpha)
model = one_step(args, data, model, lr=args.beta)
return model
def one_step(args, data, model, lr):
ind = np.random.randint(0, high=len(data), size=None, dtype=int)
seq, label = data[ind]
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss().to(args.device)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
3. 本地自适应更新
得到初始模型后,需要在本地进行1轮迭代更新:
def local_adaptation(args, model):
model.train()
Dtr, Dte = nn_seq_wind(model.name, 50)
optimizer = torch.optim.Adam(model.parameters(), lr=args.alpha)
loss_function = nn.MSELoss().to(args.device)
loss = 0
for epoch in range(1):
for seq, label in Dtr:
seq, label = seq.to(args.device), label.to(args.device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
|