IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 知识图谱源码【DKN】(一)KCN详解 -> 正文阅读

[人工智能]知识图谱源码【DKN】(一)KCN详解

_ init _()函数

参数: self, config, pretrained_word_embedding, pretrained_entity_embedding, pretrained_context_embedding
config: 设置的固定的参数!
pretrained_word_embedding: 根据下面的使用是一个bool类型,表示是不是单词被预训练过!
pretrained_entity_embedding、pretrained_context_embedding: 同理也是

super(KCNN, self).__init__()   #调用父类(也就是torch.nn.Module)的初始化函数,建立继承关系
self.config = config           #config是预定义的超参数集合! 各种配置文件都在这里面
if pretrained_word_embedding is None:  #如果没有预训练词典,那么就用第一个
           self.word_embedding = nn.Embedding(config.num_words,
                                              config.word_embedding_dim,
                                              padding_idx=0)
else:  #我们就用预训练的词典
     self.word_embedding = nn.Embedding.from_pretrained(
         pretrained_word_embedding, freeze=False, padding_idx=0)
if pretrained_entity_embedding is None:
     self.entity_embedding = nn.Embedding(config.num_entities,
                                          config.entity_embedding_dim,
                                          padding_idx=0)
else: #实体的嵌入也是一样的,同上
     self.entity_embedding = nn.Embedding.from_pretrained(
         pretrained_entity_embedding, freeze=False, padding_idx=0)
if config.use_context:   #上下文嵌入也是一样的,同上
    if pretrained_context_embedding is None:
        self.context_embedding = nn.Embedding(
            config.num_entities,
            config.entity_embedding_dim,
            padding_idx=0)
    else:
        self.context_embedding = nn.Embedding.from_pretrained(
            pretrained_context_embedding, freeze=False, padding_idx=0)
            
self.transform_matrix = nn.Parameter(   #确定transform的参数矩阵
        torch.empty(self.config.entity_embedding_dim,
                    self.config.word_embedding_dim).uniform_(-0.1, 0.1))
self.transform_bias = nn.Parameter(      #确定transform的偏置的参数矩阵
    torch.empty(self.config.word_embedding_dim).uniform_(-0.1, 0.1))

#下面是定义一个模块字典, 字典通过x也就是模块的大小来访问是哪个卷积! 
self.conv_filters = nn.ModuleDict({    #(3/2, num_filters, (x, word_embedding_dim))
            str(x): nn.Conv2d(3 if self.config.use_context else 2,
                              self.config.num_filters,
                              (x, self.config.word_embedding_dim))
            for x in self.config.window_sizes
        })
self.additive_attention = AdditiveAttention(
	self.config.query_vector_dim, self.config.num_filters)

   def forward(self, news):
        """
        Args:
          news:       #输入的news参数是个字典! (title, title_entities)  个数
            {
                "title": batch_size * num_words_title,
                "title_entities": batch_size * num_words_title
            }

        Returns:
            final_vector: batch_size, len(window_sizes) * num_filters
        """
        # batch_size, num_words_title, word_embedding_dim
        word_vector = self.word_embedding(news["title"].to(device))    #获取新闻中的单词向量
        # batch_size, num_words_title, entity_embedding_dim
        entity_vector = self.entity_embedding(                         #获取新闻中的实体向量
            news["title_entities"].to(device))
        if self.config.use_context:
            # batch_size, num_words_title, entity_embedding_dim
            context_vector = self.context_embedding(               #获取新闻中上下文向量(也就是关系向量)
                news["title_entities"].to(device))

        # batch_size, num_words_title, word_embedding_dim
        transformed_entity_vector = torch.tanh(
            torch.add(torch.matmul(entity_vector, self.transform_matrix),
                      self.transform_bias))

        if self.config.use_context:                          
            # batch_size, num_words_title, word_embedding_dim
            transformed_context_vector = torch.tanh(       #将上下文向量经过transform
                torch.add(torch.matmul(context_vector, self.transform_matrix),
                          self.transform_bias))

            # batch_size, 3, num_words_title, word_embedding_dim
            multi_channel_vector = torch.stack([          #将三个向量进行concat
                word_vector, transformed_entity_vector,
                transformed_context_vector
            ],
                                               dim=1)
        else:
            # batch_size, 2, num_words_title, word_embedding_dim
            multi_channel_vector = torch.stack(    #否则直接进行concat
                [word_vector, transformed_entity_vector], dim=1)

        pooled_vectors = []
        for x in self.config.window_sizes:       #进行预先设定好的,根据窗口大小进行操作! 
            # batch_size, num_filters, num_words_title + 1 - x
            convoluted = self.conv_filters[str(x)](
                multi_channel_vector).squeeze(dim=3)
            # batch_size, num_filters, num_words_title + 1 - x
            activated = F.relu(convoluted)
            # batch_size, num_filters
            # Here we use a additive attention module
            # instead of pooling in the paper
            pooled = self.additive_attention(activated.transpose(1, 2))
            # pooled = activated.max(dim=-1)[0]
            # # or
            # # pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2)
            pooled_vectors.append(pooled)
        # batch_size, len(window_sizes) * num_filters
        final_vector = torch.cat(pooled_vectors, dim=1)   # 最终的向量是需要concat的!
        return final_vector

补充:

1、 python中的继承!

python2.7中的继承

super是superclass的缩写,而且在super()中要包含两个实参,子类名和对象self ! 这些必不可少!

同时父类中必须含有object这个原始父类

2、 torch.nn.Module()

如果自己想研究,官方文档

它是所有的神经网络的根父类! 你的神经网络必然要继承!
模块也可以包含其他模块,允许将它们嵌套在树结构中。所以呢,你可以将子模块指定为常规属性。常规定义子模块的方法如下:
在这里插入图片描述

以这种方式分配的子模块将被注册(也就是成为你的该分支下的子类),当你调用to()等方法的时候时,它们的参数也将被转换,等等。
当然子模块就可以包含各种线性or卷积等操作了! 也就是模型

该模型的方法: 参考博文

3、 torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)

参考博文
是torch.nn中集成的进行词嵌入的方法!

4. torch.nn

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-20 15:05:55  更:2021-08-20 15:07:00 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:02:34-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码