【BUG记录】 layers.BatchNormalization()的使用
我的目标是同tensorflow改写以下的pytorch搭建的CNN模型
#PyTorch搭建的模型
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv=nn.Sequential(
# first layer
nn.Conv2d(1,32,kernel_size=(2,5)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d((1,2)),
# second layer
nn.Conv2d(32, 32, kernel_size=(2, 3)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d((1, 2)),
# second layer
nn.Conv2d(32, 32, kernel_size=(2, 2)),
nn.BatchNorm2d(32),
nn.ReLU(),
)
self.dense_layer=nn.Sequential(
nn.Flatten(),
nn.Linear(1120,1156),
nn.BatchNorm1d(1156),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1156, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256,5),
)
def forward(self,input):
feature_cnn=self.conv(input)
# feature_cnn=feature_cnn.view(-1,32*5*7)
output=self.dense_layer(feature_cnn)
return output
以下是实际搭建的模型
#tensorflow搭建模型
class ConvNet(Model):
def __init__(self):
self.filter=filter
super(ConvNet,self).__init__()
self.conv=keras.Sequential([
# first layer
layers.Conv2D(32,kernel_size=(2,5)),
layers.BatchNormalization(32),
layers.ReLU(),
layers.MaxPool2D((1, 2)),
# second layer
layers.Conv2D(32, kernel_size=(2,3)),
layers.BatchNormalization(32),
layers.ReLU(),
layers.MaxPool2D((1, 2)),
# third layer
layers.Conv2D(32, kernel_size=(2, 2)),
layers.BatchNormalization(32),
layers.ReLU(),
])
self.dense_layer=keras.Sequential([
layers.Flatten(),
layers.Dense(1056),
layers.BatchNormalization(1056),
layers.ReLU(),
layers.Dropout(rate=0.5),
layers.Dense(512),
layers.BatchNormalization(512),
layers.ReLU(),
layers.Dense(256),
layers.BatchNormalization(256),
layers.ReLU(),
layers.Dense(5)
])
def call(self,x,is_training=False):
x=tf.reshape(x, [-1, 8, 40, 1])
x=self.conv(x)
# x=tf.reshape(x,[-1,32*5*7])
x=self.dense_layer(x)
if not is_training:
x=tf.nn.softmax(x)
return x
在实际运行之后出现以下问题
Traceback (most recent call last):
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3361, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-31e68b6411fb>", line 1, in <cell line: 1>
runfile('D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py', wdir='D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code')
File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py", line 40, in <module>
pred=conv_net(batch_x)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1030, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "D:\PostGraduate\Shanghai University\Research_Group\Task\EMG\self\code\EMG-left_and_right_arms - vote\main_code\cnn_function.py", line 83, in call
x=self.conv(x)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\sequential.py", line 389, in call
outputs = layer(inputs, **kwargs)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\functional.py", line 1442, in call
return getattr(self._module, self._method_name)(*args, **kwargs)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1023, in __call__
self._maybe_build(inputs)
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2625, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 315, in build
raise ValueError('Invalid axis: %s' % (self.axis,))
ValueError: Invalid axis: ListWrapper([32])
最后发现错误原因出现在layers.BatchNormalization(32) (tensorflow)和nn.BatchNorm2d(32) (pytorch)上, 我们来看tf中的定义layers.BatchNormalization() 可参考链接: link
tf.keras.layers.BatchNormalization(
axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones', beta_regularizer=None,
gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs
)
而在pytorch中,链接: link.
因此正确的方式是将layers.BatchNormalization(32) 改成layers.BatchNormalization(axis=3) 即可。
|