IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 联邦学习的修仙之路_3 -> 正文阅读

[人工智能]联邦学习的修仙之路_3

4. TensorFlow Federated 实现Minst数据集识别

前两篇博客回顾?重新学习?了tensorflow 结构以及 tff 的两层API,讲道理官网对这两层的讲解真的有够迷。我之前一直纠结这个‘两层’的含义,现在稍微懂了一点:其中高层的FLearning是指不关注模型底层架构,把现有的模型拿来之后做简单的转换;而FCore则指可以重新创建或修改底层模型架构的API。

这篇博文的出发点是官方文档给的第一个TFF案例,(我学代码比较习惯从直接看一篇跑通的代码开始,不然一直学理论也搞不出个所以然来)。不得不说,这篇官方tutorial虽然是英文文档,但讲的非常清晰(比前几篇指南好了N个层次),相信只要不是英语过于苦手的朋友都可以读的懂这篇官方文档(shown as below):

https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification

写在最前面:官方给的例子主体都是由FLearning的高层API撰写的(所以在没有特殊标注之前,也基于官网同步)。

4.1? 导入并熟悉Minst数据集

和常规的机器学习步骤相同,这里的第一部也需要进行数据集导入,处理:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

这里使用内置数据集Mnist进行导入,导入之后可以分别使用如下函数进行查看长度,类型及打印(第一个图像),具体代码原文中展示十分详细,在这里不予赘述;仅列出几个关键字及作用。

element_type_structureAttributes, The element type information of the client datasets. elements returned by datasets in this ClientData object.

create_tf_dataset_for_client

Method,
create_tf_dataset_for_client(
? ? client_id: str
) 

-> tf.data.Dataset
from matplotlib import pyplot as plt plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal') plt.grid(False) _ = plt.show()

用于画图的固定套组,老朋友了。

官方文档里使用(包括但不限于)上述关键字分别查看了数据的数量,数据的标签和实际图像(label, pixels)但这些内容都属于帮助我们熟悉数据集,真正需要用于tff过程的只有最上面导入数据集的那一句代码。

在此之后,原文还探索了每个用户的书写特征(根据id并统计出条形图)并计算mean产生图像(这里证明了联邦学习所处在的非独立同分布情况)因为这部分也属于探索数据集,就不再分析了;可以查原文,讲的比较清晰。

4.2 处理Mnist数据集

在导入数据集之后,按正常流程将数据集进行处理数据集:将其拉平,重复,打乱。值得注意的是,这里将处理数据的过程(囊括到一个方法里),然后通过调用传参进行调用。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

这里可以看到,前面一段定义了一些超参数;第二段实现图像像素拉平;第三段实现数据的重复,打乱;这部分代码可以用如下代码进行检验:

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?next(iter(preprocessed_example_dataset)))

sample_batch

这里啰嗦一句next和iter;list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。然后我们可以对获取到的迭代器不断使?next()函数来获取下?条数据。

在准备好这些之后,分出是个客户端并为这十个客户端分配数据集:

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

4.3 准备/生成Model(Keras)

首先利用Keras构建模型:

def create_keras_model():
? return tf.keras.models.Sequential([
? ? ? tf.keras.layers.InputLayer(input_shape=(784,)),
? ? ? tf.keras.layers.Dense(10, kernel_initializer='zeros'),
? ? ? tf.keras.layers.Softmax(),
? ])

结构比较简单,输入层 + 稠密层 + 激活层。构造好模型之后将其转换 tff 实例,这步也是高阶API的精髓:

def model_fn():
? # We _must_ create a new model here, and _not_ capture it from an external
? # scope. TFF will call this within different graph contexts.
? keras_model = create_keras_model()
? return tff.learning.from_keras_model(
? ? ? keras_model,
? ? ? input_spec=preprocessed_example_dataset.element_spec,
? ? ? loss=tf.keras.losses.SparseCategoricalCrossentropy(),
? ? ? metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

4.4 开始训练

在构建好模型之后,可以开始进行训练。这里使用了build_federated_averaging_process来创建了一个交互的训练过程。这里定义了两个学习率:客户端和服务器,其中前者负责本地的更新而后者作用于avg。

可以用 type_signature 来打印出函数签名(),(函数签名由函数原型组成。它告诉你的是关于函数的一般信息,它的名称,参数,它的范围以及其他杂项信息。可以确定传入的参数是符合要求的)

接着将iterative_process进行initialize得到state,即:

state = iterative_process.initialize()

值得注意的是这里的state并不是指 ‘状态’,根据官方的解释:The?initialize_fn?function must return an object which is expected as input to and returned by the?next_fn?function. By convention, we refer to this object as?state.

再用两个参数去接 state 和 metric 就可以开始训练优化过程:

state, metrics = iterative_process.next(state, federated_train_data)

再使用for循环进行循环训练优化,就可以了:

for round_num in range(2, 11):
? state, metrics = iterative_process.next(state, federated_train_data)
? print('round {:2d}, metrics={}'.format(round_num, metrics))

4.5 Evaluation

在进行了一定轮次的训练之后(或者在准确度达到一定的程度之后)可以停止训练并开始进行模型的评估。原文说也是在防止过拟合。直接调用并创建得到实例:

evaluation = tff.learning.build_federated_evaluation(MnistModel) 

在得到实例之后,得到 train_metrics,再str()打印即可;后续如果想测试准确度,再重复使用evaluation即可。

train_metrics = evaluation(state.model, federated_train_data)

str(train_metrics)

5 结语

这篇博客顺理了一遍FL高阶API的Mnist的识别;可以发现在这个项目里tff的东西其实比较少,大部分都是keras的东西还有数据处理。这里可能也就体现出来了高阶API的特点。下一篇准备搞一下低阶API,因为最开始也是从低阶开始学这个项目的,毕竟是原汁原味的TFF 不是 \( ̄▽ ̄)/

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-05 17:21:26  更:2021-08-05 17:24:16 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/7 20:46:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码