IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 二十八、基于TextCNN的中文文本分类四 -> 正文阅读

[人工智能]二十八、基于TextCNN的中文文本分类四

1. 模型的训练和评估

1.1 模型预测的流程

  • 模型预测的流程包括对文本预处理
  • 构建预测数据迭代器
  • 调用模型完成预测

1.2 TextCNN文本分类流程

  1. 准备数据:从THUCNews中抽取了20万条新闻标题,共10个预测类别

  2. 数据预处理:构建词汇表、文本向量化、按批次读取数据

  3. 模型构建:输入层->Embeding层->全连接层->输出层

  4. 模型的训练、评估和预测

1.3 代码实现

  • 步骤一:使用测试数据评估模型predict_eval.py
# coding: UTF-8
# coding:utf-8
import torch
import numpy as np
from unit28.train_eval import evaluate

MAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'

tokenizer = lambda x: [y for y in x]  # char-level

def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path)) # 加载训练好的的模型
    model.eval()  # 开启评价模式

    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)

  • 步骤二:加载待分类数据predict_eval.py
def load_dataset(text, vocab, config, pad_size=32):
    contents = []
    for line in text:
        lin = line.strip()
        if not lin:
            continue
        words_line = []
        token = tokenizer(line)
        seq_len = len(token)
        if pad_size:
            if len(token) < pad_size:
                token.extend([PAD] * (pad_size - len(token)))
            else:
                token = token[:pad_size]
                seq_len = pad_size
        # word to id
        for word in token:
            words_line.append(vocab.get(word, vocab.get(UNK)))
        contents.append((words_line, int(0), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]
  • 步骤三:加载训练好的模型进行预测predict_eval.py
def match_label(pred, config):
    label_list = config.class_list
    return label_list[pred]


def final_predict(config, model, data_iter):
    map_location = lambda storage, loc: storage
    model.load_state_dict(torch.load(config.save_path, map_location=map_location))
    model.eval()
    predict_all = np.array([])
    with torch.no_grad():
        for texts, _ in data_iter:
            outputs = model(texts)
            pred = torch.max(outputs.data, 1)[1].cpu().numpy()
            pred_label = [match_label(i, config) for i in pred]
            predict_all = np.append(predict_all, pred_label)
    return predict_all
  • 步骤四:主函数run.py
# coding:utf-8

from unit28.TextCNN import Config
from unit28.TextCNN import Model
from unit28.load_data import build_dataset
from unit28.load_data_iter import build_iterator
from unit28.predict_eval import test,load_dataset,final_predict

text = ['国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于10月28日8:00后,登录考录专题网站查询自己的“关键数字”——报名序号。'
            '国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和'
            '博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。',
            '高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】2月20日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390'
            '双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。',
            '国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。'
            '而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年'
            '联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。',
            '汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕,',
            '甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式'
            '挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的'
            '首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。',
            '佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦',
            '江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救']

if __name__ == "__main__":
    config = Config()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    # 1. 批量加载测试数据
    test_iter = build_iterator(test_data,config, False)
    config.n_vocab = len(vocab)
    # 2. 加载模型结构
    model = Model(config).to(config.device)
    # 3. 测试
    test(config, model, test_iter)

    print("+++++++++++++++++")

    # 4. 预测

    content = load_dataset(text, vocab, config)
    predict_iter = build_iterator(content, config, predict=True)

    result = final_predict(config, model, predict_iter)
    for i, j in enumerate(result):
        print('text:{}'.format(text[i]), '\t', 'label:{}'.format(j))

1.4 运行结果

运行结果:

D:\Users\tarena\PycharmProjects\nlp\venv\Scripts\python.exe D:/Users/tarena/PycharmProjects/nlp/unit28/run.py
Loading data...
Vocab size: 4762
180000it [00:02, 71001.39it/s]
10000it [00:00, 49459.62it/s]
10000it [00:00, 80214.20it/s]
Test Loss:  0.43,  Test Acc: 86.52%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

          财经     0.8903    0.8520    0.8707      1000
          房产     0.9414    0.8510    0.8939      1000
          股票     0.8416    0.7650    0.8015      1000
          教育     0.9266    0.9470    0.9367      1000
          科技     0.7047    0.8710    0.7791      1000
          社会     0.8651    0.8660    0.8656      1000
          时政     0.7977    0.8870    0.8400      1000
          体育     0.8968    0.9390    0.9174      1000
          游戏     0.9573    0.8070    0.8757      1000
          娱乐     0.8947    0.8670    0.8807      1000

    accuracy                         0.8652     10000
   macro avg     0.8716    0.8652    0.8661     10000
weighted avg     0.8716    0.8652    0.8661     10000

Confusion Matrix...
[[852   9  58   5  28   9  24  10   1   4]
 [ 21 851  26   8  28  20  19   7   2  18]
 [ 64  20 765   4  71   0  63   7   3   3]
 [  1   0   3 947   8  10  11   7   2  11]
 [  3   6  25   8 871  22  31   6  12  16]
 [  5  12   2  20  25 866  47   6   2  15]
 [  6   1  18  12  34  28 887  10   0   4]
 [  1   0   2   2  17   8  14 939   1  16]
 [  1   1   6   4 123   7  11  25 807  15]
 [  3   4   4  12  31  31   5  30  13 867]]
+++++++++++++++++
text:国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于1028800后,登录考录专题网站查询自己的“关键数字”——报名序号。国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。 	 label:教育
text:高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】220日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。 	 label:科技
text:国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。 	 label:体育
text:汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕, 	 label:娱乐
text:甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。 	 label:财经
text:佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦 	 label:娱乐
text:江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救 	 label:娱乐

Process finished with exit code 0

1.6 完整代码

"""predict_eval.py"""

# coding:utf-8
import torch
import numpy as np
from unit28.train_eval import evaluate


MAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'

tokenizer = lambda x: [y for y in x]  # char-level


def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path)) # 加载训练好的的模型
    model.eval()  # 开启评价模式

    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)


