| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> PyTorch实现Mnist手写数字识别 -> 正文阅读 |
|
[人工智能]PyTorch实现Mnist手写数字识别 |
首先下载读取Mnist数据集
随机查看数据
? ?将数据转换为tensor张量形式
使用nn.moduel构建网络,torch.nn.functional中有很多功能,也会很常用。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些
打印构建的网络net,可以看到有两个隐藏层,一个输出层。输出层的输出特征是10而不是1,因为这是一个十分类的网络,会对每一个类都输出一个概率,因此输出的是一个由10个概率组成的一维矩阵。例如[0,0.1,0.05,0,0,0,0,0,0.85,0],此输出就代表此输入图像为8的概率为0.85,为1的概率为0.1,为3的概率为0.05,其余数字的概率均为0。 ? ?构建网络时,已经自动进行了权重以及偏置的初始化,可以用下面的代码进行打印。nn.moduel构建网络有如下特点。
??打印定义好名字里的权重和偏置项 ?使用tenordataset和dataloader来简化batch_size需要编写的代码,调用这两个工具包即可完成batch_size的数据拆分。具体代码如下:
接下来写训练函数fit,其中loss_batch用于每一个batch的损失值计算。除此之外,一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout;测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout。
最后的准备工作是写get_model,导入网络模型。
然后三行代码完成手写数字的识别
心得:通过这个简单的网络熟悉PyTorch的神经网络编写过程,这个代码其实更注重调用,并不是完全按照前向传播后向传播的顺序一步一步构建一个网络,而是写了很多函数,最后主要的三行代码就完成了网络的训练。? ? ? ? ? |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 | -2024/11/27 12:39:22- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |