IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> [十五]深度学习Pytorch-权值初始化(Xavier和Kaiming) -> 正文阅读

[人工智能][十五]深度学习Pytorch-权值初始化(Xavier和Kaiming)

0. 往期内容

[一]深度学习Pytorch-张量定义与张量创建

[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换

[三]深度学习Pytorch-张量数学运算

[四]深度学习Pytorch-线性回归

[五]深度学习Pytorch-计算图与动态图机制

[六]深度学习Pytorch-autograd与逻辑回归

[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)

[八]深度学习Pytorch-图像预处理transforms

[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)

[十]深度学习Pytorch-transforms图像操作及自定义方法

[十一]深度学习Pytorch-模型创建与nn.Module

[十二]深度学习Pytorch-模型容器与AlexNet构建

[十三]深度学习Pytorch-卷积层(1D/2D/3D卷积、卷积nn.Conv2d、转置卷积nn.ConvTranspose)

[十四]深度学习Pytorch-池化层、线性层、激活函数层

[十五]深度学习Pytorch-权值初始化

1. 梯度消失与爆炸

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
第一个网络层的std为根号256,第二个网络层的std为256,第三个网络层的std为256根号256,第四个网络层的std为256^2。
同理一直到第n个网络层的std为(根号n)^n。最终会nan。
在这里插入图片描述

2. Xavier初始化

在这里插入图片描述
ni是输入层的神经元个数,ni+1是输出层的神经元个数。

3. Kaiming初始化

在这里插入图片描述

4. 十种初始化方法

在这里插入图片描述
在这里插入图片描述尺度=输出的方差/输入的方差

5. 代码示例

# -*- coding: utf-8 -*-
"""
# @file name  : grad_vanish_explod.py
# @brief      : 梯度消失与爆炸实验
"""
import os
import torch
import random
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        #构建100层线性层,每层有256个神经元
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)]) 
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            x = torch.relu(x)

            print("layer:{}, std:{}".format(i, x.std()))
            #判断是否为nan
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

    def initialize(self):
        #遍历每一个模块modules
        for m in self.modules():
            #判断模块是否为线性层
            if isinstance(m, nn.Linear):
                ##初始值为标准正态分布,均值为0,标准差为1,但是每层的标准差会越来越大,发生std爆炸
                nn.init.normal_(m.weight.data) #标准正态分布,均值为0,标准差为1,标准差会发生爆炸
                
                #此时的初始化权值可以使每层的均值为0,std为1.
                nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))   

                #手动计算
                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                #计算激活函数的增益
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                #设置均匀分布来初始化权值
                # nn.init.uniform_(m.weight.data, -a, a)

                #自动计算
                nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('tanh'))

                #手动计算
                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))

                #自动计算
                nn.init.kaiming_normal_(m.weight.data)

flag = 0
# flag = 1

if flag:
    layer_nums = 100 #100层神经网络
    neural_nums = 256 #每一层神经元个数为256
    batch_size = 16 

    net = MLP(neural_nums, layer_nums)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # 标准正态分布normal: mean=0, std=1

    output = net(inputs)
    print(output)

# ======================================= calculate gain =======================================

# flag = 0
flag = 1

if flag:

    x = torch.randn(10000) #标准正态分布创建10000个数据点
    out = torch.tanh(x)

    gain = x.std() / out.std()
    print('gain:{}'.format(gain))

    tanh_gain = nn.init.calculate_gain('tanh')
    print('tanh_gain in PyTorch:', tanh_gain)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-07 22:41:41  更:2022-04-07 22:42:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:22:12-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码