1、Dropout介绍
??Dropout 也是一种用于抵抗过拟合的技术,它试图改变网络本身来对网络进行优化。我 们先来了解一下它的工作机制,当我们训练一个普通的神经网络时,网络的结构可能如图所示。
??Dropout 通常是在神经网络隐藏层的部分使用,使用的时候会临时关闭掉一部分的神经 元,我们可以通过一个参数来控制神经元被关闭的概率,网络结构如图所示。
更详细的流程如下:
- 在模型训练阶段我们可以先给 Dropout 参数设置一个值,例如 0.4。意思是 大约 60%的神经元是工作的,大约 40%神经元是不工作的
- 给需要进行Dropout的神经网络层的每一个神经元生成一个0-1 的随机数(一 般是对隐藏层进行 Dropout)。如果神经元的随机数小于 0.6,那么该神经元就设置为 工作状态的;如果神经元的随机数大于等于 0.6,那么该神经元就设置为不工作的,不工作状态的意思就是不参与计算和训练,可以当这个神经元不存在。
- 设置好一部分神经元工作一部分神经元不工作之后,我们会发现神经网络的输 出值会发现变化,如上图,如果隐藏层有一半不工作,那么网络输出值就会比原来的值要小,因为计算 WX+b 时,如果 W 矩阵中,有一部分的值变成 0,那么最后 的计算结果肯定会变小。所以为了使用 Dropout 的网络层神经元信号的总和不会发生 太大的变化,对于工作的神经元的输出信号还需要除以 0.4。
- 训练阶段重复 1-3 步骤,每一次都随机选择部分的神经元参与训练。
- 在测试阶段所有的神经元都参与计算。
?? Dropout 为什么会起作用呢?这个问题很难通过数学推导来证明。我们在介绍 ReLU 激 活函数的时候有提到过神经网络的信号是冗余的,神经网络在做预测时并不需要隐藏层所有神 经元都工作,只需要一部分隐藏层神经元工作即可。我们可以抽象地来理解 Dropout,当我们 使用 Dropout 的时候,就有点像我们在训练很多不同的结构更简单的神经网络,最后测试阶 段再综合所有的网络结构得到结果。或者另外一种理解方式是我们使用 Dropout 的时候减少 了神经元之间的相互关联,同时强制网络使用更少的特征来做预测,可以增加模型的健壮性。
??除了这两种理解方式之外还可以有其他的很多理解方式,深度学习中很多技巧都是不能用 数学推导得到同时又比较难理解的。但重要的是这些技巧在实际应用中可以帮助我们得到更好 的结果。
??Dropout 比较适合应用于只有少量数据但是需要训练复杂模型的场景,这类场景在图像 领域比较常见,所以 Dropout 经常用于图像领域。
2、Dropout程序
??这里我们而将看到一个Dropout在MNIST数据集识别中的应用,我们建立两个模型,一个使用Dropout,另一个不使用Dropout,对比两个模型的收敛速度。
代码在Jupyter Notebook中调试。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
model1 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh'),
Dropout(0.4),
Dense(units=100,activation='tanh'),
Dropout(0.4),
Dense(units=10,activation='softmax')
])
model2 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh'),
Dropout(0),
Dense(units=100,activation='tanh'),
Dropout(0),
Dense(units=10,activation='softmax')
])
sgd = SGD(0.2)
model1.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
model2.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
epochs = 30
batch_size=32
history1 = model1.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
history2 = model2.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
??训练过程:
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.4173 - accuracy: 0.8728 - val_loss: 0.2200 - val_accuracy: 0.9337
Epoch 2/30
60000/60000 [==============================] - 5s 78us/sample - loss: 0.2786 - accuracy: 0.9171 - val_loss: 0.1616 - val_accuracy: 0.9516
Epoch 3/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.2384 - accuracy: 0.9293 - val_loss: 0.1603 - val_accuracy: 0.9519
Epoch 4/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.2182 - accuracy: 0.9347 - val_loss: 0.1393 - val_accuracy: 0.9577
Epoch 5/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.2014 - accuracy: 0.9400 - val_loss: 0.1257 - val_accuracy: 0.9626
Epoch 6/30
60000/60000 [==============================] - 5s 75us/sample - loss: 0.1881 - accuracy: 0.9453 - val_loss: 0.1236 - val_accuracy: 0.9651
Epoch 7/30
60000/60000 [==============================] - 5s 83us/sample - loss: 0.1748 - accuracy: 0.9483 - val_loss: 0.1107 - val_accuracy: 0.9670
Epoch 8/30
60000/60000 [==============================] - 6s 104us/sample - loss: 0.1683 - accuracy: 0.9494 - val_loss: 0.1131 - val_accuracy: 0.9662
Epoch 9/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1597 - accuracy: 0.9517 - val_loss: 0.1066 - val_accuracy: 0.9677
Epoch 10/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1534 - accuracy: 0.9541 - val_loss: 0.0945 - val_accuracy: 0.9709
Epoch 11/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1511 - accuracy: 0.9547 - val_loss: 0.1054 - val_accuracy: 0.9674
Epoch 12/30
60000/60000 [==============================] - 6s 97us/sample - loss: 0.1481 - accuracy: 0.9548 - val_loss: 0.0930 - val_accuracy: 0.9730
Epoch 13/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1406 - accuracy: 0.9586 - val_loss: 0.0937 - val_accuracy: 0.9707
Epoch 14/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1381 - accuracy: 0.9588 - val_loss: 0.0904 - val_accuracy: 0.9735
Epoch 15/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1348 - accuracy: 0.9597 - val_loss: 0.0934 - val_accuracy: 0.9724
Epoch 16/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1304 - accuracy: 0.9614 - val_loss: 0.0865 - val_accuracy: 0.9747
Epoch 17/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1262 - accuracy: 0.9628 - val_loss: 0.0871 - val_accuracy: 0.9745
Epoch 18/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1255 - accuracy: 0.9628 - val_loss: 0.0856 - val_accuracy: 0.9735
Epoch 19/30
60000/60000 [==============================] - 6s 100us/sample - loss: 0.1248 - accuracy: 0.9616 - val_loss: 0.0826 - val_accuracy: 0.9747
Epoch 20/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1180 - accuracy: 0.9651 - val_loss: 0.0847 - val_accuracy: 0.9752
Epoch 21/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1163 - accuracy: 0.9648 - val_loss: 0.0869 - val_accuracy: 0.9747
Epoch 22/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1171 - accuracy: 0.9650 - val_loss: 0.0813 - val_accuracy: 0.9764
Epoch 23/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1160 - accuracy: 0.9647 - val_loss: 0.0872 - val_accuracy: 0.9746
Epoch 24/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1100 - accuracy: 0.9664 - val_loss: 0.0850 - val_accuracy: 0.9759
Epoch 25/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1095 - accuracy: 0.9671 - val_loss: 0.0815 - val_accuracy: 0.9769
Epoch 26/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1087 - accuracy: 0.9668 - val_loss: 0.0799 - val_accuracy: 0.9774
Epoch 27/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1084 - accuracy: 0.9674 - val_loss: 0.0811 - val_accuracy: 0.9779
Epoch 28/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1055 - accuracy: 0.9683 - val_loss: 0.0794 - val_accuracy: 0.9761
Epoch 29/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.1030 - accuracy: 0.9689 - val_loss: 0.0803 - val_accuracy: 0.9767
Epoch 30/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1036 - accuracy: 0.9682 - val_loss: 0.0770 - val_accuracy: 0.9777
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 99us/sample - loss: 0.2536 - accuracy: 0.9230 - val_loss: 0.1502 - val_accuracy: 0.9537
Epoch 2/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1172 - accuracy: 0.9641 - val_loss: 0.1013 - val_accuracy: 0.9688
Epoch 3/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.0809 - accuracy: 0.9757 - val_loss: 0.1021 - val_accuracy: 0.9659
Epoch 4/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.0598 - accuracy: 0.9816 - val_loss: 0.0958 - val_accuracy: 0.9699
Epoch 5/30
60000/60000 [==============================] - 6s 93us/sample - loss: 0.0457 - accuracy: 0.9857 - val_loss: 0.0867 - val_accuracy: 0.9749
Epoch 6/30
60000/60000 [==============================] - 6s 93us/sample - loss: 0.0353 - accuracy: 0.9892 - val_loss: 0.0729 - val_accuracy: 0.9770
Epoch 7/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0244 - accuracy: 0.9932 - val_loss: 0.0774 - val_accuracy: 0.9762
Epoch 8/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0191 - accuracy: 0.9947 - val_loss: 0.0688 - val_accuracy: 0.9782
Epoch 9/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0141 - accuracy: 0.9966 - val_loss: 0.0946 - val_accuracy: 0.9702
Epoch 10/30
60000/60000 [==============================] - 7s 111us/sample - loss: 0.0097 - accuracy: 0.9978 - val_loss: 0.0704 - val_accuracy: 0.9785
Epoch 11/30
60000/60000 [==============================] - 6s 107us/sample - loss: 0.0058 - accuracy: 0.9991 - val_loss: 0.0629 - val_accuracy: 0.9813
Epoch 12/30
60000/60000 [==============================] - 6s 99us/sample - loss: 0.0043 - accuracy: 0.9995 - val_loss: 0.0684 - val_accuracy: 0.9800
Epoch 13/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0030 - accuracy: 0.9998 - val_loss: 0.0646 - val_accuracy: 0.9808
Epoch 14/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0022 - accuracy: 0.9999 - val_loss: 0.0643 - val_accuracy: 0.9815
Epoch 15/30
60000/60000 [==============================] - 6s 106us/sample - loss: 0.0017 - accuracy: 1.0000 - val_loss: 0.0678 - val_accuracy: 0.9804
Epoch 16/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0660 - val_accuracy: 0.9811
Epoch 17/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.0667 - val_accuracy: 0.9812
Epoch 18/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0670 - val_accuracy: 0.9814
Epoch 19/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.0668 - val_accuracy: 0.9814
Epoch 20/30
60000/60000 [==============================] - 6s 95us/sample - loss: 9.3235e-04 - accuracy: 1.0000 - val_loss: 0.0676 - val_accuracy: 0.9817
Epoch 21/30
60000/60000 [==============================] - 6s 95us/sample - loss: 8.5067e-04 - accuracy: 1.0000 - val_loss: 0.0673 - val_accuracy: 0.9815
Epoch 22/30
60000/60000 [==============================] - 6s 95us/sample - loss: 7.8290e-04 - accuracy: 1.0000 - val_loss: 0.0688 - val_accuracy: 0.9813
Epoch 23/30
60000/60000 [==============================] - 6s 95us/sample - loss: 7.2826e-04 - accuracy: 1.0000 - val_loss: 0.0682 - val_accuracy: 0.9814
Epoch 24/30
60000/60000 [==============================] - 6s 97us/sample - loss: 6.8046e-04 - accuracy: 1.0000 - val_loss: 0.0691 - val_accuracy: 0.9811
Epoch 25/30
60000/60000 [==============================] - 5s 91us/sample - loss: 6.3994e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9812
Epoch 26/30
60000/60000 [==============================] - 5s 91us/sample - loss: 5.9906e-04 - accuracy: 1.0000 - val_loss: 0.0699 - val_accuracy: 0.9812
Epoch 27/30
60000/60000 [==============================] - 6s 92us/sample - loss: 5.6810e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9815
Epoch 28/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.3810e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9812
Epoch 29/30
60000/60000 [==============================] - 6s 96us/sample - loss: 5.1041e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9811
Epoch 30/30
60000/60000 [==============================] - 6s 96us/sample - loss: 4.8516e-04 - accuracy: 1.0000 - val_loss: 0.0712 - val_accuracy: 0.9819
??这里是用两个模型对比的,所以训练过程包含了两个模型的结果。
plt.plot(np.arange(epochs),history1.history['val_accuracy'],c='b',label='Dropout')
plt.plot(np.arange(epochs),history2.history['val_accuracy'],c='y',label='FC')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.show()
??模型训练结果前 1-30 周期是使用了 Dropout 的结果,后面的 1-30 周期是没有使用 Dropout 的结果。观察结果我们发现使用了 Dropout 之后训练集准确率和验证集的准确率相差并不是很大,所以能看出 Dropout 确实是可以起到抵抗过拟合的作用。我们还可以发现一个有趣的现象就是前 1-30 周期 model1 的验证集准确率还高于训练集的准确率,这是因为模 型在计算训练集准确率的时候模型还在使用 Dropout,在计算验证集准确率的时候已经不使 用 Dropout 了。使用 Dropout 的时候模型的准确率会稍微降低一些。同时我们也可以发现, 不用 Dropout 的 model2 中测试集的准确率看起来比使用 Dropout 的 model1 要更高。
??事实上使用 Dropout 之后模型的收敛速度会变慢一些,所以需要更多的训练次数才能得到最好的结果。
??这里不用 Dropout 的 model2 验证集训练 30 个周期最高准确率大概 是 98.2%左右;使用 Dropout 的 model1 如果训练足够多的周期,验证集最高准确率可以达 到 98.8%左右。
|