IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【TextCNN完整版】快速+高准确率的baseline -> 正文阅读

[人工智能]【TextCNN完整版】快速+高准确率的baseline

前言:
2个月前写过一篇《TextCNN的完整步骤(不到60行代码)》,但是并没有考虑到后续工程化部署以及数据量较大的情况(无法全部加载到内存里),所以今天根据实际案例做了一次改造和优化。
TextCNN的操作步骤一般可以分为以下几步:
1、数据整理:日常工作中的文本可能不像比赛一样直接给你一个csv文件,你可能需要自己整合起来;另外textcnn在训练和预测时,不认分类变量(如上海, 北京等),所以必须通过map或label_encoder的方式修改,到最后样本预测结束后再map_reverse回去。
2、建立词库:tokenizer.fit_on_texts,这一步非常重要,如果后面出现训练准确率一直在个位数的情况,请到这一条仔细检查下;
3、制作tf数据集:如果文本太多内存装不下,建议还是上batch(32或64)吧。但是需要注意的是,如果你的train_data和valid_data都做成了dataset,那么test_data也必须做成dataset,虽然label目前还没有,但可以虚拟成均为0;
4、构建TextCNN网络:这个没什么好说的,具体是[2,3,4]还是[3,4,5]都可以;
5、设定weight权重:在分类任务中,绝大部分都是不平衡的,尤其是多分类,所以设定weight权重还是很必要的;
6、训练模型:可调超参包括learning_rate(建议3e-4),epochs(建议30-40,反正会设早停),optimizer(Adam就很好),EARLY_STOP_PATIENCE(早停次数,3次即可);
7、模型固化:tensorflow2中直接可以model.save(’./model/text_cnn.h5’),在本文中就不演示了;
8、模型加载:textcnn_model = tf.keras.models.load_model(‘service/model/text_cnn.h5’);
9、样本预测:text_cnn_model.predict(test_dataset),注意出来的结果是0-1的浮点数,需要通过np.argmax(predictions, axis=-1)选择正确的标签;

具体代码如下:

一、导入数据

import os
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.utils import resample
from sklearn.model_selection import train_test_split
#训练数据导入
train_type_list = []
train_text_list = []

train_dir_name_list = os.listdir('./train/')
train_dir_name_list.remove('.DS_Store')

for dir_name in train_dir_name_list:
    for file in os.listdir('./train/'+dir_name+'/'):
        train_type_list.append(dir_name.split('-')[1])
        train_text_list.append(open('./train/'+str(dir_name)+'/'+str(file),'r',encoding='gb18030',errors='ignore').read().replace('\n', ' ').replace('\u3000', ''))
        
print(len(train_type_list))
#标签字典
cls_num = len(set(train_type_list))
cls_dict = {}
for k,v in enumerate(set(train_type_list)):
    cls_dict[k] = v
cls_dict_reverse = {v:k for k,v in cls_dict.items()}
train_data = pd.DataFrame({'text':train_text_list,'target':train_type_list})
train_data['target'] = train_data['target'].map(cls_dict_reverse)
train_data = resample(train_data)
train_data.head()
#预测数据导入
test_text_list = []
test_filename = []
for file in os.listdir('./test'):
    test_filename.append(file)
    test_text_list.append(open('./test/'+file,'r', encoding='gb18030', errors='ignore').read().replace('\n',' '))
test_data = pd.DataFrame({'text':test_text_list, 'filename':test_filename})
test_data['target']=0

二、TF数据准备

X_train, X_val, y_train, y_val = train_test_split(train_data['text'], train_data['target'], test_size=0.1, random_state=27)

#tokenizer
NUM_LABEL = cls_num #类别数量
BATCH_SIZE = 32
MAX_LEN = 200 #最长序列长度
BUFFER_SIZE = tf.constant(train_data.shape[0], dtype=tf.int64)

