IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (二)详解Pytorch中的BatchNorm模块 -> 正文阅读

[人工智能](二)详解Pytorch中的BatchNorm模块


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


0.简介

Batch Normalization在训练过程中对网络的输入输出进行归一化,可有效防止梯度爆炸和梯度消失,能加快网络的收敛速度

y = x ? E ( x ) ( V a r ( x ) + ? ) γ + β y = \frac{x-E(x)}{\sqrt(Var(x)+\epsilon)}\gamma+\beta y=( ?Var(x)+?)x?E(x)?γ+β

如上式,x表示的是输入变量,E(x)Var(x)分别表示x的那每个特征维度在batch size上所求得的梯度及方差。 ? \epsilon ?是为了防止除以0,通常为1e-5, γ \gamma γ β \beta β是可学习的参数,在torch BatchNorm API中,可通过设置affine=True/False来设置这两个参数是固定还是可学习的。True表示可学习,False表示不可学习,默认 γ = 1 \gamma=1 γ=1, β = 0 \beta=0 β=0

1.BatchNorm1d

BatchNorm1d是对NXCNXCXL维度的向量做Batch Normalization,N表示Batch Size的大小,C表示数据的维度,L表示每个维度又有多少维组成。

在这里插入图片描述

如上图,表示了一组NXCXL=3X2X3的数据,
使用BatchNorm1d后的输出为:

from torch import nn
batch = nn.BatchNorm1d(2, affine=False)
t = torch.tensor([[[7,4,6],[1,2,3]],[[3,4,2],[2,4,6]],[[9,0,7],[3,8,5]]])
t = t.float()
batch(t)
"""
输出为:
tensor([[[ 0.8750, -0.2500,  0.5000],
         [-1.3250, -0.8480, -0.3710]],

        [[-0.6250, -0.2500, -1.0000],
         [-0.8480,  0.1060,  1.0600]],

        [[ 1.6250, -1.7500,  0.8750],
         [-0.3710,  2.0140,  0.5830]]])
"""

上述的计算过程等价为:

因为affine=False因此 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0,期望的计算是单独在每个维度上对Batch计算的,等价为

在特征维度0上的均值
E ( x ) = 7 + 4 + 6 + 3 + 4 + 2 + 9 + 0 + 7 3 × 3 = 4.6667 E(x) = \frac{7+4+6+3+4+2+9+0+7}{3\times3} = 4.6667 E(x)=3×37+4+6+3+4+2+9+0+7?=4.6667
同理可计算方差为:‵Var(X) = 2.6667`

tmp = t[:,0,:] 
print(tmp.mean()) 
print(tmp.var(unbiased=False).sqrt())
print((tmp-tmp.mean())/(tmp.var(unbiased=False).sqrt()+1e-5))

"""
Output:
tensor(4.6667)
tensor(2.6667)
tensor([[ 0.8750, -0.2500,  0.5000],
        [-0.6250, -0.2500, -1.0000],
        [ 1.6250, -1.7500,  0.8750]])
"""

注意在上述计算方差的过程中没有使用Bessel’s correction贝塞尔校正,除以的是n而不是n-1,因此通过这种方式计算的方差是有偏的。上面的结果与BatchNorm1d的输出是一致的。

2.BatchNorm2d

from torch import nn
batch = nn.BatchNorm2d(2, affine=False) 
img = torch.randint(0, 255, (2,2,3,3)) 
img = img.float() 
print(img)
print(batch(img))
t = img[:,0,:,:] 
print(t.mean()) 
print(t.var().sqrt())
print((t-t.mean())/(t.var(unbiased=False).sqrt()+1e-5))

"""
Output: 
tensor([[[[ 97., 163., 130.],
          [ 26.,  83., 183.],
          [165., 108., 242.]],

         [[113., 184., 236.],
          [159., 223., 247.],
          [ 48., 104., 111.]]],


        [[[110.,  93., 115.],
          [237., 168., 120.],
          [149., 115.,  48.]],

         [[117.,  22.,  43.],
          [202.,  63., 209.],
          [104., 135.,  99.]]]])
tensor([[[[-0.6115,  0.5873, -0.0121],
          [-1.9012, -0.8658,  0.9506],
          [ 0.6236, -0.4117,  2.0223]],

         [[-0.3169,  0.7350,  1.5054],
          [ 0.3646,  1.3128,  1.6683],
          [-1.2798, -0.4502, -0.3465]]],


        [[[-0.3754, -0.6842, -0.2846],
          [ 1.9315,  0.6781, -0.1938],
          [ 0.3330, -0.2846, -1.5016]],

         [[-0.2576, -1.6650, -1.3539],
          [ 1.0016, -1.0576,  1.1054],
          [-0.4502,  0.0091, -0.5243]]]])
tensor(130.6667)
tensor(56.6486)
tensor([[[-0.6115,  0.5873, -0.0121],
         [-1.9012, -0.8658,  0.9506],
         [ 0.6236, -0.4117,  2.0223]],

        [[-0.3754, -0.6842, -0.2846],
         [ 1.9315,  0.6781, -0.1938],
         [ 0.3330, -0.2846, -1.5016]]])
"""

BatchNorm2d的输入维度是NCHW形式的4维变量,计算均值和方差时是以C为标准逐各通道上计算的,每个通道上有一个均值和方差。在NHW上进行计算。

3.BatchNorm3d

batch = nn.BatchNorm3d(2, affine=False)
t = torch.randint(0, 3, (2,2,3,3,3))
t = t.float()
print(batch(t))
tmp = t[:,0,:,:,:] 
print(tmp.mean()) 
print(tmp.var().sqrt())
print((tmp-tmp.mean())/(tmp.var(unbiased=False).sqrt()+1e-5))

参考资料


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-09 12:39:40  更:2022-05-09 12:42:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 16:10:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码