IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> tensorflow 1.X版本 与2.X版本的区别 sparse_categorical_crossentropy损失函数踩雷 -> 正文阅读

[Python知识库]tensorflow 1.X版本 与2.X版本的区别 sparse_categorical_crossentropy损失函数踩雷

卷积神经网络

2.X版本的tensorflow是有Input层的

# Create the Student Model
student = keras.Sequential(
        [
            keras.Input(shape=(28,28,1)),
            layers.Conv2D(16,(3,3),strides = (2,2),padding = "same"),
            layers.LeakyReLU(alpha=0.2),
            layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding="same"),
            layers.Conv2D(32,(3,3),strides=(2,2),padding="same"),
            layers.Flatten(),
            layers.Dense(10),
            
        ],
        name = "student" # 加这一行可以打印模型结构的时候顺便打印模型名字
)
student.summary() # 打印当前模型的结构

1.X版本则会报错

TypeError: The added layer must be an instance of class Layer. Found: Tensor("input_3:0", shape=(?, 28, 28, 1), dtype=float32)

解决办法:把input放到第一个conv中,input会把变量变成tensor,影响后面的层

from keras import layers
student = keras.Sequential(
        [
            # keras.Input(shape=(28,28,1)), # 版本问题报错,输入改到conv中
            layers.Conv2D(16,(3,3),input_shape=(28,28,1),strides = (2,2),padding = "same"),
            layers.LeakyReLU(alpha=0.2),
            layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding="same"),
            layers.Conv2D(32,(3,3),strides=(2,2),padding="same"),
            layers.Flatten(),
            layers.Dense(10),
            
        ],
        name = "student" 
)
student.summary() # 打印当前模型的结构

编译阶段
2.X版本的tensorflow

teacher.compile(
        optimizer = keras.optimizers.Adam(),
        loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics = [keras.metrics.SparseCategoricalAccuracy()],
        )

# train and evaluation
teacher.fit(x_train,y_train,epochs = 1) #实际情况会训练更多的轮数,如100或更多
teacher.evaluate(x_test,y_test)

在1.X版本中会报错

解决办法:

student.compile(
        optimizer = keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
        )

# train and evaluation
student.fit(x_train,y_train,epochs = 3) 
student.evaluate(x_test,y_test)

但是注意!!!
如果只是简单的这么更改会使得训练根本无法提升acc:
在这里插入图片描述
可以看到不管训练多少轮acc都没有增加,loss没有下降。

但是用之前的loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)训练过程中就很正常:
在这里插入图片描述
acc轻轻松松上0.85

究其原因:原来问题出在logit=True这个参数上,logit=True相当于给输出加了一个softmax的输出,将Dense输出的数值映射到[0,1]范围内,如果直接把loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)替换成了loss='sparse_categorical_crossentropy'则需要收到在模型的Dense层后面加一个softmax激活函数:

from keras import layers
student = keras.Sequential(
        [
            # keras.Input(shape=(28,28,1)), # 版本问题报错,输入改到conv中
            layers.Conv2D(16,(3,3),input_shape=(28,28,1),strides = (2,2),padding = "same"),
            layers.LeakyReLU(alpha=0.2),
            layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding="same"),
            layers.Conv2D(32,(3,3),strides=(2,2),padding="same"),
            layers.Flatten(),
            layers.Dense(10,activation="softmax"),
            
        ],
        name = "student" 
)
student.summary() # 打印当前模型的结构

这样就能正常训练了:
在这里插入图片描述
可以打印model的输出看一下,不加activation="softmax"的输出是怎么样的

y_pre = my_model.predict(np.reshape(x_train[:10],(-1,28,28,1))) # 最后dense层未加softmax,得到的是一个非常大范围的值,未映射到[0,1]之间
print(y_pre)

在这里插入图片描述
加了softmax之后的输出

y_pre = my_model.predict(np.reshape(x_train[:10],(-1,28,28,1)))
print(y_pre)

在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-03 16:08:24  更:2022-03-03 16:08:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/31 5:59:31-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码