1. 模型的训练和评估
1.1 模型预测的流程
- 模型预测的流程包括对文本预处理
- 构建预测数据迭代器
- 调用模型完成预测
1.2 TextCNN文本分类流程
-
准备数据:从THUCNews中抽取了20万条新闻标题,共10个预测类别 -
数据预处理:构建词汇表、文本向量化、按批次读取数据 -
模型构建:输入层->Embeding层->全连接层->输出层 -
模型的训练、评估和预测
1.3 代码实现
- 步骤一:使用测试数据评估模型
predict_eval.py
import torch
import numpy as np
from unit28.train_eval import evaluate
MAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'
tokenizer = lambda x: [y for y in x]
def test(config, model, test_iter):
model.load_state_dict(torch.load(config.save_path))
model.eval()
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("Precision, Recall and F1-Score...")
print(test_report)
print("Confusion Matrix...")
print(test_confusion)
- 步骤二:加载待分类数据
predict_eval.py
def load_dataset(text, vocab, config, pad_size=32):
contents = []
for line in text:
lin = line.strip()
if not lin:
continue
words_line = []
token = tokenizer(line)
seq_len = len(token)
if pad_size:
if len(token) < pad_size:
token.extend([PAD] * (pad_size - len(token)))
else:
token = token[:pad_size]
seq_len = pad_size
for word in token:
words_line.append(vocab.get(word, vocab.get(UNK)))
contents.append((words_line, int(0), seq_len))
return contents
- 步骤三:加载训练好的模型进行预测
predict_eval.py
def match_label(pred, config):
label_list = config.class_list
return label_list[pred]
def final_predict(config, model, data_iter):
map_location = lambda storage, loc: storage
model.load_state_dict(torch.load(config.save_path, map_location=map_location))
model.eval()
predict_all = np.array([])
with torch.no_grad():
for texts, _ in data_iter:
outputs = model(texts)
pred = torch.max(outputs.data, 1)[1].cpu().numpy()
pred_label = [match_label(i, config) for i in pred]
predict_all = np.append(predict_all, pred_label)
return predict_all
from unit28.TextCNN import Config
from unit28.TextCNN import Model
from unit28.load_data import build_dataset
from unit28.load_data_iter import build_iterator
from unit28.predict_eval import test,load_dataset,final_predict
text = ['国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于10月28日8:00后,登录考录专题网站查询自己的“关键数字”——报名序号。'
'国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和'
'博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。',
'高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】2月20日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390'
'双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。',
'国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。'
'而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年'
'联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。',
'汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕,',
'甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式'
'挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的'
'首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。',
'佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦',
'江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救']
if __name__ == "__main__":
config = Config()
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, False)
test_iter = build_iterator(test_data,config, False)
config.n_vocab = len(vocab)
model = Model(config).to(config.device)
test(config, model, test_iter)
print("+++++++++++++++++")
content = load_dataset(text, vocab, config)
predict_iter = build_iterator(content, config, predict=True)
result = final_predict(config, model, predict_iter)
for i, j in enumerate(result):
print('text:{}'.format(text[i]), '\t', 'label:{}'.format(j))
1.4 运行结果
运行结果:
D:\Users\tarena\PycharmProjects\nlp\venv\Scripts\python.exe D:/Users/tarena/PycharmProjects/nlp/unit28/run.py
Loading data...
Vocab size: 4762
180000it [00:02, 71001.39it/s]
10000it [00:00, 49459.62it/s]
10000it [00:00, 80214.20it/s]
Test Loss: 0.43, Test Acc: 86.52%
Precision, Recall and F1-Score...
precision recall f1-score support
财经 0.8903 0.8520 0.8707 1000
房产 0.9414 0.8510 0.8939 1000
股票 0.8416 0.7650 0.8015 1000
教育 0.9266 0.9470 0.9367 1000
科技 0.7047 0.8710 0.7791 1000
社会 0.8651 0.8660 0.8656 1000
时政 0.7977 0.8870 0.8400 1000
体育 0.8968 0.9390 0.9174 1000
游戏 0.9573 0.8070 0.8757 1000
娱乐 0.8947 0.8670 0.8807 1000
accuracy 0.8652 10000
macro avg 0.8716 0.8652 0.8661 10000
weighted avg 0.8716 0.8652 0.8661 10000
Confusion Matrix...
[[852 9 58 5 28 9 24 10 1 4]
[ 21 851 26 8 28 20 19 7 2 18]
[ 64 20 765 4 71 0 63 7 3 3]
[ 1 0 3 947 8 10 11 7 2 11]
[ 3 6 25 8 871 22 31 6 12 16]
[ 5 12 2 20 25 866 47 6 2 15]
[ 6 1 18 12 34 28 887 10 0 4]
[ 1 0 2 2 17 8 14 939 1 16]
[ 1 1 6 4 123 7 11 25 807 15]
[ 3 4 4 12 31 31 5 30 13 867]]
+++++++++++++++++
text:国考28日网上查报名序号查询后务必牢记报名参加2011年国家公务员的考生,如果您已通过资格审查,那么请于10月28日8:00后,登录考录专题网站查询自己的“关键数字”——报名序号。国家公务员局等部门提醒:报名序号是报考人员报名确认和下载打印准考证等事项的重要依据和关键字,请务必牢记。此外,由于年龄在35周岁以上、40周岁以下的应届毕业硕士研究生和博士研究生(非在职),不通过网络进行报名,所以,这类人报名须直接与要报考的招录机关联系,通过电话传真或发送电子邮件等方式报名。 label:教育
text:高品质低价格东芝L315双核本3999元作者:徐彬【北京行情】2月20日东芝SatelliteL300(参数图片文章评论)采用14.1英寸WXGA宽屏幕设计,配备了IntelPentiumDual-CoreT2390双核处理器(1.86GHz主频/1MB二级缓存/533MHz前端总线)、IntelGM965芯片组、1GBDDR2内存、120GB硬盘、DVD刻录光驱和IntelGMAX3100集成显卡。目前,它的经销商报价为3999元。 label:科技
text:国安少帅曾两度出山救危局他已托起京师一代才俊新浪体育讯随着联赛中的连续不胜,卫冕冠军北京国安的队员心里到了崩溃的边缘,俱乐部董事会连夜开会做出了更换主教练洪元硕的决定。而接替洪元硕的,正是上赛季在李章洙下课风波中同样下课的国安俱乐部副总魏克兴。生于1963年的魏克兴球员时代并没有特别辉煌的履历,但也绝对称得上特别:15岁在北京青年队获青年联赛最佳射手,22岁进入国家队,著名的5-19一战中,他是国家队的替补队员。 label:体育
text:汤盈盈撞人心情未平复眼泛泪光拒谈悔意(附图)新浪娱乐讯汤盈盈日前醉驾撞车伤人被捕, label:娱乐
text:甲醇期货今日挂牌上市继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。基准价均为3050元/吨继上半年焦炭、铅期货上市后,酝酿已久的甲醇期货将在今日正式挂牌交易。郑州商品交易所(郑商所)昨日公布首批甲醇期货8合约的上市挂牌基准价,均为3050元/吨。据此推算,买卖一手甲醇合约至少需要12200元。业内人士认为,作为国际市场上的首个甲醇期货品种,其今日挂牌后可能会因炒新资金追捧而出现冲高走势,脉冲式行情过后可能有所回落,不过,投资者在上市初期应关注期现价差异常带来的无风险套利交易机会。 label:财经
text:佟丽娅穿白色羽毛长裙美翻,自曝跳舞的女孩能吃苦 label:娱乐
text:江欣燕透露汤盈盈钱嘉乐分手 用冷笑话补救 label:娱乐
Process finished with exit code 0
1.6 完整代码
"""predict_eval.py"""
import torch
import numpy as np
from unit28.train_eval import evaluate
MAX_VOCAB_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'
tokenizer = lambda x: [y for y in x]
def test(config, model, test_iter):
model.load_state_dict(torch.load(config.save_path))
model.eval()
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("Precision, Recall and F1-Score...")
print(test_report)
print("Confusion Matrix...")
print(test_confusion)
def load_dataset(text, vocab, config, pad_size=32):
contents = []
for line in text:
lin = line.strip()
if not lin:
continue
words_line = []
token = tokenizer(line)
seq_len = len(token)
if pad_size:
if len(token) < pad_size:
token.extend([PAD] * (pad_size - len(token)))
else:
token = token[:pad_size]
seq_len = pad_size
for word in token:
words_line.append(vocab.get(word, vocab.get(UNK)))
contents.append((words_line, int(0), seq_len))
return contents
def match_label(pred, config):
label_list = config.class_list
return label_list[pred]
def final_predict(config, model, data_iter):
map_location = lambda storage, loc: storage
model.load_state_dict(torch.load(config.save_path, map_location=map_location))
model.eval()
predict_all = np.array([])
with torch.no_grad():
for texts, _ in data_iter:
outputs = model(texts)
pred = torch.max(outputs.data, 1)[1].cpu().numpy()
pred_label = [match_label(i, config) for i in pred]
predict_all = np.append(predict_all, pred_label)
return predict_all
"""run.py"""
from unit27.TextCNN import Config
from unit27.TextCNN import Model
from unit27.load_data import build_dataset
from unit27.load_data_iter import build_iterator
from unit27.train_eval import train
if __name__ == "__main__":
config = Config()
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, False)
train_iter = build_iterator(train_data, config, False)
dev_iter = build_iterator(dev_data,config,False)
config.n_vocab = len(vocab)
model = Model(config).to(config.device)
print(model.parameters)
print(model.parameters)
train(config, model, train_iter, dev_iter)
terator(train_data, config, False)
dev_iter = build_iterator(dev_data,config,False)
config.n_vocab = len(vocab)
model = Model(config).to(config.device)
print(model.parameters)
print(model.parameters)
train(config, model, train_iter, dev_iter)
|