IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch中dim的理解 -> 正文阅读

[人工智能]Pytorch中dim的理解

dim的定义

dim 表示维度

x = torch.randn(2, 3, 3)

print(x)
print(x.size())
print(x.dim())

输出:

tensor([[[-1.6943, -2.1487,  1.2332],
         [-0.2261, -0.1596,  1.5513],
         [ 2.0383, -0.6982, -2.1481]],

        [[ 0.4201, -2.7373,  0.2424],
         [-1.1152,  1.3682, -1.8322],
         [ 0.1957, -0.2920,  0.1845]]])
torch.Size([2, 3, 3])
3

这样看着不是很清晰,但如果将[]格式化:


[
    [
        [-1.6943, -2.1487,  1.2332],

        [-0.2261, -0.1596,  1.5513],

        [ 2.0383, -0.6982, -2.1481]
    ],

    [
        [ 0.4201, -2.7373,  0.2424],

        [-1.1152,  1.3682, -1.8322],

        [ 0.1957, -0.2920,  0.1845]
    ]
]

  • 维度(2, 3, 3)就很明显了, 是从矩阵的外部到内部
  • x.dim() = 3意味着x有三个维度, dim = (0, 1, 2),
    • 0对应着x.size()中的(2, 3, 3)
    • 1对应着x.size()中的(2, 3, 3)
    • 2对应着x.size()中的(2, 3, 3)

dim的理解

dim = 0时, 指的是 x(3, 3)
也就是:

x = torch.randn(2, 3, 3)
print(x)

for i in x:
    print(i)
    print(i.size())

输出:

tensor([[[-1.4251, -0.8321,  1.0230],
         [ 0.2008,  0.5929, -0.7696],
         [-0.3721, -1.0837, -0.6642]],

        [[-0.5337,  0.7808,  0.4419],
         [-0.4683,  0.3847,  0.0747],
         [ 1.0156, -0.4933,  1.5340]]])


tensor(
    [
        [-1.4251, -0.8321,  1.0230],
        [ 0.2008,  0.5929, -0.7696],
        [-0.3721, -1.0837, -0.6642]
    ]
)
torch.Size([3, 3])

tensor(
    [
        [-0.5337,  0.7808,  0.4419],
        [-0.4683,  0.3847,  0.0747],
        [ 1.0156, -0.4933,  1.5340]
    ]
)
torch.Size([3, 3])

所以说当dim=0时, 相当于去除x中的dim = 0的维度

验证

  • torch.argmax(tensor)
    返回tensor中值最大的数的下标, 比较的是同型张量
    Example:
    >>> x = torch.tensor([1, 5, 8, 4, 6])
    >>> torch.argmax(x)
    tensor(2)
import torch

x = torch.randn(2, 3, 3)

print(x)

print('='*50, end='\n\n')
for i in x:
    print(i)
    print(i.size())

print('='*50, end='\n\n')

print(x.size())
print(x.dim())

print('='*50, end='\n\n')

y = torch.argmax(x, dim=0)

print(y)
print(y.size())

输出:

tensor(
    [
        [
            [-1.3918,  0.0620, -0.4111],
            [ 1.9623, -1.3399, -0.4673],
            [-0.0185, -1.9024,  0.1340]
        ],

        [
            [ 0.7135, -0.5290, -0.7656],
            [ 0.2642,  0.5956, -0.0718],
            [-0.7465, -0.8098, -0.0874]
        ]
    ]
)
==================================================

tensor([[-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]])
torch.Size([3, 3])

tensor([[ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================

torch.Size([2, 3, 3])
3
==================================================

tensor([[1, 0, 0],
        [0, 1, 1],
        [0, 1, 0]])
torch.Size([3, 3])
  • 分析一下 y[0] = [1, 0, 0], 为什么呢?
    有两种想法:

    1. 它比较的是 [-1.3918, 0.0620, -0.4111][ 0.7135, -0.5290, -0.7656]
      其中:
      [-1.3918, 0.7135], 0.7135比较大, 所以返回 1
      [0.0620, -0.5290], 0.0620比较大, 所以返回 0
      [-0.4111, -0.7656], -0.4111比较大, 所以返回 0
    2. 如果比较的是x[i]中的每一列, 得到的是2x3的输出, 例如 x[0]:
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    

    比较每一列, 经过torch.argmax得到的是 [1, 0, 2]

  • 如果按照去掉dim = 0的部分, x':

    [
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    ],
    
    [
        [ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]
    ]
    

    也就是两个size = (3, 3)tensor, 这为什么不是第二种情况就比较合理了
    因为比较的是两个tensor, 而第二种情况是分别在一个tensor内的比较, 再将两个tensor的比较结果合并

    • 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(2, 3, 3) -> (3, 3), 得到的结果的size是去掉指定dimsize
  • 如果只有两个维度, 或许会好理解一些:

    import torch
    
    x = torch.randn(2,3)
    
    print(x)
    
    y = torch.argmax(x, dim=0)
    
    print(y)
    print(y.size())
    

    输出:

    tensor(
        [
            [ 0.0251, -0.3640,  0.1965],
            [ 0.6902,  0.9846,  0.2035]
        ]
    )
    
    tensor([1, 1, 1])
    torch.Size([3])
    

    去掉dim = 0, 比较的就是 [ 0.0251, -0.3640, 0.1965][ 0.6902, 0.9846, 0.2035]
    dim = (2, 3) -> dim(3)

  • 这时候再回来看上面3个维度的例子:

    [                                           
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    ],
    [
        [ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]
    ]
    

    比较两者时相当于在下面的tensortorch.argmax()

    [
        [-1.3918,  0.0620, -0.4111],
        [ 0.7135, -0.5290, -0.7656]
    ]
    
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-26 11:41:49  更:2022-04-26 11:44:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 18:25:27-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码