IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch中的广播机制 -> 正文阅读

[人工智能]Pytorch中的广播机制

广播条件

两个张量只有都满足下面两个条件,才可以广播:

  1. 每个张量都至少有一个维度
  2. 对两个张量的维度从后往前(从右向左) 处理,维度的大小(这个维度的长度)必须要么相等要么其中一个为1,或者其中一个张量后面不存在维度了

例:

>>>import torch
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
'''相同的形状总是可以广播的'''

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
'''不能广播,因为两个张量都必须只有一个维度'''

'''可以将尾部对齐(can line up trailing dimensions)'''
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
'''
x和y可以广播
# 倒数第一个维度:x size == y size == 1
# 倒数第二个维度: y has size 1
# 倒数第三个维度 == y size
# 倒数第四个维度: y后面不再有维度了
——————英文原文如下,表达语句值得学习——————
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
'''

>>> x=torch.empty(3,2,4,1)
>>> y=torch.empty(  3,1,1)
'''x和y不能广播,因为倒数第三个维度大小不同,且不为1'''

———————————————————————————

运算法则

如果两个张量x, y是可广播的,结果的张量大小按如下方式计算:

  1. 如果x和y的维度数量不同,对维度数量少的张量增加新的维度,且维度大小为1,使得两个张量的维度数量相同
  2. 对每个维度,结果的维度大小是x和y的维度大小的最大值。(其实如果某个维度大小不同,那么有一个维度大小肯定是1)

例1:

import torch
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

上面以维度的角度展示了运算法则。下面从元素级运算,展示具体操作原理:

例2:

import torch

x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,3).reshape(1,3)
print(y)
z = x + y
print(z)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],
         
        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
         
tensor([[0, 1, 2]])

tensor([[[ 0,  2,  4],
         [ 3,  5,  7],
         [ 6,  8, 10],
         [ 9, 11, 13]],
         
        [[12, 14, 16],
         [15, 17, 19],
         [18, 20, 22],
         [21, 23, 25]]])

例3:

x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,4).reshape(4,1)
print(y)
z = x + y
print(z)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],
         
        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])
         
tensor([[0],
        [1],
        [2],
        [3]])
        
tensor([[[ 0,  1,  2],
         [ 4,  5,  6],
         [ 8,  9, 10],
         [12, 13, 14]],
         
        [[12, 13, 14],
         [16, 17, 18],
         [20, 21, 22],
         [24, 25, 26]]])

例4:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,3).reshape(3,1)
print(y)
z = x + y
print(z)

结果为:

tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[0],
        [1],
        [2]])
        
tensor([[[[ 0,  1],
          [ 3,  4],
          [ 6,  7]],
          
         [[ 6,  7],
          [ 9, 10],
          [12, 13]],
          
         [[12, 13],
          [15, 16],
          [18, 19]],
          
         [[18, 19],
          [21, 22],
          [24, 25]]],
          
        [[[24, 25],
          [27, 28],
          [30, 31]],
          
         [[30, 31],
          [33, 34],
          [36, 37]],
          
         [[36, 37],
          [39, 40],
          [42, 43]],
          
         [[42, 43],
          [45, 46],
          [48, 49]]]])

例5:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,8).reshape(4,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[0, 1]],

        [[2, 3]],
        
        [[4, 5]],
                        
        [[6, 7]]])
        
tensor([[[[ 0,  2],
          [ 2,  4],
          [ 4,  6]],
          
         [[ 8, 10],
          [10, 12],
          [12, 14]],
          
         [[16, 18],
          [18, 20],
          [20, 22]],
          
         [[24, 26],
          [26, 28],
          [28, 30]]],
          
        [[[24, 26],
          [26, 28],
          [28, 30]],
          
         [[32, 34],
          [34, 36],
          [36, 38]],
          
         [[40, 42],
          [42, 44],
          [44, 46]],
          
         [[48, 50],
          [50, 52],
          [52, 54]]]])

例6:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,4).reshape(1,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],         
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[2, 3]]])

tensor([[[[ 2,  4],
          [ 4,  6],
          [ 6,  8]],
          
         [[ 8, 10],
          [10, 12],
          [12, 14]],
          
         [[14, 16],
          [16, 18],
          [18, 20]],
          
         [[20, 22],
          [22, 24],
          [24, 26]]],
          
        [[[26, 28],
          [28, 30],
          [30, 32]],
          
         [[32, 34],
          [34, 36],
          [36, 38]],
          
         [[38, 40],
          [40, 42],
          [42, 44]],
          
         [[44, 46],
          [46, 48],
          [48, 50]]]])

例7:

x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,8).reshape(2,1,3,1)
print(y)
z = x + y
print(z)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],
          
         [[ 6,  7],
          [ 8,  9],
          [10, 11]],
          
         [[12, 13],
          [14, 15],
          [16, 17]],
          
         [[18, 19],
          [20, 21],
          [22, 23]]],
          
        [[[24, 25],
          [26, 27],
          [28, 29]],
          
         [[30, 31],
          [32, 33],
          [34, 35]],
          
         [[36, 37],
          [38, 39],
          [40, 41]],
          
         [[42, 43],
          [44, 45],
          [46, 47]]]])
          
tensor([[[[2],
          [3],
          [4]]],
          
        [[[5],
          [6],
          [7]]]])
          
tensor([[[[ 2,  3],
          [ 5,  6],
          [ 8,  9]],
          
         [[ 8,  9],
          [11, 12],
          [14, 15]],
          
         [[14, 15],
          [17, 18],
          [20, 21]],
          
         [[20, 21],
          [23, 24],
          [26, 27]]],
          
        [[[29, 30],
          [32, 33],
          [35, 36]],
          
         [[35, 36],
          [38, 39],
          [41, 42]],
          
         [[41, 42],
          [44, 45],
          [47, 48]],
          
         [[47, 48],
          [50, 51],
          [53, 54]]]])

例8:

x = torch.arange(0,24).reshape(2,4,3,1)
print(x)
y = torch.arange(0,12).reshape(2,1,3,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0],
          [ 1],
          [ 2]],
          
         [[ 3],
          [ 4],
          [ 5]],
          
         [[ 6],
          [ 7],
          [ 8]],
          
         [[ 9],
          [10],
          [11]]],
          
        [[[12],
          [13],
          [14]],
          
         [[15],
          [16],
          [17]],
          
         [[18],
          [19],
          [20]],
          
         [[21],
          [22],
          [23]]]])
          
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]]],
          
        [[[ 6,  7],
          [ 8,  9],
          [10, 11]]]])
          
tensor([[[[ 0,  1],
          [ 3,  4],
          [ 6,  7]],
         [[ 3,  4],
         
          [ 6,  7],
          [ 9, 10]],
          
         [[ 6,  7],         
          [ 9, 10],
          [12, 13]],
          
         [[ 9, 10],
          [12, 13],
          [15, 16]]],
          
        [[[18, 19],
          [21, 22],
          [24, 25]],
          
         [[21, 22],
          [24, 25],
          [27, 28]],
          
         [[24, 25],
          [27, 28],
          [30, 31]],
          
         [[27, 28],
          [30, 31],
          [33, 34]]]])
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-11 22:11:31  更:2022-03-11 22:15:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 14:32:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码