1.Keras框架需要用到Lambda函数
keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)
function:函数,分割用到的是tf.split
output_shape:输出维度,下面代码是分成2块,在num_or_size_splits决定
axis=2,感觉是按列分;axis=1,是按照行分;axis=0,请求大家评论区补充
from keras.layers import Lambda
from tensorflow import tf
X_shortcut = X #(?, 2, 256)
x11= Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 2})(X) #(?,2,128)*2
x11= Lambda(tf.split, arguments={'axis': 1, 'num_or_size_splits': 2})(X) #(?,1,256)*2
x11= Lambda(tf.split, arguments={'axis': 0, 'num_or_size_splits': 2})(X) #(?,2,256)*2
2 Tensorflow框架直接使用tf.split
tf.split(value,num_or_size_splits,axis=0,num=None,name='split')
num_or_size_splits:每个分割后的张量的尺寸
当axis=0:按行分;当axis=1:按列分
参考:
https://www.csdn.net/tags/OtDaUgysNTMyMTgtYmxvZwO0O0OO0O0O.html
tf.split()_放下扳手&拿起键盘的博客-CSDN博客
|