1. Storage
Storage概述
-
t
e
n
s
o
r
tensor
tensor 分为头信息区
(
T
e
n
s
o
r
)
(Tensor)
(Tensor)和存储区
(
S
t
o
r
a
g
e
)
(Storage)
(Storage)。
- 信息区
(
T
e
n
s
o
r
)
(Tensor)
(Tensor)主要保存着
t
e
n
s
o
r
tensor
tensor 的形状
(
s
i
z
e
)
(size)
(size)、步长
(
s
t
r
i
d
e
)
(stride)
(stride)、数据类型
(
t
y
p
e
)
(type)
(type)等信息,而真正的数据则保存成连续数组,存储在存储区
(
S
t
o
r
a
g
e
)
(Storage)
(Storage)。
- 因为数据动辄成千上万,因此信息区元素占用内存较少,主要内存占用取决于
t
e
n
s
o
r
tensor
tensor 中元素的数目,即存储区的大小。
- 不同的
t
e
n
s
o
r
tensor
tensor 的头信息一般不同,但是可能使用相同的
s
t
o
r
a
g
e
storage
storage。
操作Storage
- 查看
S
t
o
r
a
g
e
Storage
Storage
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
t.storage()
- 索引
S
t
o
r
a
g
e
Storage
Storage
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
t.storage()[1]
- 更改
S
t
o
r
a
g
e
Storage
Storage
import torch
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
t.storage()[1]=10
t.storage()
2. offset
s
t
o
r
a
g
e
_
o
f
f
s
e
t
:
storage\_offset:
storage_offset:
t
e
n
s
o
r
tensor
tensor 的第一个元素在
s
t
o
r
a
g
e
storage
storage 中的索引。
import torch
x = torch.tensor([1, 2, 3, 4, 5])
x.storage_offset()
x[3:].storage_offset()
3. stride
-
s
t
r
i
d
e
stride
stride 是
s
t
o
r
a
g
e
storage
storage 中对应于
t
e
n
s
o
r
tensor
tensor 的相邻维度间第一个索引的跨度,也叫步长。
- 当我们根据下标索引查找
t
e
n
s
o
r
tensor
tensor 中的任意元素时,将某维度的下标索引和对应的步长相乘,然后将所有维度乘积相加就可以了。根据
t
e
n
s
e
r
tenser
tenser 中的索引
i
,
j
i,j
i,j 查找
s
t
o
r
a
g
e
storage
storage 中对应索引的公式是
s
t
o
r
a
g
e
_
o
f
f
s
e
t
+
s
t
r
i
d
e
[
0
]
?
i
+
s
t
r
i
d
e
[
1
]
?
j
storage\_offset+stride[0]*i+stride[1]*j
storage_offset+stride[0]?i+stride[1]?j,因为是从
s
t
o
r
a
g
e
storage
storage 的开头查找,所以
s
t
o
r
a
g
e
_
o
f
f
s
e
t
=
0
storage\_offset=0
storage_offset=0。
示例:上图是一个
s
t
o
r
a
g
e
storage
storage,与它对应的
t
e
n
s
o
r
(
[
[
1.0
,
2.0
,
3.0
]
,
[
4.0
,
5.0
,
6.0
]
]
)
tensor([[1.0,2.0,3.0], [4.0,5.0,6.0]])
tensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]) 如下图所示: 那么
t
e
n
s
o
r
tensor
tensor 的
s
t
r
i
d
e
=
(
3
,
1
)
stride=(3,1)
stride=(3,1),因为从第一行的第一个索引到第二行第一个索引跨度是
3
3
3,从第一列到第二列的跨度是
1
1
1。
import torch
t = torch.tensor([[[1,3,5],[2,4,6]],
[[1,3,5],[2,4,6]]])
t.stride()
|