IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> torch 的 3种矩阵乘法运算 -> 正文阅读

[人工智能]torch 的 3种矩阵乘法运算

torch.tensor * torch.tensor

当操作符是最最最自然的 “*” 时,执行的时 element-wise 乘法,操作数会 broadcast。
更多细节请见Tensor unsqueeze 以 broadcast

torch.mm(就是执行矩阵乘法,1维不能作参数)

就是执行矩阵乘法
torch.mm(input, mat2, *, out=None) → Tensor
Performs a matrix multiplication of the matrices input and mat2.
If input is a (n × \times × m) tensor, mat2 is a (m × \times × p) tensor, out will be a (n × \times × p)tensor.

例1: [3,3] [3,3]

import torch
Mat1 = torch.tensor([[1, 6, 7],
                      [2, 5, 8],
                      [3, 4, 9]])
Mat2 = torch.tensor([[9, 6, 3],
                      [8, 5, 2],
                      [7, 4, 1]])

Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)

输出

tensor([[106,  64,  22],
        [114,  69,  24],
        [122,  74,  26]])

例2 [1,2] [2,3]

import torch
Mat1 = torch.tensor([[1, 3]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[6, 4, 2],
                      [5, 3, 1]])
print("Mat2's shape: ",Mat2.shape)
Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)

输出:

Mat1's shape:  torch.Size([1, 2])
Mat2's shape:  torch.Size([2, 3])
tensor([[21, 13,  5]])

例3 与例2对比,只相差一个括号

将 Mat1 修改为

Mat1 = torch.tensor([1, 3])

输出:

Mat1's shape:  torch.Size([2])
Mat2's shape:  torch.Size([2, 3])
    Mat3 = torch.mm(Mat1, Mat2, out=None)
RuntimeError: self must be a matrix

例4 [2,2] [3,3] 报错

import torch
Mat1 = torch.tensor([[1, 3],
                      [2, 4]] )

Mat2 = torch.tensor([[6, 4, 2],
                      [5, 3, 1],
                     [7,8,9]])

Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)

输出:

    Mat3 = torch.mm(Mat1, Mat2, out=None)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 3x3)

torch.mutual

1 维tensor 和 1 维 tensor(向量点乘)

import torch
Vec1 = torch.tensor([6,4,2])
print("Vec1's shape: ",Vec1.shape)
Vec2 = torch.tensor([5,3,1])
print("Vec2's shape: ",Vec2.shape)
Vec3 = torch.matmul(Vec1, Vec2)
print("Vec3: ",Vec3)
print("Vec3's shape: ",Vec3.shape,"\n")

输出

Vec1's shape:  torch.Size([3])
Vec2's shape:  torch.Size([3])
Vec3:  tensor(44)
Vec3's shape:  torch.Size([]) 

注意,Vec3 的 shape 是 torch.Size([]) 。
但是如果直接 print(torch.tensor([44]).shape)
会得到 torch.Size([1]) 而不是 torch.Size([])

2 维tensor 和 2 维 tensor

Mat1 = torch.tensor([[1, 3]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[6, 4, 2],
                      [5, 3, 1]])
print("Mat2's shape: ",Mat2.shape)
Mat3 = torch.matmul(Mat1, Mat2)
print("Mat3: ",Mat3)
print("Mat3's shape: ",Mat3.shape,"\n")

Mat4 = torch.matmul(Mat2, Mat1)
print("Mat4: ",Mat4)
print("Mat4's shape: ",Mat4.shape,"\n")

输出

Mat1's shape:  torch.Size([1, 2])
Mat2's shape:  torch.Size([2, 3])

Mat3:  tensor([[21, 13,  5]])

Mat3's shape:  torch.Size([1, 3]) 

Traceback (most recent call last):
  File "D:/Test2022.py", line 29, in <module>
    Mat4 = torch.matmul(Mat2, Mat1)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 1x2)

注意,上面的 Mat3 结果和 torch.mm 计算出来的例 2 的结果一样的。
说明 2 维 tensor 与 2 维 tensor,torch.matmul 函数也是执行矩阵乘法。
注意 Mat4 的报错。

1维-2 维,2维-1维

import torch

# first argument 1D and second argument 2D
mat1_1 = torch.tensor([3, 6, 2])

mat1_2 = torch.tensor([[1, 2, 3],
                       [4, 3, 8],
                       [1, 7, 2]])

out_1 = torch.matmul(mat1_1, mat1_2)
print("\n1D-2D matmul :\n", out_1)

# first argument 2D and second argument 1D
mat2_1 = torch.tensor([[1, 2, 3],
                       [4, 3, 8],
                       [1, 7, 2]])

mat2_2 = torch.tensor([3, 6, 2])

# assigning to output tensor
out_2 = torch.matmul(mat2_1, mat2_2)

print("\n2D-1D matmul :\n", out_2)

输出:

1D-2D matmul :
 tensor([29, 38, 61])

2D-1D matmul :
 tensor([21, 46, 49])

第一种情况1D-2D matmul 可以用
1 × 3 1\times3 1×3 3 × 3 3\times3 3×3 的矩阵乘法来理解。
第二种情况可以用
3 × 3 3\times3 3×3 3 × 1 3\times1 3×1 的矩阵乘法来理解。

批矩阵乘法

import torch

# first argument 1D and second argument 2D
Mat1 = torch.tensor([[[1, 4,-9],
                       [2,-5,8],
                       [3,6,-7]],
                      [[2, 4, 6],
                       [1, 3, 5],
                       [7, 8, 9]]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[[1, 2, 3],
                       [4, 3, 8],
                       [1, 7, 2]],
                      [[1, 7, 2],
                       [3, 2, 3],
                       [1, 1, 2]]])
print("Mat2's shape: ",Mat2.shape)
Out1 = torch.matmul(Mat1, Mat2)
print("\n3D-3D matmul :\n", Out1)
print("Out1's shape: ",Out1.shape)

输出:

Mat1's shape:  torch.Size([2, 3, 3])
Mat2's shape:  torch.Size([2, 3, 3])

3D-3D matmul :
 tensor([[[  8, -49,  17],
         [-10,  45, -18],
         [ 20, -25,  43]],

        [[ 20,  28,  28],
         [ 15,  18,  21],
         [ 40,  74,  56]]])
Out1's shape:  torch.Size([2, 3, 3])

Process finished with exit code 0

注,输入的两个 tensor 的 shape 都是
[ 2 , 3 , 3 ] [2,3,3] [2,3,3]
输出的 tensor 的shape 也是
[ 2 , 3 , 3 ] [2,3,3] [2,3,3]
实际上是 2 个
3 × 3 3\times3 3×3
的矩阵对应相乘,拼成一个
[ 2 , 3 , 3 ] [2,3,3] [2,3,3]
的输出

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-08-19 19:04:57  更:2022-08-19 19:06:39 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年10日历 -2025/10/22 8:22:58-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码