IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 由mnist引发的思考,pytorch中的交叉熵误差函数nn.CrossEntropy做了什么? -> 正文阅读

[人工智能]由mnist引发的思考,pytorch中的交叉熵误差函数nn.CrossEntropy做了什么?

引入

在MNIST手写体实验中,关于在交叉熵损失函数计算误差时,神经网络输出为10个,当标签设置为何种情况时才能满足交叉熵损失函数的计算公式,来探究这个问题。

实验一

直接打印出每个数据的标签内容

代码如下:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([  # 设置预处理的方式,里面依次填写预处理的方法
    transforms.ToTensor(),  # 将数据转换为tensor对象
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
if __name__ == '__main__':
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
    for i, data in enumerate(trainloader, 0):
        data_, label = data
        print("label:"+str(label.numpy()))
        if i == 10:
            break

label:[6]
label:[7]
label:[2]
label:[3]
label:[5]
label:[5]
label:[7]
label:[6]
label:[1]
label:[6]
label:[9]

从上面的实验看出label就是该手写体所代表的数字,那么nn.CrossEntorpy是如何对只有一个值的tensor进行计算的?

实验二

经过查阅资料,得出label中的tensor代表的是所有预测种类中的第几类,例如上面实验从的第一个label:[6]就代表为第6类,即该手写体图片的正确答案是5


该函数的数学公式可以写为:
KaTeX parse error: No such environment: align at position 7: \begin{?a?l?i?g?n?}?loss(x,y)&=-\lo…
其中x代表损失函数式的输入,y代表target(或标签)中所代表的类别,C为所有的类别数量。

import numpy as np
import torch
from torch import nn
from torchvision import datasets, transforms

transform = transforms.Compose([  # 设置预处理的方式,里面依次填写预处理的方法
    transforms.ToTensor(),  # 将数据转换为tensor对象
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
if __name__ == '__main__':
    inputs = torch.Tensor([[-0.5, -0.2, -0.3]])
    target = torch.tensor([0])
    criterion = nn.CrossEntropyLoss()
    output = criterion(inputs, target)
    print("通过nn.CrossEntropy计算的结果", output)  # 使用nn.CrossEntropy函数
    # 直接计算
    my_out = -inputs[:, 0] + np.log(torch.sum(torch.exp(inputs[0:1])))
    print("通过公式计算的结果:", my_out)
    # 通过logSoftmax和NLLose计算的结果
    log_softmax_function = nn.LogSoftmax(dim=1)
    loss = nn.NLLLoss()
    logSoftmax_NLLose_output = loss(log_softmax_function(inputs), target)
    print("通过LogSoftmax和NLLLose函数计算的结果", logSoftmax_NLLose_output)

通过nn.CrossEntropy计算的结果 tensor(1.2729)
通过公式计算的结果: tensor([1.2729])
通过LogSoftmax和NLLLose函数计算的结果 tensor(1.2729)

结论

PYTORCH中的CrossEntropy函数结合了取log的softmax函数和NLLOSE误差函数来计算loss,label中只用给出结果所分的类别编号即可。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-13 21:47:25  更:2022-03-13 21:49:14 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/28 14:37:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码