IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 数据安全--28--Pyhton完成横向联邦图像分类 -> 正文阅读

[人工智能]数据安全--28--Pyhton完成横向联邦图像分类

一、常见的配置概念

联邦学习在开发过程中会涉及大量的参数配置,其中比较常用的参数设置包括以下几个:

● 训练的客户端数量:每一轮的迭代,服务端会首先从所有的客户端中挑选部分客户端进行本地训练,挑出部分不仅不会影响全局收敛的效果,而且能够提升训练的效率。
● 全局迭代次数:即服务端和客户端的通信次数。通常会设置一个最大的全局迭代次数,但在训练过程中,只要模型满足收敛的条件,那么训练也可以提前终止。
● 本地模型的迭代次数:即每一个客户端在进行本地模型训练时的迭代次数。每一个客户端的本地模型的迭代次数可以相同,也可以不同。
● 本地训练相关的算法配置:本地模型进行训练时的参数设置,如学习率、训练样本大小、使用的优化算法等。
● 模型信息:即当前任务我们使用的模型结构。
● 数据信息:联邦学习训练的数据。

联邦学习在模型训练之前,会将配置信息分别发送到服务端和客户端中保存,如果配置信息发生改变,也会同时对所有参与方进行同步,以保证各参与方的配置信息一致。

以上配置体现形式如下:

conf.json

{
	"model_name" : "resnet18",         // 模型名称
	"no_models" : 10,                  // 客户端数量
	"type" : "cifar",                  // 数据集信息
	"global_epochs" : 20,              // 全局迭代次数,即服务端与客户端的通信迭代次数
	"local_epochs" : 3,                // 本地模型训练迭代次数
	"k" : 5,                           // 每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练
	"batch_size" : 32,                 // 本地训练每一轮的样本数
	"lr" : 0.001,                      // 本地训练的超参数设置
	"momentum" : 0.0001,
	"lambda" : 0.1
}

二、server端

服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。
已对代码做出了具体的注释说明,具体细节阅读代码即可。

server.py

import models, torch

class Server(object):
	# 定义构造函数
	def __init__(self, conf, eval_dataset):
	    # 将配置信息拷贝到服务端中
		self.conf = conf 
		self.global_model = models.get_model(self.conf["model_name"]) 
		# 按照配置中的模型信息获取模型
		self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
		
	# 定义模型聚合函数,weight_accumulator存储了每一个客户上传的参数变化值
	def model_aggregate(self, weight_accumulator):
		for name, data in self.global_model.state_dict().items():
			update_per_layer = weight_accumulator[name] * self.conf["lambda"]
			if data.type() != update_per_layer.type():
				data.add_(update_per_layer.to(torch.int64))
			else:
				data.add_(update_per_layer)
			
	# 定义模型评估函数,对当前的全局模型,利用评估数据评估当前的全局模型性能
	def model_eval(self):
		self.global_model.eval()
		total_loss = 0.0
		correct = 0
		dataset_size = 0
		for batch_id, batch in enumerate(self.eval_loader):
			data, target = batch 
			dataset_size += data.size()[0]
			
			if torch.cuda.is_available():
				data = data.cuda()
				target = target.cuda()
				
			output = self.global_model(data)
			# 把损失值聚合起来
			total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
			# 获取最大的对数概率的索引值
			pred = output.data.max(1)[1]
			correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
		
		# acc是计算准确率,total_l是计算损失值 
		acc = 100.0 * (float(correct) / float(dataset_size))
		total_l = total_loss / dataset_size

		return acc, total_l

三、client端

客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。
已对代码做出了具体的注释说明,具体细节阅读代码即可。

client.py

import models, torch, copy

class Client(object):
	# 定义构造函数
	def __init__(self, conf, model, train_dataset, id = -1):
		self.conf = conf
		self.local_model = models.get_model(self.conf["model_name"])
		# 客户端ID 
		self.client_id = id
		# 客户端本地数据集
		self.train_dataset = train_dataset  
		all_range = list(range(len(self.train_dataset)))
		data_len = int(len(self.train_dataset) / self.conf['no_models'])
		train_indices = all_range[id * data_len: (id + 1) * data_len]
		self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"], 
									sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
	
	# 定义模型本地训练函数				
	def local_train(self, model):
		for name, param in model.state_dict().items():
			# 客户端首先用服务端下发的全局模型覆盖本地模型
			self.local_model.state_dict()[name].copy_(param.clone())
	
		# 定义最优化函数器,用于本地模型训练
		optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'])
		# 本地模型训练
		self.local_model.train()
		for e in range(self.conf["local_epochs"]):
			for batch_id, batch in enumerate(self.train_loader):
				data, target = batch
				if torch.cuda.is_available():
					data = data.cuda()
					target = target.cuda()
			
				optimizer.zero_grad()
				output = self.local_model(data)
				loss = torch.nn.functional.cross_entropy(output, target)
				loss.backward()
				optimizer.step()
			print("Epoch %d done." % e)	
		diff = dict()
		for name, data in self.local_model.state_dict().items():
			diff[name] = (data - model.state_dict()[name])
			
		return diff

四、训练数据集

按照配置文件中的type字段信息,获取数据集。

datasets.py

import torch 
from torchvision import datasets, transforms

def get_dataset(dir, name):

	if name=='mnist':
		train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
		eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
		
	elif name=='cifar':
		transform_train = transforms.Compose([
			transforms.RandomCrop(32, padding=4),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
		])

		transform_test = transforms.Compose([
			transforms.ToTensor(),
			transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
		])
		
		train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
		eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
		
	return train_dataset, eval_dataset

五、整合训练

当配置文件、服务端和客户端都定义完后,就可以整合起来完成一次横向联邦学习的图像分类了。

main.py

import argparse, json
import datetime
import os
import logging
import torch, random
from server import *
from client import *
import models, datasets

if __name__ == '__main__':

	parser = argparse.ArgumentParser(description='Federated Learning')
	parser.add_argument('-c', '--conf', dest='conf')
	args = parser.parse_args()
	
	# 读取配置文件信息
	with open(args.conf, 'r') as f:
		conf = json.load(f)	
	
	# 分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景
	train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
	server = Server(conf, eval_datasets)
	clients = []
	# 创建多个客户端
	for c in range(conf["no_models"]):
		clients.append(Client(conf, server.global_model, train_datasets, c))	
	print("\n\n")
	
	# 每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练
	for e in range(conf["global_epochs"]):
		# 采样k个客户端参与本轮联邦训练
		candidates = random.sample(clients, conf["k"])
		weight_accumulator = {}
		for name, params in server.global_model.state_dict().items():
			weight_accumulator[name] = torch.zeros_like(params)
		for c in candidates:
			diff = c.local_train(server.global_model)
			for name, params in server.global_model.state_dict().items():
				weight_accumulator[name].add_(diff[name])
		
		# 调用模型聚合函数model_aggregate,来更新全局模型
		server.model_aggregate(weight_accumulator)
		acc, loss = server.model_eval()
		print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))

六、项目运行

执行以下命令即可:

python main.py -c conf.json
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-01 23:22:47  更:2022-04-01 23:23:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 0:50:45-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码