IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 联邦学习(FL,Federated Learning) 之FedAvg算法 -> 正文阅读

[人工智能]联邦学习(FL,Federated Learning) 之FedAvg算法

Communication-Efficient Learning of Deep Networks from Decentralized Data?

论文地址:[1602.05629] Communication-Efficient Learning of Deep Networks from Decentralized Data (arxiv.org)

FL主要瓶颈 : ?? ?

1. 通信速率不稳定,且可能不可靠 ?? ?

2. 聚合服务器的容量有限,同时与server通信的client的数量受限

解决方案:

在FL的每一步考虑:? ? ?

????????1. 减少client数量 ?? ?

????????2. 减少通信带宽

FedAvg算法采取策略:

????????增加客户端计算,限制通信频率(在上传更新的梯度之前执行多次本地梯度下降迭代)

FedAvg算法 ?? ???

????????随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新,同时用当前全局模型替换未采样的客户端 ?? ?

优点:相对于FedSGD在相同效果情况下,通讯成本大大降低 ?? ?

缺点:最终的模型是有偏倚的,不同于预期的每个客户端确定性聚合后的模型。

FedAvg客户端抽样:

? ? ? ? ? ? ? ?

  • 在每次迭代中对随机选择参与的客户端子集St进行统一抽样,并下发当前全局模型参数θt,在客户端本地进行训练更新梯度,并上传至服务端进行平均形成更新参数θi^(t+1);
  • 不属于抽样子集的客户端的更新则由当前的全局模型参数θt代替;
  • 然后服务端进行平均产生新的全局参数θt+1。

FedAvg算法伪代码:

? ? ? ? ? ? ? ??

?FedAvg算法步骤:

1. 在每一轮迭代的步骤t,服务端发送当前全局模型参数θ给客户端

2. 非抽样子集中的客户端根据θt,通过SGD更新本地模型

3. 抽样子集中每个客户端上传更新后的本地参数θt+1

4. 在迭代步骤t+1,服务端根据全局模型参数θi(t+1)计算出加权平均值θt+1

? ? ? ? ??

? ? ? ? ? ? ?

其中,由此可得最终的优化的目标函数。

?优化目标:

pi表示权重,一般表达式为nk/n。FedAvg算法最终取Li(θ)的加权平均值。

算法实现代码:

 # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc)/len(list_acc))

源代码地址:

Federated-Learning-PyTorch/federated_main.py at master · AshwinRJ/Federated-Learning-PyTorch · GitHub

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-16 11:44:45  更:2021-08-16 11:45:16 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 0:54:17-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码