def load_dataset(text, vocab, config, pad_size=32):
    contents = []
    for line in text:
        lin = line.strip()
        if not lin:
            continue
        words_line = []
        token = tokenizer(line)
        seq_len = len(token)
        if pad_size:
            if len(token) < pad_size:
                token.extend([PAD] * (pad_size - len(token)))
            else:
                token = token[:pad_size]
                seq_len = pad_size
        # word to id
        for word in token:
            words_line.append(vocab.get(word, vocab.get(UNK)))
        contents.append((words_line, int(0), seq_len))
    return contents  # [([...], 0), ([...], 1), ...]


def match_label(pred, config):
    label_list = config.class_list
    return label_list[pred]


def final_predict(config, model, data_iter):
    map_location = lambda storage, loc: storage
    model.load_state_dict(torch.load(config.save_path, map_location=map_location))
    model.eval()
    predict_all = np.array([])
    with torch.no_grad():
        for texts, _ in data_iter:
            outputs = model(texts)
            pred = torch.max(outputs.data, 1)[1].cpu().numpy()
            pred_label = [match_label(i, config) for i in pred]
            predict_all = np.append(predict_all, pred_label)
    return predict_all

"""run.py"""
# coding:utf-8

from unit27.TextCNN import Config
from unit27.TextCNN import Model
from unit27.load_data import build_dataset
from unit27.load_data_iter import build_iterator
from unit27.train_eval import train

if __name__ == "__main__":
    config = Config()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    # 1. 批量加载数据
    train_iter = build_iterator(train_data, config, False)
    dev_iter = build_iterator(dev_data,config,False)

    config.n_vocab = len(vocab)
    # 2. 构建模型
    model = Model(config).to(config.device)
    print(model.parameters)

    # init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter)
terator(train_data, config, False)
    dev_iter = build_iterator(dev_data,config,False)

    config.n_vocab = len(vocab)
    # 2. 构建模型
    model = Model(config).to(config.device)
    print(model.parameters)

    # init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-13 11:44:22  更:2022-05-13 11:47:55 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 23:36:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码