IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (情感倾向分类)2分类文本任务,Bert微调详细教程 -> 正文阅读

[人工智能](情感倾向分类)2分类文本任务,Bert微调详细教程

🎀

Dataset:SST-2

Model:bert-base-cased

?

transformers库的使用【三】对预训练模型进行微调

Transformers实战——使用Trainer类训练和评估自己的数据和模型

HuggingFace




从在线库中载入SST2数据集

from datasets import load_dataset
dataset = load_dataset('glue','sst2')

Tokenizer:将input转换为模型可以处理的格式。

from_pretrained方法让你快速加载任何架构的预训练模型,这样你就不必投入时间和资源从头开始训练一个模型。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets删除其中的sentence列,idx列;然后把label列,列名改为labels。这都是按照bert模型需要的处理的。

tokenized_datasets = tokenized_datasets.remove_columns(["sentence","idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

取子集,训练集取1000个,测试集取200个,充分打散。

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(200))

把数据集装入DataLoader

from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=1)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=1)

载入预训练好的序列分类模型

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

在初始化BertForSequenceClassification时,没有使用bert-base-cased的模型检查点的一些权重。['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias',' cls.predictions.transform.LayerNorm.weight' ]

- 如果你从另一个任务或另一个架构上训练的模型的检查点初始化 BertForSequenceClassification,这是预期的(例如,从 BertForPreTraining 模型初始化 BertForSequenceClassification 模型)。

- 如果你从一个你期望完全相同的模型的检查点初始化BertForSequenceClassification(从一个BertForSequenceClassification模型初始化一个BertForSequenceClassification模型),这是不可能的。

BertForSequenceClassification的一些权重没有从bert-base-cased的模型检查点初始化,而是被新初始化。['分类器.权重', '分类器.偏置']

你可能应该在一个下游任务上训练这个模型,以便能够使用它进行预测和推理。

载入预训练好的训练参数

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=1,  # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
)

这个时候如果实例化一个Trainer

from transformers import Trainer
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset= small_train_dataset,
    eval_dataset=small_eval_dataset
)

然后训练(因为用了预训练的参数和模型,所以这叫微调)

训练模型使用trainer对象的train方法
trainer.train()

load_metric的作用是使模型能在训练期间进行模型评估。该函数接收“预测的标签”和“真实的标签”。

import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metric(eval_pred):
    logits,labels = eval_pred
    predictions = np.argmax(logits,axis=-1)
    return metric.compute(predictions = predictions,references = labels)

评估模型

评估模型使用trainer对象的evaluate方法
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset= small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics = compute_metric,
)
trainer.evaluate()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-06 11:03:14  更:2022-05-06 11:05:20 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 15:21:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码