IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> transformers加载roberta实现序列标注任务 -> 正文阅读

[人工智能]transformers加载roberta实现序列标注任务

transformers加载roberta实现序列标注任务

最近在断断续续的学习hugging face的transformers,主要是运用transformers加载各种预训练模型实现序列标注。本次博客的主要内容是争对加载roberta预训练模型做序列标注任务。大致内容如下:
(1)roberta 模型
(2)transformers实现序列标注

roberta模型

争对bert模型,有很多改进版本的模型,roberta模型与bert模型有以下几个不同的地方,其中roberta的全称为Robustly optimized BERT approach。
(1)roberta的训练语料增大,模型效果比bert好
(2)roberta使用动态的mask,这个与bert的mask不一样,bert在模型训练的时候,每一个epoch的数据mask都一样,而roberta改进了这个maks的方式,采用动态的mask,即每一个epoch的输入的数据的mask不一样
(3)roberta采用Byte-Pair Encoding的编码方式,即采用character- and word-level representations
(4)roberta取消了nsp,论文中实验证明nsp损失下游任务
在这里插入图片描述
(5)roberta的优化器对adam的参数进行了修改
(6)roberta训练的时候,可以采取更大的batch size

transformers实现序列标注

这里介绍transformers加载roberta预训练模型实现序列标注,采用的是哈工大的roberta-wwm模型。
(1)tokenizer

from tokenizers import BertWordPieceTokenizer
def get_tokenizer(model_path):
    vocab_file = os.path.join(model_path, "vocab.txt")
    tokenizer = BertWordPieceTokenizer(vocab_file,
                                       lowercase=True)
    return tokenizer

(2)模型输入特征话
模型的输入和bert的输入保持一致

def convert_example_to_feature(context, context_tags, tokenizer):
    code = tokenizer.encode(context)
    new_tags = []
    for offset in code.offsets:
        if offset != (0, 0):
            start_index, end_index = offset
            new_tags.append(context_tags[start_index])

    assert len(code.ids) == len(code.type_ids) == len(code.attention_mask)
    return code.ids, code.type_ids, code.attention_mask, new_tags


def create_inputs_targets_roberta(sentences, tags, tag2id, max_len, tokenizer):
    tokenizer.enable_padding(length=max_len)
    tokenizer.enable_truncation(max_length=max_len)
    dataset_dict = {
        "input_ids": [],
        "token_type_ids": [],
        "attention_mask": [],
        "tags": []
    }

    for sentence, tag in zip(sentences, tags):
        sentence = ''.join(sentence)
        input_ids, token_type_ids, attention_mask, \
        post_tags = convert_example_to_feature(sentence, tag, tokenizer)
        dataset_dict["input_ids"].append(input_ids)
        dataset_dict["token_type_ids"].append(token_type_ids)
        dataset_dict["attention_mask"].append(attention_mask)
        if len(post_tags) < max_len - 2:
            pad_bio_tags = post_tags + [tag2id['O']] * (max_len - 2 - len(post_tags))
        else:
            pad_bio_tags = post_tags[:max_len - 2]
        dataset_dict["tags"].append([tag2id['O']] + pad_bio_tags + [tag2id['O']])

    for key in dataset_dict:
        dataset_dict[key] = np.array(dataset_dict[key])

    x = [
        dataset_dict["input_ids"],
        dataset_dict["token_type_ids"],
        dataset_dict["attention_mask"],
    ]
    y = dataset_dict["tags"]
    return x, y

(3)模型
在transformers加载roberta模型进行fine-tuning,中文序列标注使用的是TFBertForTokenClassification。模型的代码如下:

	model = TFBertForTokenClassification.from_pretrained(args["pretrain_model_path"],
                                                         from_pt=True,
                                                         num_labels=len(list(tag2id.keys())))
  
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08)
    # we do not have one-hot vectors, we can use sparse categorical cross entropy and accuracy
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
    model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

    model.summary()
    model.fit(train_x,
              train_y,
              epochs=epoch,
              verbose=1,
              batch_size=batch_size,
              validation_data=(dev_x, dev_y),
              validation_batch_size=batch_size
             )   # , validation_split=0.1

    # model save
    model_file = os.path.join(args["output_path"], "ner_model.h5")
    model.save_weights(model_file, overwrite=True)
    # save pb model
    tf.keras.models.save_model(model, args["pb_path"], save_format="tf")

(4)模型评价

id2tag = {value: key for key, value in tag2id.items()}
pred_logits = model.predict(data, batch_size=batch_size)[0]
# pred shape [batch_size, max_len]
preds = np.argmax(pred_logits, axis=2).tolist()

assert len(preds) == len(seq_len_list)
# get predcit label
predict_label = []
target_label = []
for i in range(len(preds)):
    pred = preds[i][1:]
    temp = []
    true_label = label[i][:min(seq_len_list[i], len(pred))]
    for j in range(min(seq_len_list[i], len(pred))):
        temp.append(id2tag[pred[j]])
    assert len(temp) == len(true_label)
    target_label.append(true_label)
    predict_label.append(temp)

    # 计算 precision, recall, f1_score
precision = precision_score(target_label, predict_label, average="macro", zero_division=0)
recall = recall_score(target_label, predict_label, average="macro", zero_division=0)
f1 = f1_score(target_label, predict_label, average="macro",
              zero_division=0)
logger.info(classification_report(target_label, predict_label))

后起补充在不同数据集上的结果,如有错误,欢迎大家指证。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-05 20:20:08  更:2021-07-05 20:20:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年3日历 -2024/3/29 14:27:14-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码