IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 学习Tensorflow2官方Demo——Lenet,以及遇到的问题 -> 正文阅读

[人工智能]学习Tensorflow2官方Demo——Lenet,以及遇到的问题

前言

??TensorFlow是一个面向所有开发人员的开源机器学习框架。 它用于实现机器学习和深度学习应用程序。为了开发和研究有关人工智能,Google团队创建了TensorFlow。 TensorFlow是使用Python编程语言设计的,因此它是一个易于理解的框架。
??Tensorflow Tensor的通道排序:[batch, height, width, channel]

1.官方Demo的项目目录

在这里插入图片描述

2.模型

代码:

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu') #卷积层:卷积核的个数为32,卷积核的大小为3*3,stride默认为1,激活函数为relu
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')#全连接层:结点个数为128
        self.d2 = Dense(10, activation='softmax')

    def call(self, x, **kwargs): #网络的正向传播过程
        x = self.conv1(x)      # input[batch, 28, 28, 1] output[batch, 26, 26, 32]
        x = self.flatten(x)    # output [batch, 21632]
        x = self.d1(x)         # output [batch, 128]
        return self.d2(x)      # output [batch, 10]

3.训练

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from model import MyModel


def main():
    mnist = tf.keras.datasets.mnist #一个手写数据集

    # download and load data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # Add a channels dimension
    #载入的手写数据图片只有宽度+高度信息,故再添加一个深度信息
    x_train = x_train[..., tf.newaxis]
    x_test = x_test[..., tf.newaxis]

    # create data generator
    #载入数据的方式是将图像和它所对应的标签以一个元组的形式传入from_tensor_slices((x_train, y_train))
    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(10000).batch(32)
    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

    # create model
    model = MyModel()

    # define loss
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    # define optimizer
    optimizer = tf.keras.optimizers.Adam()

    # define train_loss and train_accuracy
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    # define train_loss and train_accuracy
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

    # define train function including calculating loss, applying gradient and calculating accuracy
    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images)
            loss = loss_object(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, predictions)

    # define test function including calculating loss and calculating accuracy
    @tf.function
    def test_step(images, labels):
        predictions = model(images)
        t_loss = loss_object(labels, predictions)

        test_loss(t_loss)
        test_accuracy(labels, predictions)

    EPOCHS = 5

    for epoch in range(EPOCHS):
        train_loss.reset_states()        # clear history info
        train_accuracy.reset_states()    # clear history info
        test_loss.reset_states()         # clear history info
        test_accuracy.reset_states()     # clear history info

        for images, labels in train_ds:
            train_step(images, labels)

        for test_images, test_labels in test_ds:
            test_step(test_images, test_labels)

        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(epoch + 1,
                              train_loss.result(),
                              train_accuracy.result() * 100,
                              test_loss.result(),
                              test_accuracy.result() * 100))


if __name__ == '__main__':
    main()

运行结果:
在这里插入图片描述

4.遇到的问题

  1. 没有安装CUDA和CUDNN,导致代码不能正常执行
    解决办法:👉指路

  2. 出现Could not load dynamic library ‘cudart64_110.dll‘; dlerror: cudart64_110.dll not found 的问题
    解决办法:👉指路

  3. 后台提示:由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败;
    解决办法:👉指路

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-05 17:21:26  更:2021-08-05 17:21:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/7 22:49:14-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码