IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【自然语言处理】【文本生成】使用Transformers中的BART进行文本摘要 -> 正文阅读

[人工智能]【自然语言处理】【文本生成】使用Transformers中的BART进行文本摘要

使用Transformers中的BART进行文本摘要

相关博客
【自然语言处理】【文本生成】使用Transformers中的BART进行文本摘要
【自然语言处理】【文本生成】Transformers中使用约束Beam Search指导文本生成
【自然语言处理】【文本生成】Transformers中用于语言生成的不同解码方法
【自然语言处理】【文本生成】BART:用于自然语言生成、翻译和理解的降噪Sequence-to-Sequence预训练
【自然语言处理】【文本生成】UniLM:用于自然语言理解和生成的统一语言模型预训练
【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态

? 本文是一个基于 Transformers \text{Transformers} Transformers的文本生成代码示例。该示例中使用中文版本的 BART \text{BART} BART模型,数据则使用 NLPCC2017 \text{NLPCC2017} NLPCC2017的中文摘要数据集。数据位于百度网盘nlpcc2017_clean.json,提取码为knci

零、包引入

import torch
import datasets
import lawrouge
import numpy as np

from typing import List, Dict
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM, 
                          DataCollatorForSeq2Seq, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer, 
                          BartForConditionalGeneration)

一、定义参数

batch_size = 32
epochs = 5
max_input_length = 512 # 最大输入长度
max_target_length = 128 # 最大输出长度
learning_rate = 1e-04

二、加载数据

# 读取数据
dataset = load_dataset('json', data_files='nlpcc2017_clean.json', field='data')
# 加载tokenizer,中文bart使用bert的tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

三、数据处理

1. 调整数据格式

def flatten(example):
    return {
        "document": example["content"],
        "summary": example["title"],
        "id":"0"
    }

# 将原始数据中的content和title转换为document和summary
dataset = dataset["train"].map(flatten, remove_columns=["title", "content"])

2. 划分数据集

train_dataset, valid_dataset = dataset.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
train_dataset, test_dataset = train_dataset.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
datasets = datasets.DatasetDict({"train":train_dataset,"validation": valid_dataset,"test":test_dataset})
# print(datasets["train"][2])

3. tokenized

def preprocess_function(examples):
    """
    document作为输入,summary作为标签
    """
    inputs = [doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = datasets
tokenized_datasets = tokenized_datasets.map(preprocess_function, batched=True, remove_columns=["document", "summary", "id"])
# print(tokenized_datasets["train"][2].keys())
# print(tokenized_datasets["train"][2])

4. 定义 collate_fn \text{collate\_fn} collate_fn

def collate_fn(features: Dict):
    batch_input_ids = [torch.LongTensor(feature["input_ids"]) for feature in features]
    batch_attention_mask = [torch.LongTensor(feature["attention_mask"]) for feature in features]
    batch_labels = [torch.LongTensor(feature["labels"]) for feature in features]
    
    # padding
    batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
    batch_attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)
    batch_labels = pad_sequence(batch_labels, batch_first=True, padding_value=-100)
    return {
        "input_ids": batch_input_ids,
        "attention_mask": batch_attention_mask,
        "labels": batch_labels
    }

# 构建DataLoader来验证collate_fn
dataloader = DataLoader(tokenized_datasets["test"], shuffle=False, batch_size=4, collate_fn=collate_fn)
batch = next(iter(dataloader))
# print(batch)

四、加载模型

model = AutoModelForSeq2SeqLM.from_pretrained("fnlp/bart-base-chinese")
# output = model(**batch) # 验证前向传播
# print(output)

五、模型训练

1. 定义评估函数

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # 将id解码为文字
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # 替换标签中的-100
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # 去掉解码后的空格
    decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
    decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
    # 分词计算rouge
    # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
    # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]
    # 计算rouge
    rouge = lawrouge.Rouge()
    result = rouge.get_scores(decoded_preds, decoded_labels,avg=True)
    result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
    result = {key: value * 100 for key, value in result.items()}
    return result

2. 设置训练参数

# 设置训练参数
args = Seq2SeqTrainingArguments(
    output_dir="results", # 模型保存路径
    num_train_epochs=epochs,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    warmup_steps=500,
    weight_decay=0.001,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=500,
    evaluation_strategy="steps",
    save_total_limit=3,
    generation_max_length=max_target_length, # 生成的最大长度
    generation_num_beams=1, # beam search
    load_best_model_at_end=True,
    metric_for_best_model="rouge-1"
)

3. 定义trainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=collate_fn,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

4. 训练

train_result = trainer.train()
# 打印验证集上的结果
print(trainer.evaluate(tokenized_datasets["validation"]))
# 打印测试集上的结果
print(trainer.evaluate(tokenized_datasets["test"]))
# 保存最优模型
trainer.save_model("results/best")

六、生成

# 加载训练好的模型
model = BartForConditionalGeneration.from_pretrained("results/best")
model = model.to("cuda")
# 从测试集中挑选4个样本
test_examples = test_dataset["document"][:4]
inputs = tokenizer(
        test_examples,
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
        return_tensors="pt",
    )
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
# 生成
outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=128)
# 将token转换为文字
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
output_str = [s.replace(" ","") for s in output_str]
print(output_str)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-08 20:42:06  更:2022-10-08 20:44:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/29 7:59:58-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码