IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 正则化无法解决神经网络过拟合 怎么办!? -> 正文阅读

[人工智能]正则化无法解决神经网络过拟合 怎么办!?

1、背景描述

这是一个“简单的”回归问题,为了锻炼自己的炼丹能力,我把函数关系搞得特别复杂:

x_raw = torch.rand((1000, 10)) * 2
y_raw = torch.cat((x_raw[:, [0]] * x_raw[:, [1]],
                   -x_raw[:, [2]] * x_raw[:, [3]],
                   x_raw[:, [9]] / x_raw[:, [8]] - x_raw[:, [7]],
                   -x_raw[:, [6]].abs() ** 0.5 + x_raw[:, [5]] + 2,
                   -x_raw[:, [4]] / x_raw[:, [0]]), dim=1).sum(dim=1, keepdim=True)

你看这是一个十变量函数
y = x 0 x 1 ? x 2 x 3 + x 9 x 8 ? x 7 ? ∣ x 6 ∣ + x 5 + 2 ? x 4 x 0 y=x_0x_1-x_2x_3+\frac{x_9}{x_8}-x_7-\sqrt{|x_6|}+x_5+2-\frac{x_4}{x_0} y=x0?x1??x2?x3?+x8?x9???x7??x6? ?+x5?+2?x0?x4??
按理来说神经网络能够拟合任意函数,只要网络够深,参数够多。但偏偏我昨天训练的时候就发生了过拟合!
网络结构是四层MLP,毕竟这种回归问题应该用不到什么fancy的layer。
在这里插入图片描述
这边train_acc和vali_acc都是 A d j u s t e d ? R 2 \mathrm{Adjusted}\ R^2 Adjusted?R2,可以看到vali_acc甚至出现了负值,而train_acc比vali_acc高得多
过拟合咋办呢,常见的方法我都试了,权重衰减、Dropout、BatchNorm、flooding、提前终止,但没什么用,比方说稍微改大一点权重衰减的系数,就出现了欠拟合!

2、原因探析

  • 深入想想,为啥会过拟合?为啥权重衰减可以抑制过拟合?不正是因为权重太大了,要控制权重再一定范围内吗?一个过大的权重,会使得模型对输入变得十分敏感——输入稍微改动一下,可能输出就会有很大的波动。因此才会想到要权重衰减。
    进一步,那为什么神经网络会学到很大的权重?我想可能是因为回归标签本来就很大吧,一个值为1000的标签,如果网络权重都很小,输入又做了标准化(归一化),那怎么可能预测出1000这么大的数呢,所以至少在网络最后一层权重会比较大,这样Loss也会比较大,进而导致梯度大,这又带来梯度爆炸的问题了。
  • 另一方面,数据集可能不是很好,会出现一些outliers异常值,但网络并不知道那些是异常值啊,为了使预测接近异常值,预测出的曲线(单个特征)或曲面(两个特征)自己波动就很大,那就很容易过拟合了,也会表现出对输入的敏感性。

3、问题解决

考虑到上面两点,我们要对标签值进行一些预处理。我以前是只对样本特征预处理的,但是这次也要preprocess标签。
相应地,一方面,要控制标签整体的范围,不能太大,因此做标准化处理。

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
y = torch.from_numpy(scaler.fit_transform(y_raw)).type(torch.float32)

P.S. 在这里插入图片描述
另一方面,删除异常值,这里异常值的判断标准是与均值的距离超出1.96个标准差。

y_range = (y.mean() - 1.96 * y.std(), y.mean() + 1.96 * y.std())
y_normal_ind = ((y < y_range[1]) & (y > y_range[0])).flatten()
y = y[y_normal_ind]
x = x[y_normal_ind, :]

在这里插入图片描述
可以对比一下处理前后的标签值分布~
处理前,标签里有几个零星的绝对值为几百的数以及不少绝对值为十几的数,当然大部分数还是集中在2附近,处理过后之后范围大大缩小了。

4、改进效果

在这里插入图片描述
训练比之前稳定多了~
在这里插入图片描述

训练集、验证集、测试集上预测效果都还不错。

5、数据标签真的很丑!

后来我进一步分析,为什么数据标签会这么丑呢?这肯定跟函数关系式有关,我的关系式中有两个分式,而所有的变量服从(-2,2)的均匀分布,难保分母不会有一个很接近于0的值,进而导致该项很大很大,从而使得标签在某个点的值突然上升或突然下降。
在这里插入图片描述
上图中,(1,1),(1,2),(2,1),(2,2),(3,1)分别是函数第1、2、3、4、5项,可以看到,分式项分布范围很广,甚至可以达到三四百这么大的值,神经网络学习有多么困难,可想而知。
我把除法运算都改成了乘法运算后,分布都在差不多的较小范围里了。
在这里插入图片描述

再次训练网络,调整R2达到了0.99以上!!!啊啊啊!
在这里插入图片描述
好家伙,不愧是我!

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-11 12:23:47  更:2021-08-11 12:26:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 20:28:48-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码