IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 交叉熵损失计算过程 -> 正文阅读

[人工智能]交叉熵损失计算过程

声明

1,本文整体偏向小白风。
2,尽量少贴公式,就讲下原理。我觉得讲清交叉熵根本不需要一堆公式和各种术语。

前言

交叉熵损失常用于分类任务。
优点是误差较大时,学习速度较快。
本文以pytorch中自带的实现函数为依据,解释下交叉熵损失的计算过程。

二分类任务单样本

以minst数据集识别为例,就是一个典型的多分类任务。
经过网上搜索,一通复制黏贴,网络计算,最终输出维度应该是10(对应十分类,下文用out指代输出)。此处,先简化下问题,假设现在只识别0和1,将问题简化为二分类任务。
那损失函数的入参就是out和label,label就是样本的标签。
再简化下例子,假设现在就一个样本。也就是说,现在任务成了,二分类任务的一个样本的输出。那输出应该是类似这样的

tensor[[0, 1.0]]

这个输出并不是概率值,而是计算值。那首先就需要将其归一化到0~1之间。

exp(0) / (exp(0)+exp(1.0))
exp(1.0) / (exp(0)+exp(1.0))

其实就是以自然数e为底数,做次方运算。
假设这个样本的标签是0
根据二分类交叉熵损失公式(百度下,网上很多)

-ln(yi*pi+(1-yi)(1-pi))
yi--第i个样本的标签值,此处假设的单样本,标签为0;
pi--第i个样本预测为正样本(也就是预测为1)的概率值,也就是--exp(1) / (exp(0)+exp(1.0))=2.7183/(1+2.7183)=0.7311;
此时,yi=0,则yi*pi+(1-yi)(1-pi) = 1-pi=1-0.7311=0.2689。
ln0.2689 = -1.313,取负值就是1.313。

以上结果可自行用pytorch验证下。

伪代码:cross(tensor[[0, 1.0]], tensor[0])

二分类多样本

好了,二分类任务的单样本,这个最简单的例子就完成了。
多样本呢?无非是计算所有样本后取平均值。easy,一笔带过。

多分类任务呢

网上也能找到多分类交叉熵损失函数计算公式。
但是,我觉得可以简化一下,多分类单样本的损失计算公式为:

-ln(pi[label-index])

也就是说,以标签值的index作为pi的index,从概率值数组拿出对应的概率值进行计算。
其实就是官方公式的变种。因为,官方公式中的Yi,k除了在标签值时取值为1,其余为0。那累加项其实就一项。
用人话说就是,假如一个样本要预测0-9,实际标签是0,那计算结果是0-9的十个概率值,可我肯定关心的是输出为0的概率值。那我只要把输出为0的概率值拿来计算就行了。
当然,官方公式那样写多半有其用意。随着学习深入,随缘探究吧。

小结

整体步骤归纳下:
1,对输出做归一化。貌似好像这一步就是softmax;
2,对标签值对应的概率值取对数,再取负值。
3,对各样本损失值取平均数;

扩展

计算过程很简单。
扣下细节。为什么要对概率值取对数,再取负数呢。

1,先说取负数。由于概率值永远小于1,softmax不可能计算出概率值等于1的。因为e的次方算不出0。进而,ln的曲线可以自己在纸上划一下,x取值小于1的情况下,lnx输出永远小于0。且x值越小,则计算结果越小,其绝对值越大。那我们肯定希望对实际标签的预测概率值越小,则误差越大。那取个负值,相当于取了其绝对值,也就顺理成章。
2,为什么要取对数。前面说了,交叉熵的优点就是“误差较大时,学习速度较快”。我们肯定希望离终点很远的时候步子能大一点,也就是误差计算结果能大一点。那只要把上面第一点说的那张图画出来,你看下x轴的步长映射都y轴上的步长变化,也就明白了–实现这一目的是利用了对数曲线的特性。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-29 12:08:38  更:2022-04-29 12:09:06 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:52:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码