Pytorch矩阵运算
此文章随着博主的学习而持续更新
主要是总结在阅读代码时遇到的tensor计算问题
对角矩阵
获取矩阵对角元素
x = torch.randn(3, 3)
y = torch.diagonal(x)
print(y)
"""
tensor([[ 0.7220, -1.8137, 0.5217],
[ 0.5010, -0.0773, 0.4702],
[ 0.3320, 0.0329, 1.1394]])
tensor([ 0.7220, -0.0773, 1.1394])
"""
获取矩阵非对角元素
原理
假设一个n*n矩阵,这里下标从1开始,
k
k
k表示第k行
则对角元素的小标位置为
(
k
?
1
)
n
+
k
(k-1)n+k
(k?1)n+k 将n*n转换为(n-1)(n+1)时,会损失最后一个元素,即最后一个对角矩阵元素
将(1)式除以新矩阵的列数(n+1),余数为1,说明原矩阵对角元素位于新矩阵第一列的位置
x = torch.randn(3, 3)
y = x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
print(y)
"""
tensor([[-1.1656, -1.6680, -0.5478],
[-0.6240, 1.7954, -0.2933],
[-0.1587, 0.7560, 0.9796]])
tensor([-1.6680, -0.5478, -0.6240, -0.2933, -0.1587, 0.7560])
"""
|