IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【BUG记录】layers.BatchNormalization()的使用 -> 正文阅读

[人工智能]【BUG记录】layers.BatchNormalization()的使用

【BUG记录】 layers.BatchNormalization()的使用

我的目标是同tensorflow改写以下的pytorch搭建的CNN模型

#PyTorch搭建的模型
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.conv=nn.Sequential(
            # first layer
            nn.Conv2d(1,32,kernel_size=(2,5)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1,2)),
            # second layer
            nn.Conv2d(32, 32, kernel_size=(2, 3)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            # second layer
            nn.Conv2d(32, 32, kernel_size=(2, 2)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.dense_layer=nn.Sequential(
            nn.Flatten(),
            nn.Linear(1120,1156),
            nn.BatchNorm1d(1156),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1156, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256,5),
        )

    def forward(self,input):
        feature_cnn=self.conv(input)
        # feature_cnn=feature_cnn.view(-1,32*5*7)
        output=self.dense_layer(feature_cnn)

        return output

以下是实际搭建的模型

#tensorflow搭建模型
class ConvNet(Model):

    def __init__(self):
        self.filter=filter
        super(ConvNet,self).__init__()

        self.conv=keras.Sequential([
            # first layer
            layers.Conv2D(32,kernel_size=(2,5)),
            layers.BatchNormalization(32),
            layers.ReLU(),
            layers.MaxPool2D((1, 2)),
            # second layer
            layers.Conv2D(32, kernel_size=(2,3)),
            layers.BatchNormalization(32),
            layers.ReLU(),
            layers.MaxPool2D((1, 2)),
            # third layer
            layers.Conv2D(32, kernel_size=(2, 2)),
            layers.BatchNormalization(32),
            layers.ReLU(),
        ])

        self.dense_layer=keras.Sequential([
            layers.Flatten(),

            layers.Dense(1056),
            layers.BatchNormalization(1056),
            layers.ReLU(),
            layers.Dropout(rate=0.5),

            layers.Dense(512),
            layers.BatchNormalization(512),
            layers.ReLU(),

            layers.Dense(256),
            layers.BatchNormalization(256),
            layers.ReLU(),

            layers.Dense(5)
        ])

    def call(self,x,is_training=False):
        x=tf.reshape(x, [-1, 8, 40, 1])
        x=self.conv(x)
        # x=tf.reshape(x,[-1,32*5*7])
        x=self.dense_layer(x)

        if not is_training:
            x=tf.nn.softmax(x)
        return x

在实际运行之后出现以下问题

Traceback (most recent call last):
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3361, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-31e68b6411fb>", line 1, in <cell line: 1>
    runfile('D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py', wdir='D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code')
  File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py", line 40, in <module>
    pred=conv_net(batch_x)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1030, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\PostGraduate\Shanghai University\Research_Group\Task\EMG\self\code\EMG-left_and_right_arms - vote\main_code\cnn_function.py", line 83, in call
    x=self.conv(x)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\sequential.py", line 389, in call
    outputs = layer(inputs, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\functional.py", line 1442, in call
    return getattr(self._module, self._method_name)(*args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1023, in __call__
    self._maybe_build(inputs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2625, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 315, in build
    raise ValueError('Invalid axis: %s' % (self.axis,))
ValueError: Invalid axis: ListWrapper([32])

最后发现错误原因出现在layers.BatchNormalization(32)(tensorflow)和nn.BatchNorm2d(32)(pytorch)上,
我们来看tf中的定义layers.BatchNormalization()
可参考链接: link

tf.keras.layers.BatchNormalization(
    axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
    beta_initializer='zeros', gamma_initializer='ones',
    moving_mean_initializer='zeros',
    moving_variance_initializer='ones', beta_regularizer=None,
    gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs
)

在这里插入图片描述
而在pytorch中,链接: link.

因此正确的方式是将layers.BatchNormalization(32)改成layers.BatchNormalization(axis=3)即可。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-21 20:50:48  更:2022-03-21 20:54:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 2:07:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码