IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch学习(五)——欠拟合和过拟合 -> 正文阅读

[人工智能]Pytorch学习(五)——欠拟合和过拟合

本文主要介绍模型训练过程中出现的欠拟合和过拟合问题,以及进行模型选和处理过拟合的一般方法。

训练误差和泛化误差

通俗来讲,训练误差(training error)指模型在训练数据集上表现出的误差。泛化误差(generalization error)指的是模型在任意一个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。可以通过损失函数来计算训练误差和泛化误差。

在机器学习里,我们通常假设训练数据集和测试数据集里的每一个样本都是从同一个概率分布中相互独立地生成的。基于独立同分布假设,给定任意一个机器学习模型,它的训练误差和泛化误差都是一样的。如果我们将模型参数设置成随机值,那么泛化误差和训练误差会非常接近。通过在训练数据集上训练模型,我们能够得到一组最小化训练误差的权重参数。所以,训练误差的期望小于或等于泛化误差。也就是说,一般情况下,由训练数据集学到的模型参数会使模型在训练数据集上的表现优于或等于在测试数据集上的表现。在训练的过程中,应当关注降低泛化误差,然而,由于无法从训练误差估计泛化误差,一味的降低训练误差并不意味着泛化误差一定会降低。

欠拟合和过拟合

模型训练中经常出现两类典型问题:一类是模型无法得到较低的训练误差,我们将这一现象称之为欠拟合(underfitting)。另一类是模型的训练误差远小于他在测试数据集上的误差,我们称该现象为过拟合(overfitting)。

欠拟合可理解为模型对训练数据的特征提取不充分,没有学习到数据背后的规律,或者评判标准过于宽松,导致模型在测试数据集上无法做出正确判断。表现为:训练误差和泛化误差都相对较高。

过拟合可理解为模型对特征信息提取过多,把数据噪声当作规律学习,评判标准过于严格。表现为:训练误差低,泛化误差高。

西瓜书中的这张图片,能够很直观的介绍欠拟合和过拟合:

欠拟合和过拟合
在实践中,要尽可能同时应对欠拟合和过拟合。虽然有很多因素可能导致这两种拟合问题,但是我们这里重点讨论两个因素:模型复杂度和训练数据集大小。

当样本特征较少、模型复杂度较低时,对样本的特征提取不够充分,就可能导致欠拟合问题。

当数据集质量不高、噪声较大、训练样本数较少,或是模型复杂度较高、参数过多,就会导致学习到的特征并不普遍适用,模型高度拟合训练数据,出现过拟合问题。

模型复杂度

以多项式函数拟合为例,给定一个由标量数据特征x和对应的标量标签y组成的训练数据集,多项式函数拟合的的目标是找一个K阶多项式函数:
在这里插入图片描述
高阶多项式函数模型参数更多,模型函数的选择空间更大,所以高阶多项式函数比低阶多项式函数的复杂度更高。因此,高阶多项式函数比低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。模型复杂度和误差之间的关系通常如下图所示。给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型的复杂度过高,则容易出现过拟合。因此需要针对数据集选择合适复杂度的模型。
在这里插入图片描述

训练数据集大小

影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,过拟合更容易发生。此外,泛化误差不会随着训练数据集里的样本数量增加而增大。因此,在计算资源允许的范围内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时。

处理欠拟合和过拟合的方法

上文提到,增大训练数据集可能会减轻过拟合,但是获取额外的训练数据往往代价高昂。在训练数据集固定的情况下,一般使用权重衰减和丢弃法来解决过拟合问题。

权重衰减

权重衰减等价于L2 范数正则化,通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。

