IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TextCNN_pytorch实现 -> 正文阅读

[人工智能]TextCNN_pytorch实现

import numpy as np
import torch
from torch.functional import split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
"""

filter_list:

Conv2d(1, 3, kernel_size=(2, 4), stride=(1, 1))
    1:表示输入channel为1;
    3:表示输出channel为3;
    kernel_size: 卷积核大小为[2x4];
    stride=(1, 1): 步长为1进行滑动

filter_list:
    '0':Conv2d(1, 3, kernel_size=(2, 4), stride=(1, 1))
    '1':Conv2d(1, 3, kernel_size=(2, 4), stride=(1, 1))
    '2':Conv2d(1, 3, kernel_size=(2, 4), stride=(1, 1))
    len():3

a.permute(2, 0, 1): 若a维度为 [6, 1, 3],
    将a的维度转化为 [3, 6, 1]

"""


class TextCNN(nn.Module):
    def __init__(self):
        super(TextCNN, self).__init__()
        self.num_filters_total = num_filters * len(filter_sizes)  # 3 * 3 = 9
        self.W = nn.Embedding(vocab_size, embedding_size)  # (16, 4)的词表,根据索引得到
        self.Weight = nn.Linear(self.num_filters_total,
                                num_classes,
                                bias=False)  # Weight:[9, 2]  输入9个特征,输出2个特征,二分类
        self.Bias = nn.Parameter(torch.ones(
            [num_classes]))  # 加个偏置,一维长度为2的向量  data:tensor([1., 1.])
        self.filter_list = nn.ModuleList([
            nn.Conv2d(1, num_filters, (size, embedding_size))
            for size in filter_sizes
        ])  # nn.Conv2d(1, 3, (2, 4))

    def forward(self, x):
        embedded_chars = self.W(
            x)  # embedded_chars:[6, 3, 4] 词向量维度为4,输入x为[6, 3]
        # print('embedded_chars_size:{}'.format(embedded_chars.size()))
        embedded_chars = embedded_chars.unsqueeze(
            1)  # 在索引为1的地方扩充一个维度,变为[6, 1, 3, 4]

        pooled_outputs = []

        #  就相当于进行了三个滤波器,均由输入数据进过Conv2d(1, 3, kernel_size=(2, 4), stride=(1, 1))卷积
        #  得到了三个卷积后的结果,也就是输入1,输出3的含义, 输入维度:[6, 3] 输出维度:[6, 1, 1, 3]

        for i, conv in enumerate(self.filter_list):
            # print(i, conv)
            h = F.relu(conv(embedded_chars))  # h: torch.Size([6, 3, 2, 1])
            # print("h_size:{}".format(h.size()))
            mp = nn.MaxPool2d(
                (sequence_length - filter_sizes[i] + 1, 1))  #构建一个维度为 2x1的最大池化

            pooled = mp(h).permute(0, 3, 2, 1)  # mp(h): [6, 3, 1, 1]
            # print("pooled_size:{}".format(
            # pooled.size()))  #pooled_size:torch.Size([6, 1, 1, 3])
            pooled_outputs.append(
                pooled)  # 最终三个卷积后的结果append到pooled_outputs,将三个滤波器卷积后的结果拼接到list中
            # print("pooledn_outputs.size:{}, pooled_outputs:{}".format(
            # len(pooled_outputs), pooled_outputs))

        h_pool = torch.cat(
            pooled_outputs,
            len(filter_sizes))  # shape: [6, 1, 1 ,9],将三个6113cat成6119
        # print("h_pool_size:{}".format(h_pool.size()))
        h_pool_flat = torch.reshape(
            h_pool, [-1, self.num_filters_total])  # reshape为:[6, 9]维度
        # print("h_pool_flat_size:{}".format(h_pool_flat.size()))
        model = self.Weight(
            h_pool_flat) + self.Bias  # 变为 [6, 2] + 长度为2的向量,得到模型的输出
        # print('model:{}'.format(model))
        return model


"""
x.view(a,b,c) 将维度变为 [a, b, c]
"""

if __name__ == '__main__':
    embedding_size = 4
    sequence_length = 3
    num_classes = 2
    filter_sizes = [2, 2, 2]  #  ****啥作用?
    num_filters = 3  # 不是很懂啊

    sentences = [
        "i love you", "he loves me", "she likes baseball", "i hate you",
        "sorry for that", "this is awful"
    ]
    labels = [1, 1, 1, 0, 0, 0]

    # word_list = " ".join(sentences).split()
    # print('word_list:{}'.format(word_list))
    # word_list = list(set(word_list))
    # print('word_list_len:{}, word_list:{}'.format(len(word_list), word_list))
    # word_dict = {w: i for i, w in enumerate(word_list)}

    word_dict = {
        w: i
        for i, w in enumerate(list(set(' '.join(sentences).split())))
    }
    vocab_size = len(word_dict)

    model = TextCNN()  # 先允许TextCNN的__init__()函数

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    inputs = torch.LongTensor([[word_dict[i] for i in sen.split()]
                               for sen in sentences])
    targets = torch.LongTensor([out for out in labels])

    # input_ = ([[word_dict[n] for n in sen.split()]
    #            for sen in sentences])  # input_ :[batch_size, 3] 3为每条数据的长度
    # print('input_size:{}, input:{}'.format(
    #     torch.Tensor(input_).size(), input_))
    # inputs = torch.LongTensor(np.asarray(input_))
    # print('inputs_size:{}, inputs:{}'.format(
    #     inputs.size(), inputs))  # inputs: [6, 3] 一批数据6条,每条数据三个单词

    # 训练
    for epoch in range(5001):
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        if epoch % 1000 == 0:
            print("Epoch:", "%04d" % (epoch), 'cost=', '{:.6f}'.format(loss))

        loss.backward()
        optimizer.step()

    # Test
    test_text = 'he likes you'
    test = [[word_dict[t] for t in test_text.split()]]
    test = torch.LongTensor(test)

    # tests = [np.asarray([word_dict[n] for n in test_text.split()])]
    # test_batch = torch.LongTensor(tests)

    predict = model(test).data.max(1, keepdim=True)[1]

    print("predict:{}".format(predict))

    # 保存模型
    # torch.save(model, './model/TextCNN.pkl')

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-20 15:05:55  更:2021-08-20 15:09:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年10日历 -2025/10/27 5:43:36-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码