IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 【Keras】保存模型的前几层,删除最后几层 -> 正文阅读

[Python知识库]【Keras】保存模型的前几层,删除最后几层

1 前言

?? 需求:在使用Keras的过程中,只想保留模型的前几层,删除最后一层,以便网络进行增量训练。
?? 以sklearn中的鸢尾花数据集为例,建立一个多层感知机,以用来删除网络的最后一层
?? 使用以下代码进行删除训练好神经网络的最后一层

model.pop()

2 代码

2.1 删除前代码

?? 载入数据集

import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris
from tqdm import tqdm

data = load_iris()
iris_target = data.target
iris_data = np.float32(data.data)

?? 建立模型并训练

from sklearn.model_selection import train_test_split
from keras.layers import Dense, Dropout
import tensorflow as tf
import keras
from keras.models import Sequential
from tensorflow.keras.optimizers import SGD, Adam


X_train, X_test, y_train, y_test = train_test_split(iris_data, iris_target, test_size = 0.25)

def build_mlp(max_features):
    model = Sequential()
    model.add(Dense(units = 30, input_dim = max_features, activation = "relu"))
    model.add(Dropout(0.2))
    model.add(Dense(units = 30, activation = "relu"))
    model.add(Dense(units = 3, activation = "softmax"))
    # sgd = SGD(lr = 0.002)

    model.compile(optimizer="adam", loss=tf.losses.sparse_categorical_crossentropy, metrics = ["accuracy"])
    return model
# 设置模型提前停止训练
callback = keras.callbacks.EarlyStopping(monitor='loss', patience=5)
model = build_mlp(4)
model.fit(X_train, y_train, epochs=10,  batch_size=128, verbose=1, callbacks = [callback])
score1 = model.evaluate(X_test, y_test)
print(score1)

2.2 删除前模型

?? 查看模型结构代码

model.summary()

?? 模型结构

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 30)                150       
                                                                 
 dropout_3 (Dropout)         (None, 30)                0         
                                                                 
 dense_13 (Dense)            (None, 30)                930       
                                                                 
 dense_14 (Dense)            (None, 3)                 93        
                                                                 
=================================================================
Total params: 1,173
Trainable params: 1,173
Non-trainable params: 0
_________________________________________________________________

?? 模型预测代码及结果

model.predict([[5.1, 3.5, 1.4, 0.2]])
# 输出结果为:array([[0.15257263, 0.56145775, 0.2859696 ]], dtype=float32)

2.3 删除最后一层代码

?? 删除代码

model.pop()

2.4 删除后结果

?? 查看删除后模型结构

model.summary()

?? 删除后模型结构

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 30)                150       
                                                                 
 dropout_3 (Dropout)         (None, 30)                0         
                                                                 
 dense_13 (Dense)            (None, 30)                930       
                                                                 
=================================================================
Total params: 1,080
Trainable params: 1,080
Non-trainable params: 0
_________________________________________________________________

?? 删除最后一层后模型预测代码及结果

model.save("test_model.h5")
new_model = tf.keras.models.load_model("test_model.h5")
new_model.predict([[5.1, 3.5, 1.4, 0.2]])

?? 输出结果

array([[0.6801046 , 0.        , 1.6479366 , 2.1548553 , 0.        ,
        0.        , 0.        , 0.        , 0.3491744 , 1.1407335 ,
        2.2483244 , 0.        , 0.71208394, 0.23256747, 0.        ,
        0.        , 1.1044086 , 1.492352  , 0.        , 0.        ,
        1.6636133 , 0.0603824 , 0.02222613, 0.        , 0.        ,
        1.625694  , 0.        , 0.        , 0.        , 3.2165537 ]],
      dtype=float32)

3 总结

?? pop()函数是一层层删除训练后的模型,从后往前删除,就像栈的机制。

4 参考资料

📗 知乎连接:keras可以只保存模型前几层,而不是整个模型及权重值吗?

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-12 17:27:53  更:2022-03-12 17:29:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 18:42:37-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计