填写快递单据可以直接把所有信息直接粘贴进客户端,客户端自动识别 省市、人名、电话等信息,分类填入,然后打印出来粘贴。无须人工填写,加快了作业效率。
learn from : https://aistudio.baidu.com/aistudio/projectdetail/1329361
通过使用预训练模型+finetune,训练一个快递信息抽取模型。
1. 导包
from functools import partial
import paddle
from paddlenlp.datasets import MapDataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification
from paddlenlp.metrics import ChunkEvaluator
from paddle.utils.download import get_path_from_url
2. 数据处理
URL = "https://paddlenlp.bj.bcebos.com/paddlenlp/datasets/waybill.tar.gz"
get_path_from_url(URL, "./")
epochs = 10
batch_size = 16
def load_dict(dict_path):
vocab = {}
i = 0
for line in open(dict_path, 'r', encoding='utf-8'):
key = line.strip('\n')
vocab[key] = i
i += 1
return vocab
with open("./data/test.txt", 'r', encoding='utf-8') as f:
i = 0
for line in f:
print(line)
i += 1
if i > 5:
break
def convert_example(example, tokenizer, label_vocab):
tokens, labels = example
tokenized_input = tokenizer(
tokens, return_length=True, is_split_into_words=True)
labels = ['O'] + labels + ['O']
tokenized_input['labels'] = [label_vocab[x] for x in labels]
return tokenized_input['input_ids'], tokenized_input['token_type_ids'], \
tokenized_input['seq_len'], tokenized_input['labels']
def load_dataset(datafiles):
def read(data_path):
with open(data_path, 'r', encoding='utf-8') as fp:
next(fp)
for line in fp.readlines():
words, labels = line.strip('\n').split('\t')
words = words.split('\002')
labels = labels.split('\002')
yield words, labels
if isinstance(datafiles, str):
return MapDataset(list(read(datafiles)))
elif isinstance(datafiles, list) or isinstance(datafiles, tuple):
return [MapDataset(list(read(datafile))) for datafile in datafiles]
train_ds, dev_ds, test_ds = load_dataset(datafiles=(
'./data/train.txt', './data/dev.txt', './data/test.txt'))
label_vocab = load_dict('./data/tag.dic')
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
trans_func = partial(convert_example, tokenizer=tokenizer, label_vocab=label_vocab)
train_ds.map(trans_func)
dev_ds.map(trans_func)
test_ds.map(trans_func)
print(train_ds[0])
ignore_label = -1
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
Stack(dtype='int64'),
Pad(axis=0, pad_val=ignore_label, dtype='int64')
): fn(samples)
train_loader = paddle.io.DataLoader(
dataset=train_ds,
batch_size=batch_size,
return_list=True,
collate_fn=batchify_fn)
dev_loader = paddle.io.DataLoader(
dataset=dev_ds,
batch_size=batch_size,
return_list=True,
collate_fn=batchify_fn)
test_loader = paddle.io.DataLoader(
dataset=test_ds,
batch_size=batch_size,
return_list=True,
collate_fn=batchify_fn)
3. 辅助函数
3.1 评估函数
@paddle.no_grad()
def evaluate(model, metric, data_loader):
model.eval()
metric.reset()
for input_ids, seg_ids, lens, labels in data_loader:
logits = model(input_ids, seg_ids)
preds = paddle.argmax(logits, axis=-1)
n_infer, n_label, n_correct = metric.compute(None, lens, preds, labels)
metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())
precision, recall, f1_score = metric.accumulate()
print("eval precision: %f - recall: %f - f1: %f" %
(precision, recall, f1_score))
model.train()
3.2 预测函数
def predict(model, data_loader, ds, label_vocab):
pred_list = []
len_list = []
for input_ids, seg_ids, lens, labels in data_loader:
logits = model(input_ids, seg_ids)
pred = paddle.argmax(logits, axis=-1)
pred_list.append(pred.numpy())
len_list.append(lens.numpy())
preds = parse_decodes(ds, pred_list, len_list, label_vocab)
return preds
3.3 预测结果解码
def parse_decodes(ds, decodes, lens, label_vocab):
decodes = [x for batch in decodes for x in batch]
lens = [x for batch in lens for x in batch]
id_label = dict(zip(label_vocab.values(), label_vocab.keys()))
outputs = []
for idx, end in enumerate(lens):
sent = ds.data[idx][0][:end]
tags = [id_label[x] for x in decodes[idx][1:end]]
sent_out = []
tags_out = []
words = ""
for s, t in zip(sent, tags):
if t.endswith('-B') or t == 'O':
if len(words):
sent_out.append(words)
tags_out.append(t.split('-')[0])
words = s
else:
words += s
if len(sent_out) < len(tags_out):
sent_out.append(words)
outputs.append(''.join(
[str((s, t)) for s, t in zip(sent_out, tags_out)]))
return outputs
4. 训练
model = ErnieForTokenClassification.from_pretrained("ernie-1.0", num_classes=len(label_vocab))
metric = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True)
loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label)
optimizer = paddle.optimizer.AdamW(learning_rate=2e-5, parameters=model.parameters())
step = 0
for epoch in range(epochs):
for idx, (input_ids, token_type_ids, length, labels) in enumerate(train_loader):
logits = model(input_ids, token_type_ids)
loss = paddle.mean(loss_fn(logits, labels))
loss.backward()
optimizer.step()
optimizer.clear_grad()
step += 1
print("epoch:%d - step:%d - loss: %f" % (epoch, step, loss))
evaluate(model, metric, dev_loader)
paddle.save(model.state_dict(),
'./ernie_result/model_%d.pdparams' % step)
state_dict = paddle.load("./ernie_result/model_450.pdparams")
model.load_dict(state_dict)
preds = predict(model, test_loader, test_ds, label_vocab)
file_path = "ernie_results.txt"
with open(file_path, "w", encoding="utf8") as fout:
fout.write("\n".join(preds))
print(
"The results have been saved in the file: %s, some examples are shown below: "
% file_path)
print("\n".join(preds[:10]))
训练过程:
epoch:0 - step:1 - loss: 2.788503
epoch:0 - step:2 - loss: 2.520449
epoch:0 - step:3 - loss: 2.365216
epoch:0 - step:4 - loss: 2.255839
epoch:0 - step:5 - loss: 2.108390
epoch:0 - step:6 - loss: 2.006438
...
epoch:0 - step:100 - loss: 0.045199
eval precision: 0.969141 - recall: 0.977292 - f1: 0.973199
epoch:1 - step:101 - loss: 0.026065
...
epoch:1 - step:200 - loss: 0.012335
eval precision: 0.984925 - recall: 0.989066 - f1: 0.986991
epoch:2 - step:201 - loss: 0.014337
...
epoch:2 - step:300 - loss: 0.004556
eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085
epoch:3 - step:301 - loss: 0.003423
...
epoch:3 - step:400 - loss: 0.002968
eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085
epoch:4 - step:401 - loss: 0.001868
...
epoch:4 - step:500 - loss: 0.016371
eval precision: 0.989933 - recall: 0.992431 - f1: 0.991180
epoch:5 - step:501 - loss: 0.006276
...
epoch:5 - step:530 - loss: 0.001634
...
一些预测结果:
The results have been saved in the file: ernie_results.txt, some examples are shown below:
('黑龙江省', 'A1')('双鸭山市', 'A2')('尖山区', 'A3')('八马路与东平行路交叉口北40米', 'A4')('韦业涛', 'P')('18600009172', 'T')
('广西壮族自治区', 'A1')('桂林市', 'A2')('雁山区', 'A3')('雁山镇西龙村老年活动中心', 'A4')('17610348888', 'T')('羊卓卫', 'P')
('15652864561', 'T')('河南省', 'A1')('开封市', 'A2')('顺河回族区', 'A3')('顺河区公园路32号', 'A4')('赵本山', 'P')
('河北省', 'A1')('唐山市', 'A2')('玉田县', 'A3')('无终大街159号', 'A4')('18614253058', 'T')('尚汉生', 'P')
('台湾', 'A1')('台中市', 'A2')('北区', 'A3')('北区锦新街18号', 'A4')('18511226708', 'T')('蓟丽', 'P')
('廖梓琪', 'P')('18514743222', 'T')('湖北省', 'A1')('宜昌市', 'A2')('长阳土家族自治县', 'A3')('贺家坪镇贺家坪村一组临河1号', 'A4')
('江苏省', 'A1')('南通市', 'A2')('海门市', 'A3')('孝威村孝威路88号', 'A4')('18611840623', 'T')('计星仪', 'P')
('17601674746', 'T')('赵春丽', 'P')('内蒙古自治区', 'A1')('乌兰察布市', 'A2')('凉城县', 'A3')('新建街', 'A4')
('云南省', 'A1')('临沧市', 'A2')('耿马傣族佤族自治县', 'A3')('鑫源路法院对面', 'A4')('许贞爱', 'P')('18510566685', 'T')
('四川省', 'A1')('成都市', 'A2')('双流区', 'A3')('东升镇北仓路196号', 'A4')('耿丕岭', 'P')('18513466161', 'T')
|