IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TensorflowAPI:tf.keras搭建网络八股,改进鸢尾花分类 -> 正文阅读

[人工智能]TensorflowAPI:tf.keras搭建网络八股,改进鸢尾花分类

keras介绍
tf.keras是tensorflow2引入的高封装度的框架,可以用于快速搭建神经网络模型,keras为支持快速实验而生,能够把想法迅速转换为结果,是深度学习框架之中最终易上手的一个,它提供了一致而简洁的API,能够极大地减少一般应用下的工作量,提高代码地封装程度和复用性。

一.用tf.keras创建网络的步骤

1.import 引入相应的python库

2.train,test告知要喂入的网络的训练集和测试集是什么,指定训练集的输入特征,x_train和训练集的标签y_train,以及测试集的输入特征和测试集的标签。

3.model = tf,keras,models,Seqential 在Seqential中搭建网络结构,逐层表述每层网络,走一边前向传播。

4.model.compile 在complie()中配置训练方法。告知训练器选择哪种优化器,选择哪个损失函数,哪种评测指标

5。model.fit? 在fit中执行训练过程,告知训练集和测试集的训练特征和标签,告知每个batch是多少,要迭代多少次数据集。

6.model.summary 用summary()打印出网络的结构和参数统计

二.改进鸢尾花分类

相关函数说明

1.tf.keras.models.Sequential()

Sequential函数是一个容器,描述了神经网络的网络结构,在Sequential函数的输入参数中描述从输入层到输出层的网络结构。

?全连接层:tf.keras.layers.Dense( 神经元个数,
????????????????????????????????????????????????????????????????activation=”激活函数”,
????????????????????????????????????????????????????????????????kernel_regularizer=”正则化方式”)
其中:
activation激活函数(字符串给出)可选relu、softmax、sigmoid、tanh等
kernel_regularizer正则化可选tf.keras.regularizers.l1()、
????????????????????????????????????????tf.keras.regularizers.l2()
2.Model.compile()

Model.compile( optimizer = 优化器,
????????????????????????????????loss = 损失函数,
????????????????????????????????metrics = [“准确率”])
Compile用于配置神经网络的训练方法,告知训练时使用的优化器、损失函数和准确率评测标准。
其中:
optimizer可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数。

可选项包括:
‘sgd’or tf.optimizers.SGD( lr=学习率,
????????????????????????????????????????????????decay=学习率衰减率,
????????????????????????????????????????????????momentum=动量参数)
‘adagrad’or tf.keras.optimizers.Adagrad(lr=学习率,
????????????????????????????????????????????????????????????????????????decay=学习率衰减率)
‘adadelta’or tf.keras.optimizers.Adadelta(lr=学习率,
????????????????????????????????????????????????????????????????????????decay=学习率衰减率)
‘adam’or tf.keras.optimizers.Adam (lr=学习率,
????????????????????????????????????????????????????????????????decay=学习率衰减率)

Loss可以是字符串形式给出的损失函数的名字,也可以是函数形式。
可选项包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy
or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
损失函数常需要经过softmax等函数将输出转化为概率分布的形式。from_logits则用来标注该损失函数是否需要转换为概率的形式,取False时表示转化为概率分布,取True时表示没有转化为概率分布,直接输出。

4.Metrics标注网络评测指标。
可选项包括:
‘accuracy’:y_和y都是数值,

????????如y_=[1] y=[1]。
‘categorical_accuracy’:y_和y都是以独热码和概率分布表示。
????????如y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’:y_是以数值形式给出,y是以独热码形式给出。
????????如y_=[1],y=[0.256, 0.695, 0.048]。

5.model.fit()

model.fit(训练集的输入特征, 训练集的标签, batch_size, epochs,
? ? ? ? ? ? ? ? validation_data = (测试集的输入特征,测试集的标签),
????????????????validataion_split = 从测试集划分多少比例给训练集,
????????????????validation_freq = 测试的epoch间隔次数)
fit函数用于执行训练过程

6.model.summary()

summary函数用于打印网络结构和参数统计

?上图是model.summary()对鸢尾花分类网络的网络结构和参数统计,对于一个输入为4输出为3的全连接网络,共有15个参数。

代码清单:

# 使用tf.keras改进鸢尾花数据集的预测
# 1.import
import tensorflow as tf
from sklearn import datasets
import numpy as np

# 2.train test
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_train)	#shuffle-洗牌
np.random.seed(116)
np.random.shuffle(y_train)


# 3.model.Sequential()
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
# 4.complie
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy']
              )
# 5.model.fit
model.fit(x_train,y_train,batch_size=32,epochs=500,validation_split=0.2,validation_freq=20)

# 6.summary
model.summary()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-19 12:04:06  更:2021-08-19 12:05:25 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 23:44:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码