IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于Convolutional Block Attention Module (CBAM)的Multi-Attention模型设计与实现 -> 正文阅读

[人工智能]基于Convolutional Block Attention Module (CBAM)的Multi-Attention模型设计与实现

本文主要介绍的 Multi-Attention 方法只是一个研究思路。只是尝试了在基本模型的顶部集成多注意力机制。

其Multi-Attention模型结构如下所示:

?模型本质上是并行添加了 CBAM 和 DeepMoji 注意力机制,并在最后将它们的特征进行合并。 此外,我们通过集成全局加权平均池 (GWAP) 方法,对 CBAM 机制及其空间模块的末尾部分进行了修改。

经过对该模型的输出,我们将Output特征放入全连接的神经网络模型中进行最终的模型训练。

其结构图如下:

为了验证模型是否能够运行,我们在mnist手写汉字数据集上进行了实验。其训练主函数代码为:

image_size = 71
def resize(mnist):
     train_data = []
     for img in mnist:
            resized_img = cv2.resize(img, (image_size, image_size))
            train_data.append(resized_img)
     return train_data
(xtrain, train_target), (_, _) = tf.keras.datasets.mnist.load_data()

# resize and prepare the input 
xtrain = resize(xtrain)
xtrain = np.expand_dims(xtrain, axis=-1)
xtrain = np.repeat(xtrain, 3, axis=-1)
xtrain = xtrain.astype('float32') / 255

# train set / target 
ytrain = tf.keras.utils.to_categorical(train_target, num_classes=10)
class Classifier(tf.keras.Model):
    def __init__(self, dim):
        super(Classifier, self).__init__()
        # Defining All Layers in __init__
        # Layer of Block
        self.Base  = tf.keras.applications.Xception(
            input_shape=(image_size, image_size, 3),
            weights=None,
            include_top=False)
        # Keras Built-in
        self.GAP1 = tf.keras.layers.GlobalAveragePooling2D()
        self.GAP2 = tf.keras.layers.GlobalAveragePooling2D()
        self.BAT  = tf.keras.layers.BatchNormalization()
        self.ADD  = tf.keras.layers.Add()
        self.AVG  = tf.keras.layers.Average()
        self.DROP = tf.keras.layers.Dropout(rate=0.5)
        # Customs
        self.CAN  = ChannelAttentionModule()
        self.SPN1 = SpatialAttentionModule()
        self.SPN2 = SpatialAttentionModule()
        self.AWG  = AttentionWeightedAverage2D()
        # Tail
        self.DENS = tf.keras.layers.Dense(512, activation=tf.nn.relu)
        self.OUT  = tf.keras.layers.Dense(10,  activation='softmax', dtype=tf.float32)
    
    def call(self, input_tensor, training=False):
        # Base Inputs
        x  = self.Base(input_tensor)
        # Attention Modules 1
        # Channel Attention + Spatial Attention 
        canx   = self.CAN(x)*x
        spnx   = self.SPN1(canx)*canx
        # Global Weighted Average Poolin
        gapx   = self.GAP1(spnx)
        wvgx   = self.GAP2(self.SPN2(canx))
        gapavg = self.AVG([gapx, wvgx])
        # Attention Modules 2
        # Attention Weighted Average (AWG)
        awgavg = self.AWG(x)
        # Summation of Attentions
        x = self.ADD([gapavg, awgavg])
        # Tails
        x = self.BAT(x)
        x = self.DENS(x)
        x  = self.DROP(x, training=training)
        return self.OUT(x)
   
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(image_size, image_size, 3))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

model.compile( metrics=['accuracy'],loss = 'categorical_crossentropy', optimizer = 'adam')
model.fit(xtrain, ytrain, epochs=5)

训练过程如下:

完整代码可下载:https://download.csdn.net/download/weixin_40651515/21121787?

?模型运行环境:python3.7.6、tensorflow==2.2.0、keras==2.3.1等。

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-18 12:42:30  更:2021-08-18 12:43:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 0:52:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码