mzcn
中文版本的matchzoo-py
本库包是基于matchzoo-py的库包做的二次开发开源项目,MatchZoo 是一个通用的文本匹配工具包,它旨在方便大家快速的实现、比较、以及分享最新的深度文本匹配模型。
由于matchzoo-py面向英文预处理较为容易,中文处理则需要进行一定的预处理。为此本人在借鉴学习他人成功的基础上,改进了matchzoo-py包,开发mzcn库包。
mzcn库包对中文文本语料进行只保留文本、去除表情、去除空格、去除停用词等操作,使得使用者可以快速进行中文文本语料进行预处理,使用方法和matchzoo-py基本一致。
快速入手
定义损失函数和指标
import torch
import numpy as np
import pandas as pd
import mzcn as mz
print('matchzoo version', mz.__version__)
ranking_task = mz.tasks.Ranking(losses=mz.losses.RankHingeLoss())
ranking_task.metrics = [
mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
mz.metrics.NormalizedDiscountedCumulativeGain(k=5),
mz.metrics.MeanAveragePrecision()
]
print("`ranking_task` initialized with metrics", ranking_task.metrics)
C:\Users\Administrator\Anaconda3\lib\requests\__init__.py:80: RequestsDependencyWarning: urllib3 (1.25.11) or chardet (3.0.4) doesn't match a supported version!
RequestsDependencyWarning)
matchzoo version 1.0
`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]
准备输入数据
def load_data(tmp_data,tmp_task):
df_data = mz.pack(tmp_data,task=tmp_task)
return df_data
train=pd.read_csv('./data/train_data.csv').sample(100)
dev=pd.read_csv('./data/dev_data.csv').sample(50)
test=pd.read_csv('./data/test_data.csv').sample(30)
train_pack_raw = load_data(train,ranking_task)
dev_pack_raw = load_data(dev,ranking_task)
test_pack_raw=load_data(test,ranking_task)
数据集预处理
preprocessor = mz.models.aNMM.get_default_preprocessor()
train_pack_processed = preprocessor.fit_transform(train_pack_raw)
dev_pack_processed = preprocessor.transform(dev_pack_raw)
test_pack_processed = preprocessor.transform(test_pack_raw)
Processing text_left with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 0%| | 0/92 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\ADMINI~1\AppData\Local\Temp\jieba.cache
Loading model cost 1.062 seconds.
Prefix dict has been built successfully.
Processing text_left with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 92/92 [00:01<00:00, 61.25it/s]
Processing text_right with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 93/93 [00:00<00:00, 216.90it/s]
Processing text_right with append: 100%|████████████████████████████████████████████| 93/93 [00:00<00:00, 92741.39it/s]
Building FrequencyFilter from a datapack.: 100%|████████████████████████████████████| 93/93 [00:00<00:00, 46575.55it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 93/93 [00:00<00:00, 46503.37it/s]
Processing text_left with extend: 100%|█████████████████████████████████████████████| 92/92 [00:00<00:00, 15340.54it/s]
Processing text_right with extend: 100%|████████████████████████████████████████████| 93/93 [00:00<00:00, 93073.32it/s]
Building Vocabulary from a datapack.: 100%|██████████████████████████████████████| 817/817 [00:00<00:00, 203900.18it/s]
Processing text_left with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 92/92 [00:00<00:00, 218.14it/s]
Processing text_right with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 93/93 [00:00<00:00, 227.51it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 93/93 [00:00<00:00, 46536.66it/s]
Processing text_left with transform: 100%|██████████████████████████████████████████| 92/92 [00:00<00:00, 30685.96it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 93/93 [00:00<00:00, 31014.57it/s]
Processing length_left with len: 100%|██████████████████████████████████████████████| 92/92 [00:00<00:00, 92138.48it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 93/93 [00:00<00:00, 46497.83it/s]
Processing text_left with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 45/45 [00:00<00:00, 202.82it/s]
Processing text_right with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 50/50 [00:00<00:00, 215.62it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 50/50 [00:00<00:00, 49920.30it/s]
Processing text_left with transform: 100%|██████████████████████████████████████████| 45/45 [00:00<00:00, 11257.53it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 50/50 [00:00<00:00, 50135.12it/s]
Processing length_left with len: 100%|██████████████████████████████████████████████| 45/45 [00:00<00:00, 22512.37it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 50/50 [00:00<00:00, 12510.60it/s]
Processing text_left with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 30/30 [00:00<00:00, 209.93it/s]
Processing text_right with chain_transform of ChineseRemoveBlack => ChineseSimplified => ChineseEmotion => IsChinese => ChineseStopRemoval => ChineseTokenizeDemo => Tokenize => Lowercase => PuncRemoval: 100%|█| 28/28 [00:00<00:00, 209.05it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 28/28 [00:00<00:00, 28062.25it/s]
Processing text_left with transform: 100%|██████████████████████████████████████████| 30/30 [00:00<00:00, 10006.29it/s]
Processing text_right with transform: 100%|█████████████████████████████████████████| 28/28 [00:00<00:00, 14031.12it/s]
Processing length_left with len: 100%|███████████████████████████████████████████████| 30/30 [00:00<00:00, 7504.12it/s]
Processing length_right with len: 100%|█████████████████████████████████████████████| 28/28 [00:00<00:00, 13924.65it/s]
生成训练数据
trainset = mz.dataloader.Dataset(
data_pack=train_pack_processed,
mode='pair',
num_dup=2,
num_neg=1
)
devset = mz.dataloader.Dataset(
data_pack=dev_pack_processed
)
生成管道
padding_callback = mz.models.aNMM.get_default_padding_callback()
trainloader = mz.dataloader.DataLoader(
dataset=trainset,
stage='train',
callback=padding_callback,
)
devloader = mz.dataloader.DataLoader(
dataset=devset,
stage='dev',
callback=padding_callback,
)
定义模型
model = mz.models.aNMM()
model.params['task'] = ranking_task
model.params["embedding_output_dim"]=100
model.params["embedding_input_dim"]=preprocessor.context["embedding_input_dim"]
model.params['dropout_rate'] = 0.1
model.build()
print(model)
aNMM(
(embedding): Embedding(319, 100, padding_idx=0)
(matching): Matching()
(hidden_layers): Sequential(
(0): Sequential(
(0): Linear(in_features=200, out_features=100, bias=True)
(1): ReLU()
)
(1): Sequential(
(0): Linear(in_features=100, out_features=1, bias=True)
(1): ReLU()
)
)
(q_attention): Attention(
(linear): Linear(in_features=100, out_features=1, bias=False)
)
(dropout): Dropout(p=0.1, inplace=False)
(out): Linear(in_features=1, out_features=1, bias=True)
)
模型训练
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
trainer = mz.trainers.Trainer(
model=model,
optimizer=optimizer,
trainloader=trainloader,
validloader=devloader,
validate_interval=None,
epochs=10
)
trainer.run()
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-1 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-2 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-3 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-4 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-5 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-6 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-7 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-8 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-9 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
HBox(children=(IntProgress(value=0, max=1), HTML(value='')))
[Iter-10 Loss-1.000]:
Validation: normalized_discounted_cumulative_gain@3(0.0): 0.2121 - normalized_discounted_cumulative_gain@5(0.0): 0.2121 - mean_average_precision(0.0): 0.2121
Cost time: 3.3411495685577393s
Install
由于mzcn是依赖于matchzoo-py模型,所以一共有两种途径安装mzcn
Install MatchZoo-py from Pypi:
pip install mzcn
Install MatchZoo-py from the Github source:
git clone https://github.com/yingdajun/mzcn.git cd mzcn python setup.py install
Citation
本人是第一次写库包,水平有限,希望能给大家带来使用的帮助,如果有不足的地方请多指教 这里是所有引用过的库包
matchzoo-py
@inproceedings{Guo:2019:MLP:3331184.3331403, author = {Guo, Jiafeng and Fan, Yixing and Ji, Xiang and Cheng, Xueqi}, title = {MatchZoo: A Learning, Practicing, and Developing System for Neural Text Matching}, booktitle = {Proceedings of the 42Nd International ACM SIGIR Conference on Research and Development in Information Retrieval}, series = {SIGIR’19}, year = {2019}, isbn = {978-1-4503-6172-9}, location = {Paris, France}, pages = {1297–1300}, numpages = {4}, url = {http://doi.acm.org/10.1145/3331184.3331403}, doi = {10.1145/3331184.3331403}, acmid = {3331403}, publisher = {ACM}, address = {New York, NY, USA}, keywords = {matchzoo, neural network, text matching}, }
CSDN的作者:SK-Berry的博文
https://blog.csdn.net/sk_berry/article/details/104984599
|