IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 在Keras中实现Multi-head-attention -> 正文阅读

[人工智能]在Keras中实现Multi-head-attention

多头注意力机制其实本质上就是将多个注意力结果进行拼接后输出,目前有多种拼接的方法。
第一种:拼接后乘以一个可训练矩阵进行维度转换。
在这里插入图片描述例如有32维数据,则设置8个head,每个head有32维,则最后拼接结果为32x8=256维,再设置一权值矩阵为W0为(256x32),则最后结果为
【1x256】x【256x32】=1x32
第二种方法:将每一个头的维度缩小再对每个头的结果直接拼接为最后的输出维度。
在这里插入图片描述
例如有128维数据,则设置8个head,每个head有16维,则最后拼接结果为16x8=128维。
第三种方法:对每个头的结果进行求和后求平均,此方法多用于GAT。
在这里插入图片描述
本次我们应用第一种方法来做Multi-head-attention,Attention的实现代码在我另一篇篇文章中已经实现(https://blog.csdn.net/qq_41669355/article/details/121362089),同时也借鉴了苏神在自定义层中的一个类,以实现在自定义层中调用已有的层(https://spaces.ac.cn/archives/4765)。

class MAtt(OurLayer):
    def __init__(self, out_dim, **kwargs):
        super(MAtt, self).__init__(**kwargs)
        self.out_dim = out_dim

    def build(self, input_shape):
        super(MAtt, self).build(input_shape)
        self.head1 = MyAttention(out_dim=self.out_dim)
        self.head2 = MyAttention(out_dim=self.out_dim)
        self.head3 = MyAttention(out_dim=self.out_dim)
        self.head4 = MyAttention(out_dim=self.out_dim)
        self.w0 = Dense(self.out_dim, use_bias=False)
    def call(self, inputs):
        # input_size = tf.shape(inputs)
        h1 = self.reuse(self.head1,inputs)
        h2 = self.reuse(self.head2,inputs)
        h3 = self.reuse(self.head3,inputs)
        h4 = self.reuse(self.head4,inputs)
        # h_r=tf.reshape(tf.concat([h1,h2,h3,h4],-1),(input_size[0],input_size[1],1,input_size[-1]*4))
        # h_r=tf.reshape(tf.multiply(h_r,self.w0),(input_size[0],input_size[1],input_size[-1]))
        h_r=self.reuse(self.w0,tf.concat([h1, h2, h3, h4], -1))
        # h_r=average([h1,h2,h3,h4])
        return h_r
    def compute_output_shape(self, input_shape):
        return (input_shape[0],input_shape[1], self.out_dim)

这个代码目前还是有缺陷,比如不能自定义head数,目前固定为4,将该代码放置Keras的IMBD任务进行测试,结果如下:
在这里插入图片描述
最后的结果应该过拟合了,或许还有些参数没调好(比如key_size),但懒得调了,以后有需要的话再优化叭。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-17 12:45:44  更:2021-11-17 12:47:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 5:38:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码