tokenizer = tf.keras.preprocessing.text.Tokenizer(char_level=True)
tokenizer.fit_on_texts(X_train)
def build_tf_dataset(text, label, is_train=False):
    '''制作tf数据集'''
    sequence = tokenizer.texts_to_sequences(text)
    sequence_padded = tf.keras.preprocessing.sequence.pad_sequences(sequence,padding='post',maxlen=MAX_LEN)
    dataset = tf.data.Dataset.from_tensor_slices((sequence_padded, label))
    if is_train:
        dataset = dataset.shuffle(BUFFER_SIZE)
        dataset = dataset.batch(BATCH_SIZE)
        dataset = dataset.prefetch(BUFFER_SIZE)
    else:
        dataset = dataset.batch(BATCH_SIZE)
        dataset = dataset.prefetch(BATCH_SIZE)
    return dataset
train_dataset = build_tf_dataset(X_train, y_train, is_train=True)
val_dataset = build_tf_dataset(X_val, y_val, is_train=False)
test_dataset = build_tf_dataset(test_data['text'], test_data['target'], is_train=False)

三、构建TextCNN网络

VOCAB_SIZE = len(tokenizer.index_word) + 1
print(VOCAB_SIZE)
EMBEDDING_DIM = 100

FILTERS = [3, 4, 5]
NUM_FILTERS = 128 #卷积核的大小
DENSE_DIM = 256 #全连接层大小
CLASS_NUM = 20 #类别数量
DROPOUT_RATE = 0.5 #dropout比例
def build_text_cnn_model():
    inputs = tf.keras.Input(shape=(None,))
    embed = tf.keras.layers.Embedding(
        input_dim=VOCAB_SIZE,
        output_dim=EMBEDDING_DIM,
        trainable=True,
        mask_zero=True)(inputs)
    embed = tf.keras.layers.Dropout(DROPOUT_RATE)(embed)
    
    pool_outputs = []
    for filter_size in FILTERS:
        conv = tf.keras.layers.Conv1D(NUM_FILTERS,
                                     filter_size,
                                     padding='same',
                                     activation='relu',
                                     data_format='channels_last',
                                     use_bias=True)(embed)
        max_pool = tf.keras.layers.GlobalMaxPooling1D(data_format='channels_last')(conv)
        pool_outputs.append(max_pool)
    
    outputs = tf.keras.layers.concatenate(pool_outputs, axis=-1)
    outputs = tf.keras.layers.Dense(DENSE_DIM, activation='relu')(outputs)
    outputs = tf.keras.layers.Dropout(DROPOUT_RATE)(outputs)
    outputs = tf.keras.layers.Dense(CLASS_NUM, activation='softmax')(outputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

text_cnn_model = build_text_cnn_model()
text_cnn_model.summary()
#设定weight权重
df_weight = train_data['target'].value_counts().sort_index().reset_index()
df_weight['weight'] = df_weight['target'].min() / df_weight['target']
df_weight_dict = {k:v for k,v in zip(df_weight['index'], df_weight['weight'])}
df_weight_dict

四、开始训练

LR = 3e-4
EPOCHS = 30
EARLY_STOP_PATIENCE = 2
loss = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(LR)

text_cnn_model.compile(loss=loss,
                      optimizer=optimizer,
                      metrics=['accuracy'])

callback = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy',
                                           patience=EARLY_STOP_PATIENCE,
                                           restore_best_weights=True)

history = text_cnn_model.fit(train_dataset,
                            epochs=EPOCHS,
                            callbacks=[callback],
                            validation_data=val_dataset,
                            class_weight=df_weight_dict
                            )

在这里插入图片描述
在CPU上效果也不差,准确率能达到90%左右。

五、预测和导出结果

test_predict = text_cnn_model.predict(test_dataset)
preds = np.argmax(test_predict, axis=-1)
test_data['category'] = preds
test_data['category'] = test_data['category'].map(cls_dict)

test_data[['filename','category']].to_csv('zhanglei.csv', index=False)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-28 11:16:07  更:2021-11-28 11:17:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 2:51:41-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码