IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TensorFlow_01_CNN实现图像分类 -> 正文阅读

[人工智能]TensorFlow_01_CNN实现图像分类

导入相关的包。

# 导入tensorflow
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

防止cudnn报错。


physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)

下载并准备数据集

下载并准备 CIFAR10 数据集CIFAR10 数据集包含 10 类,共 60000 张彩色图片,每类图片有 6000 张。此数据集中 50000 个样例被作为训练集,剩余 10000 个样例作为测试集。 类之间相互独立,不存在重叠的部分

(train_images, train_labels), (test_images, test_labels) = \
    datasets.cifar10.load_data()
# 将像素值归一化到0和1之间
train_images, test_images = train_images / 255.0, test_images / 255.0

验证数据

为了验证数据集看起来是否正确,我们绘制训练集样本的前25张图像并在每张图像的前25张图像并在每张图像的下方显示类别名称:

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
               'frog', 'horse', 'ship', 'truck']

将图像的标签图像化出来。

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i])

    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

图像分类的结果如下所示:
在这里插入图片描述

构造卷积神经网络模型

下方展示的6行代码声明了一个常见的卷积神经网络,由几个Conv2D和MaxPooling2D层组成.cnn的形状为(image_height,image_width,color_channels)的张量作为输入,忽略批次大小,我们将cnn的形状处理成(32,32,3)的输入,即CIFAR的图像格式. 通过将参数input_shape传递给第一层来实现此目的.

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

在上面的结构中,可以看到每个 Conv2D 和 MaxPooling2D 层的输出都是一个三维的张量 (Tensor),其形状描述了 (height, width, channels)。越深的层中,宽度和高度都会收缩。每个 Conv2D 层输出的通道数量 (channels) 取决于声明层时的第一个参数(如:上面代码中的 32 或 64)。这样,由于宽度和高度的收缩,便可以(从运算的角度)增加每个 Conv2D 层输出的通道数量 (channels)。

增加 Dense 层

为了完成模型,您需要将卷积基(形状为 (4, 4, 64))的最后一个输出张量馈送到一个或多个 Dense 层以执行分类。Dense 层将向量作为输入(即 1 维),而当前输出为 3 维张量。
首先,将 3 维输出展平(或展开)为 1 维,然后在顶部添加一个或多个 Dense 层。CIFAR 有 10 个输出类,因此使用具有 10 个输出的最终 Dense 层。

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
# 网络摘要显示 (4, 4, 64) 输出在经过两个 Dense 层之前被展平为形状为 (1024) 的向量。

编译并训练模型

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=20,
                    validation_data=(test_images, test_labels))

评估模型

将训练的轮次和精确度之间的关系可视化出来

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

在这里插入图片描述
训练的过程如下所示:

1563/1563 [==============================] - 12s 5ms/step - loss: 1.7593 - accuracy: 0.3453 - val_loss: 1.2124 - val_accuracy: 0.5645
Epoch 2/20
1563/1563 [==============================] - 5s 3ms/step - loss: 1.1851 - accuracy: 0.5807 - val_loss: 1.0649 - val_accuracy: 0.6252
Epoch 3/20
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0019 - accuracy: 0.6482 - val_loss: 1.0100 - val_accuracy: 0.6387
Epoch 4/20
1563/1563 [==============================] - 5s 4ms/step - loss: 0.8828 - accuracy: 0.6892 - val_loss: 0.9207 - val_accuracy: 0.6747
Epoch 5/20
1563/1563 [==============================] - 5s 4ms/step - loss: 0.7989 - accuracy: 0.7174 - val_loss: 0.9033 - val_accuracy: 0.6845
Epoch 6/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.7364 - accuracy: 0.7418 - val_loss: 0.8639 - val_accuracy: 0.7010
Epoch 7/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.6762 - accuracy: 0.7635 - val_loss: 0.8704 - val_accuracy: 0.7051
Epoch 8/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.6319 - accuracy: 0.7763 - val_loss: 0.8738 - val_accuracy: 0.7019
Epoch 9/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.5740 - accuracy: 0.7976 - val_loss: 0.8616 - val_accuracy: 0.7125
Epoch 10/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.5305 - accuracy: 0.8131 - val_loss: 0.9049 - val_accuracy: 0.7179
Epoch 11/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.4818 - accuracy: 0.8296 - val_loss: 0.9193 - val_accuracy: 0.7085
Epoch 12/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.4470 - accuracy: 0.8411 - val_loss: 0.9930 - val_accuracy: 0.7021
Epoch 13/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.4080 - accuracy: 0.8545 - val_loss: 1.0348 - val_accuracy: 0.6978
Epoch 14/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.3666 - accuracy: 0.8696 - val_loss: 1.0691 - val_accuracy: 0.6992
Epoch 15/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.3352 - accuracy: 0.8810 - val_loss: 1.0883 - val_accuracy: 0.7075
Epoch 16/20
1563/1563 [==============================] - 7s 4ms/step - loss: 0.3126 - accuracy: 0.8894 - val_loss: 1.1247 - val_accuracy: 0.7076
Epoch 17/20
1563/1563 [==============================] - 6s 4ms/step - loss: 0.2762 - accuracy: 0.9009 - val_loss: 1.2253 - val_accuracy: 0.7074
Epoch 18/20
1563/1563 [==============================] - 7s 4ms/step - loss: 0.2590 - accuracy: 0.9082 - val_loss: 1.3303 - val_accuracy: 0.6965
Epoch 19/20
1563/1563 [==============================] - 7s 4ms/step - loss: 0.2447 - accuracy: 0.9128 - val_loss: 1.3879 - val_accuracy: 0.6945
Epoch 20/20
1563/1563 [==============================] - 7s 4ms/step - loss: 0.2122 - accuracy: 0.9252 - val_loss: 1.4484 - val_accuracy: 0.6920
313/313 - 1s - loss: 1.4484 - accuracy: 0.6920
test_acc 0.6919999718666077
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-24 15:33:09  更:2021-08-24 15:36:10 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 22:58:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码