IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Dropout层的个人理解和具体使用 -> 正文阅读

[人工智能]Dropout层的个人理解和具体使用

一、Dropout层的作用

??dropout 能够避免过拟合,我们往往会在全连接层这类参数比较多的层中使用dropout;在训练包含dropout层的神经网络中,每个批次的训练数据都是随机选择,实质是训练了多个子神经网络,因为在不同的子网络中随机忽略的权重的位置不同,最后在测试的过程中,将这些小的子网络组合起来,类似一种投票的机制来作预测,有点类似于集成学习的感觉。

??关于dropout,有nn.Dropout和nn.functional.dropout两种。推荐使用nn.xxx,因为一般情况下只有训练train时才用dropout,在eval不需要dropout。使用nn.Dropout,在调用model.eval()后,模型的dropout层和批归一化(batchnorm)都关闭,但用nn.functional.dropout,在没有设置training模式下调用model.eval()后不会关闭dropout。
??这里关闭dropout等的目的是为了测试我们训练好的网络。在eval模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在train训练阶段已经学出的mean和var值。同时我们在用模型做预测的时候也应该声明model.eval()。

注??:为了进一步加速模型的测试,我们可以设置with torch.no_grad(),主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止梯度gradient计算和储存,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为,这样我们可以使用更大的batch进行测试。

model.eval()下不启用 Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

二、Dropout层的使用方法

??以下是nn.Dropout和nn.functional.dropout两种具体的使用方法:

class Dropout1(nn.Module):
   def __init__(self):
       super(Dropout1, self).__init__()
       self.fc = nn.Linear(100,20)
 
   def forward(self, input):
       out = self.fc(input)
       out = F.dropout(out, p=0.5, training=self.training)  # 这里必须给traning设置为True
       return out
# 如果设置为F.dropout(out, p=0.5)实际上是没有任何用的, 因为它的training状态一直是默认值False. 由于F.dropout只是相当于引用的一个外部函数, 模型整体的training状态变化也不会引起F.dropout这个函数的training状态发生变化. 所以,在训练模式下out = F.dropout(out) 就是 out = out. 
Net = Dropout1()
Net.train()

#或者直接使用nn.Dropout() (nn.Dropout()实际上是对F.dropout的一个包装, 自动将self.training传入,两者没有本质的差别)
class Dropout2(nn.Module):
  def __init__(self):
      super(Dropout2, self).__init__()
      self.fc = nn.Linear(100,20)
      self.dropout = nn.Dropout(p=0.5)
 
  def forward(self, input):
      out = self.fc(input)
      out = self.dropout(out)
      return out
Net = Dropout2()
Net.train()

三、RNN中使用Dropout(进阶)

??偶然间我在一个项目中发现其在LSTM上下层连接的时候使用了dropout来控制模型的过拟合;通过我查阅了一些资料发现dropour只能用于特定的rnn连接上,在gru上边会出现错误,其具体的原因来自于模型的具体结构问题,具体详情可以参考:https://www.jianshu.com/p/be34e53d54e6

# https://github.com/allenai/allennlp/blob/master/allennlp/modules/input_variational_dropout.py
class RNNDropout(nn.Dropout):
    """
    Dropout layer for the inputs of RNNs.
    Apply the same dropout mask to all the elements of the same sequence in
    a batch of sequences of size (batch, sequences_length, embedding_dim).
    """

    def forward(self, sequences_batch):
        """
        Apply dropout to the input batch of sequences.
        Args:
            sequences_batch: A batch of sequences of vectors that will serve
                as input to an RNN.
                Tensor of size (batch, sequences_length, emebdding_dim).
        Returns:
            A new tensor on which dropout has been applied.
        """
        ones = sequences_batch.data.new_ones(sequences_batch.shape[0],
                                             sequences_batch.shape[-1])
        dropout_mask = nn.functional.dropout(ones, self.p, self.training,
                                             inplace=False)
        return dropout_mask.unsqueeze(1) * sequences_batch 
        # 这样就会随机mask掉词向量中的部分参数


以上是我对dropout的一些个人理解,欢迎大家补充或提出见解👏

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-25 12:12:14  更:2021-08-25 12:12:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 22:42:17-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码