一、split
split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
将张量分割成子张量.
- 如果 num_or_size_splits 是整数类型,num_split,则 value 沿维度 axis 分割成为 num_split 更小的张量.要求 num_split 均匀分配 value.shape[axis]。
- 如果 num_or_size_splits 不是整数类型,则它被认为是一个张量 size_splits,然后将 value 分割成 len(size_splits) 块.第 i 部分的形状与 value 的大小相同,除了沿维度 axis 之外的大小 size_splits[i]。
import pandas as pd
import tensorflow as tf
x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
print("x = \n", pd.DataFrame(x.numpy()))
print("-" * 200)
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
print("s0 = \n", pd.DataFrame(s0.numpy()))
print("-" * 50)
print("s1 = \n", pd.DataFrame(s1.numpy()))
print("-" * 50)
print("s2 = \n", pd.DataFrame(s2.numpy()))
print("-" * 200)
t0, t1, t2 = tf.split(x, num_or_size_splits=[4, 15, 11], axis=1)
print("t0 = \n", pd.DataFrame(t0.numpy()))
print("-" * 50)
print("t1 = \n", pd.DataFrame(t1.numpy()))
print("-" * 50)
print("t2 = \n", pd.DataFrame(t2.numpy()))
print("-" * 200)
打印结果:
x =
0 1 2 ... 27 28 29
0 -0.888679 0.882839 0.739282 ... -0.688343 -0.930151 -0.875597
1 -0.153850 -0.319729 -0.098402 ... 0.489693 -0.170844 -0.091632
2 0.003379 0.187339 0.795501 ... 0.379071 -0.256689 0.564788
3 -0.372030 0.340384 -0.875375 ... -0.214336 0.717279 0.092451
4 -0.495783 0.257741 -0.358638 ... -0.921029 -0.830439 0.507138
[5 rows x 30 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
s0 =
0 1 2 ... 7 8 9
0 -0.888679 0.882839 0.739282 ... -0.403924 0.196670 -0.098327
1 -0.153850 -0.319729 -0.098402 ... 0.418904 0.081062 0.173876
2 0.003379 0.187339 0.795501 ... 0.615282 -0.385442 -0.311836
3 -0.372030 0.340384 -0.875375 ... -0.252203 -0.587342 0.321012
4 -0.495783 0.257741 -0.358638 ... 0.552696 0.620588 0.132702
[5 rows x 10 columns]
--------------------------------------------------
s1 =
0 1 2 ... 7 8 9
0 0.509016 0.740289 -0.964265 ... 0.459772 -0.697755 -0.540041
1 0.904286 0.986134 -0.409174 ... 0.187198 -0.445747 0.813097
2 -0.137152 0.934053 -0.751823 ... 0.309953 0.716927 0.848913
3 0.096014 0.069597 0.777320 ... -0.907295 -0.384888 0.764411
4 -0.706331 -0.901017 -0.529774 ... -0.301620 0.066731 0.770751
[5 rows x 10 columns]
--------------------------------------------------
s2 =
0 1 2 ... 7 8 9
0 -0.356173 -0.040504 0.150185 ... -0.688343 -0.930151 -0.875597
1 -0.436071 -0.224807 0.383009 ... 0.489693 -0.170844 -0.091632
2 0.169518 0.384529 -0.600068 ... 0.379071 -0.256689 0.564788
3 0.038849 0.754196 -0.049200 ... -0.214336 0.717279 0.092451
4 0.245371 -0.548065 0.338353 ... -0.921029 -0.830439 0.507138
[5 rows x 10 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
t0 =
0 1 2 3
0 -0.888679 0.882839 0.739282 0.454827
1 -0.153850 -0.319729 -0.098402 -0.764573
2 0.003379 0.187339 0.795501 -0.467434
3 -0.372030 0.340384 -0.875375 0.350312
4 -0.495783 0.257741 -0.358638 0.301579
--------------------------------------------------
t1 =
0 1 2 ... 12 13 14
0 -0.484341 -0.429574 0.999090 ... 0.634394 0.459772 -0.697755
1 0.325134 -0.227807 -0.890493 ... 0.152983 0.187198 -0.445747
2 -0.074674 -0.037023 0.830544 ... -0.993245 0.309953 0.716927
3 0.044287 0.245083 -0.858829 ... -0.583070 -0.907295 -0.384888
4 -0.105187 0.293733 0.783647 ... 0.397994 -0.301620 0.066731
[5 rows x 15 columns]
--------------------------------------------------
t2 =
0 1 2 ... 8 9 10
0 -0.540041 -0.356173 -0.040504 ... -0.688343 -0.930151 -0.875597
1 0.813097 -0.436071 -0.224807 ... 0.489693 -0.170844 -0.091632
2 0.848913 0.169518 0.384529 ... 0.379071 -0.256689 0.564788
3 0.764411 0.038849 0.754196 ... -0.214336 0.717279 0.092451
4 0.770751 0.245371 -0.548065 ... -0.921029 -0.830439 0.507138
[5 rows x 11 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Process finished with exit code 0
二、unstack
将秩为 R 的张量的给定维度出栈为秩为 (R-1) 的张量.
通过沿 axis 维度将 num 张量从 value 中分离出来.如果没有指定 num(默认值),则从 value 的形状推断.如果 value.shape[axis] 不知道,则引发 ValueError.
例如,给定一个具有形状 (A, B, C, D) 的张量.
- 如果
axis == 0 ,那么 output 中的第 i 个张量就是切片 value[i, :, :, :],并且 output 中的每个张量都具有形状 (B, C, D).(请注意,出栈的维度已经消失,不像split). - 如果 axis == 1,那么 output 中的第 i 个张量就是切片 value[:, i, :, :],并且 output 中的每个张量都具有形状 (A, C, D).
tf.unstack(value, num=None, axis=0, name='unstack')
- value: A rank R > 0 Tensor to be unstacked.
- num: An int. The length of the dimension axis. Automatically inferred if None (the default).
- axis: An int. The axis to unstack along. Defaults to the first dimension. Negative - values: wrap around, so the valid range is [-R, R).
- name: A name for the operation (optional).
import tensorflow as tf
x = tf.reshape(tf.range(12), (3, 4))
print("x = \n", x)
print("-" * 200)
p, q, r = tf.unstack(x)
print("p = ", p)
print("-" * 50)
print("q = ", q)
print("-" * 50)
print("r = ", r)
print("-" * 200)
打印结果:
x =
tf.Tensor(
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]], shape=(3, 4), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
p = tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
--------------------------------------------------
q = tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
--------------------------------------------------
r = tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Process finished with exit code 0
|