0. 往期内容
[一]深度学习Pytorch-张量定义与张量创建
[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换
[三]深度学习Pytorch-张量数学运算
[四]深度学习Pytorch-线性回归
[五]深度学习Pytorch-计算图与动态图机制
[六]深度学习Pytorch-autograd与逻辑回归
[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)
[八]深度学习Pytorch-图像预处理transforms
[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)
[十]深度学习Pytorch-transforms图像操作及自定义方法
[十一]深度学习Pytorch-模型创建与nn.Module
[十二]深度学习Pytorch-模型容器与AlexNet构建
[十三]深度学习Pytorch-卷积层(1D/2D/3D卷积、卷积nn.Conv2d、转置卷积nn.ConvTranspose)
[十四]深度学习Pytorch-池化层、线性层、激活函数层
[十五]深度学习Pytorch-权值初始化
深度学习Pytorch-权值初始化 Xavier和Kaiming
1. 梯度消失与爆炸
第一个网络层的std为根号256,第二个网络层的std为256,第三个网络层的std为256根号256,第四个网络层的std为256^2。 同理一直到第n个网络层的std为(根号n)^n。最终会nan。
2. Xavier初始化
ni是输入层的神经元个数,ni+1是输出层的神经元个数。
3. Kaiming初始化
4. 十种初始化方法
尺度=输出的方差/输入的方差
5. 代码示例
"""
# @file name : grad_vanish_explod.py
# @brief : 梯度消失与爆炸实验
"""
import os
import torch
import random
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1)
class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
x = torch.relu(x)
print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data)
nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))
nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('tanh'))
nn.init.kaiming_normal_(m.weight.data)
flag = 0
if flag:
layer_nums = 100
neural_nums = 256
batch_size = 16
net = MLP(neural_nums, layer_nums)
net.initialize()
inputs = torch.randn((batch_size, neural_nums))
output = net(inputs)
print(output)
flag = 1
if flag:
x = torch.randn(10000)
out = torch.tanh(x)
gain = x.std() / out.std()
print('gain:{}'.format(gain))
tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)
|