IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> python 基于卷积神经网络的情感分类源码实现 -> 正文阅读

[人工智能]python 基于卷积神经网络的情感分类源码实现

python 基于卷积神经网络的情感分类源码实现

一数据处理模块(utils文件夹下的CNNprocess.py)

# -*- coding: utf-8 -*-
# @Time    : 2021-08-18 11:06
# @Author  : XAT
# @FileName: CNNprocess.py
# @Software: PyCharm


# 词表隐射
from collections import defaultdict
from platform import uname

from torch.utils.data import Dataset, DataLoader, TensorDataset
from nltk.corpus import sentence_polarity
from torch.utils.data import DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence


class Vocab:
    def __init__(self, tokens=None):
        self.idx_to_token = list()
        self.token_to_idx = dict()

        if tokens is not None:
            if "<unk>" not in tokens:
                tokens = tokens + ["<unk>"]
            for token in tokens:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
            self.unk = self.token_to_idx['<unk>']

    @classmethod
    def build(cls, text, min_freq=1, reserved_tokens=None):
        token_freqs = defaultdict(int)
        for sentence in text:
            for token in sentence:
                token_freqs[token] += 1
        uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
        uniq_tokens += [token for token, freq in token_freqs.items() \
                        if freq >= min_freq and token != "<unk>"]
        return cls(uniq_tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, token):
        return self.token_to_idx.get(token, self.unk)

    def convert_tokens_to_ids(self, tokens):
        return [self.token_to_idx.get(token, self.unk) for token in tokens]

    def convert_ids_to_tokens(self, indices):
        return [self.idx_to_token[index] for index in indices]


def load_sentence_polarity():
    vocab = Vocab.build(sentence_polarity.sents())
    print('test---vocab:', len(vocab))

    train_data = [(vocab.convert_tokens_to_ids(sentence), 0) for sentence in
                  sentence_polarity.sents(categories='pos')[:4000]] \
                 + [(vocab.convert_tokens_to_ids(sentence), 1) for sentence in
                    sentence_polarity.sents(categories='neg')[:4000]]

    # 其余数据作为测试集
    test_data = [(vocab.convert_tokens_to_ids(sentence), 0) for sentence in
                 sentence_polarity.sents(categories='pos')[4000:]] \
                + [(vocab.convert_tokens_to_ids(sentence), 1) for sentence in
                   sentence_polarity.sents(categories='neg')[4000:]]

    return train_data, test_data, vocab


class BowDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)  # 返回数据集中样例的数目

    def __getitem__(self, i):
        return self.data[i]  # 返回下标为i的样例


# collate_fn参数指向一个函数,用于对一个批次的样本进行整理,如将其转换成张量 (MLP函数)
# def collate_fn(examples):
#     inputs = [torch.tensor(ex[0]) for ex in examples]
#     # 输出的目标targets为该批次中全部样例输出结果(0或者1)构成的张量
#     targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
#     # 获取一个批次中每个样例的序列长度
#     offsets = [0] + [i.shape[0] for i in inputs]
#     # 根据序列的长度,转换为每个序列起始位置的偏移量(offsets)
#     offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
#     # 将inputs 列表中的张量拼接成一个大的张量
#     inputs = torch.cat(inputs)
#     return inputs, offsets, targets

####################   CNN的collate_fn函数   ################
def collate_fn(examples):
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
    inputs = pad_sequence(inputs, batch_first=True) # 对批次内的样本进行补齐,使其具有相同长度(同最大长度序列),padding不足补0

    return inputs, targets

二、CNN模型模块(Models文件夹下的CNNModel.py)

# -*- coding: utf-8 -*-
# @Time    : 2021-08-18 10:55
# @Author  : XAT
# @FileName: CNNModel.py.py
# @Software: PyCharm

import torch.nn as nn
from torch.nn import functional as F

class CNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, filter_size, num_filter, num_class):
        super(CNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.conv1d = nn.Conv1d(embedding_dim, num_filter, filter_size, padding=1)
        self.activate = F.relu
        self.linear = nn.Linear(num_filter, num_class)

    def forward(self, inputs):
        embedding = self.embedding(inputs)
        conv1d = self.conv1d(embedding.permute(0, 2, 1))
        convolution = self.activate(conv1d)
        pooling = F.max_pool1d(convolution, kernel_size=convolution.shape[2])
        outputs = self.linear(pooling.squeeze(dim=2))
        log_probs = F.log_softmax(outputs, dim=1)
        return log_probs

三、主函数(main.py)

import sys
sys.path.append("../") #此处是为了实现在本地运行时,调用同级文件夹下的包/函数/类

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from nltk.corpus import sentence_polarity
from Models.MLPModel import MLP
from utils.MLPprocess import (BowDataset, Vocab, collate_fn, load_sentence_polarity)

from Models.CNNModel import CNN
from utils.CNNprocess import BowDataset, Vocab, collate_fn, load_sentence_polarity

import torch.optim as optim
from tqdm.auto import tqdm  # tqdm是一个python模块,能以进度条的方式显示迭代的进度


import nltk
nltk.download('sentence_polarity')
mlp = MLP(vocab_size=8,embedding_dim=3,hidden_dim=5,num_class=2)
#输入为两个长度为4的整数序列S
# inputs = torch.tensor([[0, 1, 2, 1], [4, 6, 6, 7]], dtype=torch.long)
# outputs = mlp(inputs)
# print(outputs)

# print("\n")
# print("Parameters: ")
# for name, param in mlp.named_parameters():
#     print(name,param.data)


#超参数设置
embedding_dim = 128
hidden_dim = 256
num_class = 2
batch_size = 32
num_epoch = 5
filter_size = 3
num_filter = 100

# 加载数据
train_data, test_data, vocab = load_sentence_polarity()
train_dataset = BowDataset(train_data)
test_data = BowDataset(test_data)

train_data_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=batch_size, collate_fn= collate_fn, shuffle=False)

#加载模型
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('vocab: ', vocab)
# model = MLP(len(vocab), embedding_dim, hidden_dim, num_class) #MLP
model = CNN(len(vocab), embedding_dim, filter_size, num_filter, num_class)  #CNN
#vocab_size, embedding_dim, filter_size, num_filter, num_class
model.to(device) #将模型加载到cpu或GPU设备

#训练过程
nll_loss = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(num_epoch):
    total_loss = 0
    for batch in tqdm(train_data_loader, desc=f"Training Epoch {epoch}"):
        # inputs, offsets, targets = [x.to(device) for x in batch] #MLP
        # log_probs = model(inputs, offsets) #MLP
        # loss = nll_loss(log_probs, targets) #MLP

        inputs, targets = [x.to(device) for x in batch]  # CNN Cnn中没有offset
        log_probs = model(inputs)# CNN
        loss = nll_loss(log_probs, targets)# CNN

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Loss: {total_loss: .2f}")


#测试过程
acc = 0
for batch in tqdm(test_data_loader, desc=f"Testing"):
    # inputs, offsets, targets = [x.to(device) for x in batch]  #MLP
    inputs,  targets = [x.to(device) for x in batch]  # CNN
    with torch.no_grad():
        # output = model(inputs, offsets) #MLP
        output = model(inputs)
        acc += (output.argmax(dim=1) == targets).sum().item()

#输出在测试集上的准确率
print(f"Acc: {acc /len(test_data_loader):.2f}")


  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-19 12:04:06  更:2021-08-19 12:05:35 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 23:54:15-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码