L2 范数正则化在模型原损失函数的基础上添加L2 范数惩罚项,从而得到训练所需要最小化的函数。L2 范数惩罚项指的是模型权重参数每个元素的平方和与一个正的常数的乘积,假设有如下线性回归损失函数:
在这里插入图片描述
将权重参数用向量 w = [ w1 , w2 ] 表示,带有 L2 范数惩罚项的新损失函数为:
在这里插入图片描述
上式中L2 范数平方 ||w||2 展开后得到 w12 + w22,其中超参 λ > 0。当权重参数均为0时,惩罚项最小。当 λ 较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素接近0.当 λ 设置为0时,惩罚项完全不起作用。有了L2 范数惩罚项之后,在小批量随机梯度下降中,权重w1和w2的迭代方式更改为:
在这里插入图片描述
可见,L2 范数正则化令权重 w12 和 w22 先自乘小于1的数,再减去不含惩罚项的梯度。因此L2 范数正则化又叫权重衰减。权重衰减通过惩罚绝对值较大的模型参数为需要学习的模型增加了限制,这可能对过拟合有效。

权重衰减实现

在构造优化器时,通过weight_decay参数来指定权重衰减超参,默认下,Pytorch会对权重和偏差同时衰减。我们可以分别对权重和偏差构造优化器实例,从而只对权重衰减。

    net = nn.Linear(num_inputs, 1)
    nn.init.normal_(net.weight, mean=0, std=1)
    nn.init.normal_(net.bias, mean=0, std=1)

    optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd)
    optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)

丢弃法(dropout)

除了权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题。

假设有输入个数为4,单隐藏层,隐藏单元个数为5,且隐藏单元hi ( i = 1, … , 5) 的计算表达式为:

在这里插入图片描述
这里?是激活函数,x1, … , x4 是输入,隐藏单元i的权重参数为w1i, …, w4i, 偏差参数为bi。当对该隐藏层使用丢弃法时,该层的隐藏单元将有一定概率被丢弃掉。设丢弃概率为 p ,那么有 p 的概率 hi 会被清零,有 1 - p 的概率 hi 会除以 1 - p 做拉伸。丢弃概率是丢弃法的超参数。具体来说,设随机变量 ξi 为0和1的概率分别为 p1 - p。使用丢弃法时我们计算新的隐藏单元:

在这里插入图片描述
对上述隐藏层使用丢弃法,一种可能的结果如下图所示。其中h2h5 被清零。这时输出值的计算不再依赖h2h5 ,在反向传播时,与这两个隐藏层单元相关的权重梯度均为0。由于在训练中隐藏层单元的丢弃是随机的,即h1,…,h5 中的任何一个都有可能被清零,输出层的计算无法过度依赖 h1,…,h5 中的任一个,从而在模型训练时起到正则化的作用,并可以用来应对过拟合。注意,在模型推理时,不使用dropout。
在这里插入图片描述

dropout实现

手动实现:

def dropout(X, drop_prob):
    X = X.float()
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape) < keep_prob).float()

    return mask * X / keep_prob

Pytorch实现:

在训练模型时,利用nn模块中的Dropout层实现:

nn.Dropout(0.5)

模型选择

在机器学习中,通常需要评估若干候选模型的表现,并从中选择模型。这一过程称之为模型选择(model selection)。可供选择的候选模型可以是有不同超参数的同类模型。以多层感知机为例,可以选择隐藏层的个数以及每个隐藏层中的隐藏单元个数和激活函数。为了得到有效的模型,需要使用验证数据集来进行模型选择。

K折交叉验证

由于验证数据集不参与模型训练,当训练数据不够用时,预留大量的验证数据集显得太奢侈。通常使用K折交叉验证的方法,通过把原始训练数据集分割成K个不重合的子数据集,然后我们做K次模型训练和验证。每一次,我们使用一个子数据集验证模型,并使用其他K - 1个子数据集来训练模型。在这K次训练和验证中,每次用来验证模型的子数据集都不同。最后我们对这K次训练误差和验证误差分别求平均。

def get_k_fold_data(k, i, features, labels):
    assert k > 1
    fold_size = features.shape[0] // k
    x_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size)
        x_part, y_part = features[idx, :], labels[idx]

        if j == i:
            x_valid, y_valid = x_part, y_part
        elif x_train is None:
            x_train, y_train = x_part, y_part
        else:
            x_train = torch.cat((x_train, x_part), dim=0)
            y_train = torch.cat((y_train, y_part), dim=0)

    return x_train, y_train, x_valid, y_valid
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-17 01:27:50  更:2021-08-17 01:28:22 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 0:57:59-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码