IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> tensorflow2 交叉验证 -> 正文阅读

[人工智能]tensorflow2 交叉验证

作者:recommend-item-box type_download clearfix

交叉验证在fit()函数的参数里边,完整参数传送

https://blog.csdn.net/Forrest97/article/details/106635664

fit()里边相关交叉验证的参数

validation_data=test,就是自己划分好的测试集

validation_steps, 验证样本总数 Total validation Samples/验证样本大小Validation Batch Size,多少组验证样本的数据后代入网络验证,可能每次输出都能看到验证结果,val_loss 和 val_accuracy 会输出验证精度

validation_batch_size:一般是32

validation_freq:仅当validation_data设置时有效,表示训练完几组epoch后,进行验证

先给一段代码,以下代码的数据集是60000张图片,50000张来训练集,10000张来测试集。其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值,一共有10中分类结果:飞机、汽车、鸟、猫咪、鹿子、狗子、小青蛙、马儿、船、大卡车。

先用DNN建模后,存储模型,再重新调用模型。

import keras
import tensorflow as tf
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics

def preprocess(x,y):
    x = 2*tf.cast(x,dtype=tf.float32)/225.-1.
    y = tf.cast(y,dtype=tf.int32)
    return x,y

batch_size = 128
tf.random.set_seed(1)
(x,y),(test_x,test_y) = datasets.cifar10.load_data()
y = tf.squeeze(y)
test_y = tf.squeeze(test_y)
y = tf.one_hot(y,depth=10)
test_y = tf.one_hot(test_y,depth=10)

train = tf.data.Dataset.from_tensor_slices((x,y))
train = train.map(preprocess).shuffle(60000).batch(batch_size=batch_size)
test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
test = test.map(preprocess).batch(batch_size=batch_size)

sample = next(iter(train))

class mydense(layers.Layer):
    def __init__(self,input_dim,output_dim):
        super(mydense, self).__init__()
        self.kernel = self.add_weight('w',[input_dim,output_dim])
        self.bias = self.add_weight('b',[1,output_dim])
    def call(self,inputs,training = None):
        x = inputs @ self.kernel+self.bias
        return x

class my_network(keras.Model):
    def __init__(self):
        super(my_network,self).__init__()
        self.fc1=mydense(32*32*3,256)
        self.fc2=mydense(256,128)
        self.fc3=mydense(128,64)
        self.fc4=mydense(64,32)
        self.fc5=mydense(32,10)

    def call(self,inputs,training = None):
        x = tf.reshape(inputs,[-1,32*32*3])
        x = self.fc1(x)
        x=tf.nn.relu(x)
        x = self.fc2(x)
        x=tf.nn.relu(x)
        x=self.fc3(x)
        x = tf.nn.relu(x)
        x = tf.nn.relu(self.fc4(x))
        x=self.fc5(x)
        return x
network = my_network()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )
#交叉验证validation_steps : 验证样本总数 Total validation Samples/验证批量大小Validation Batch Size,多少组验证样本的数据后代入网络验证)
#validation_freq:仅当validation_data设置时有效,表示训练完几组epoch后,进行验证。
network.fit(train,epochs=5,validation_data=test,validation_freq=2)
#print("validation_batch_size",network.fit.validation_freq)
network.evaluate(test)
network.save_weights('weights/mynetwork')
del network
print('saved weights')
network = my_network()
network.compile(optimizer=optimizers.Adam(lr=1e-4),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )
network.load_weights('weights/mynetwork')
network.fit(train,epochs=5,validation_data=test,validation_batch_size=80)
network.evaluate(test)

?输出结果,当设置

validation_freq=2时,第一条fit函数输出是隔一次输出测试集的验证结果,就是在epoch/2的整数倍后,验证数据
Epoch 1/5
391/391 [==============================] - 3s 6ms/step - loss: 1.7262 - accuracy: 0.3867
Epoch 2/5
391/391 [==============================] - 4s 9ms/step - loss: 1.4883 - accuracy: 0.4739 - val_loss: 1.4526 - val_accuracy: 0.4891
Epoch 3/5
391/391 [==============================] - 3s 7ms/step - loss: 1.3768 - accuracy: 0.5139
Epoch 4/5
391/391 [==============================] - 4s 10ms/step - loss: 1.2961 - accuracy: 0.5428 - val_loss: 1.4032 - val_accuracy: 0.5021
Epoch 5/5
391/391 [==============================] - 3s 9ms/step - loss: 1.2186 - accuracy: 0.5733
79/79 [==============================] - 0s 5ms/step - loss: 1.3872 - accuracy: 0.5197
saved weights
Epoch 1/5
391/391 [==============================] - 4s 11ms/step - loss: 1.1536 - accuracy: 0.5942 - val_loss: 1.3735 - val_accuracy: 0.5234
Epoch 2/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0945 - accuracy: 0.6120 - val_loss: 1.3713 - val_accuracy: 0.5252
Epoch 3/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0385 - accuracy: 0.6360 - val_loss: 1.3727 - val_accuracy: 0.5346
Epoch 4/5
391/391 [==============================] - 4s 11ms/step - loss: 0.9848 - accuracy: 0.6520 - val_loss: 1.4310 - val_accuracy: 0.5297
Epoch 5/5
391/391 [==============================] - 5s 12ms/step - loss: 0.9356 - accuracy: 0.6687 - val_loss: 1.4377 - val_accuracy: 0.5266
79/79 [==============================] - 0s 5ms/step - loss: 1.4377 - accuracy: 0.5266

当只设置validation_batch_size=88、88、30,第二条fit函数输出是每次都输出val_loss和val_accuracy,且每次结果都一样,如下:

?当只设置validation_steps=30、80,第二条fit函数输出是当30时每次都输出val_loss和val_accuracy,如下:

saved weights
Epoch 1/5
391/391 [==============================] - 3s 9ms/step - loss: 1.1536 - accuracy: 0.5942 - val_loss: 1.3933 - val_accuracy: 0.5185
Epoch 2/5
391/391 [==============================] - 4s 10ms/step - loss: 1.0945 - accuracy: 0.6120 - val_loss: 1.3774 - val_accuracy: 0.5286
Epoch 3/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0385 - accuracy: 0.6360 - val_loss: 1.3880 - val_accuracy: 0.5365
Epoch 4/5
391/391 [==============================] - 4s 11ms/step - loss: 0.9848 - accuracy: 0.6520 - val_loss: 1.4451 - val_accuracy: 0.5326
Epoch 5/5
391/391 [==============================] - 4s 10ms/step - loss: 0.9356 - accuracy: 0.6687 - val_loss: 1.4410 - val_accuracy: 0.5318
79/79 [==============================] - 0s 5ms/step - loss: 1.4377 - accuracy: 0.5266

当设置validation_steps=80,超过验证数据集总数,没有进行验证。

?结论:

比较有效果的参数设置:

validation_data和validation_batch_size一起用

或者validation_data和validation_freq一起用

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-27 16:13:42  更:2021-07-27 16:14:27 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 10:04:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码