任务简介:
学习一个简单的BERT意图分类项目,了解BERT进行NLP任务时的流程。
任务说明(本节):
- 构建BERT分类模型
- 损失函数计算
导入必须的第三方库:
输入:
%cd ../
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
from transformers import BertPreTrainedModel, BertModel, BertConfig, DistilBertConfig, AlbertConfig
from transformers import BertTokenizer, DistilBertTokenizer, AlbertTokenizer
from bert_finetune_cls.model import ClsBERT
from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels
from bert_finetune_cls.data_loader import load_and_cache_examples
输出:
D:\notebook_workspace\BERT_cls
一、意图分类任务的MLP层
代码:
class IntentClassifier(nn.Module):
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
super().__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_intent_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
二、模型主要架构
代码:
class ClsBERT(BertPreTrainedModel):
def __init__(self, config, args, intent_label_lst):
super(ClsBERT, self).__init__(config)
self.args = args
self.num_intent_labels = len(intent_label_lst)
self.bert = BertModel(config=config)
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids):
outputs = self.bert(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids)
sequence_output = outputs[0]
pooled_output = outputs[1]
intent_logits = self.intent_classifier(pooled_output)
outputs = ((intent_logits),) + outputs[2:]
if intent_label_ids is not None:
if self.num_intent_labels == 1:
intent_loss_fct = nn.MSELoss()
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
else:
intent_loss_fct = nn.CrossEntropyLoss()
intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
outputs = (intent_loss,) + outputs
return outputs
三、交叉熵损失函数 CrossEntropyLoss
PyTorch中CrossEntropyLoss()函数的主要是将softmax -> log -> NLLLoss合并到一块得到的结果, 所以我们自己不需要求softmax。
L
=
?
∑
i
=
1
N
y
i
?
log
?
y
i
^
L=- \sum_{i=1}^{N}y_i* \log \hat{y_i}
L=?i=1∑N?yi??logyi?^?
y
i
y_i
yi?是真正类别的one-hot分布,只有真实类别的概率为1,其他都是 0,
y
i
^
\hat{y_i}
yi?^? 是经由softmax后的分布
-
softmax将输出数据规范化为一个概率分布。 -
然后将Softmax之后的结果取log -
输入负对数损失函数
1. 实例化模型
代码:
MODEL_CLASSES = {
'bert': (BertConfig, ClsBERT, BertTokenizer),
}
MODEL_PATH_MAP = {
'bert': 'bert_finetune_cls/resources/uncased_L-2_H-128_A-2',
}
class Args():
task = None
data_dir = None
intent_label_file = None
args = Args()
args.task = "atis"
args.data_dir = "bert_finetune_cls/data"
args.intent_label_file = "intent_label.txt"
args.max_seq_len = 50
args.model_type = "bert"
args.model_dir = "bert_finetune_cls/experiments/outputs/clsbert_0"
args.model_name_or_path = MODEL_PATH_MAP[args.model_type]
args.train_batch_size = 4
args.dropout_rate = 0.1
tokenizer = load_tokenizer(args)
config = MODEL_CLASSES[args.model_type][0].from_pretrained(args.model_name_or_path)
intent_label_lst = get_intent_labels(args)
model = ClsBERT(config, args, intent_label_lst)
查看 tokenizer:
输入:
print(tokenizer)
输出:
PreTrainedTokenizer(name_or_path='bert_finetune_cls/resources/uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
查看config:
输入:
print(config)
输出:
BertConfig {
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 128,
"initializer_range": 0.02,
"intermediate_size": 512,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 2,
"num_hidden_layers": 2,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.2.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
查看intent标签:
输入:
print(intent_label_lst)
输出:
['UNK', 'atis_abbreviation', 'atis_aircraft', 'atis_aircraft#atis_flight#atis_flight_no', 'atis_airfare', 'atis_airline', 'atis_airline#atis_flight_no', 'atis_airport', 'atis_capacity', 'atis_cheapest', 'atis_city', 'atis_distance', 'atis_flight', 'atis_flight#atis_airfare', 'atis_flight_no', 'atis_flight_time', 'atis_ground_fare', 'atis_ground_service', 'atis_ground_service#atis_ground_fare', 'atis_meal', 'atis_quantity', 'atis_restriction']
查看模型参数:
输入:
print(model)
输出:
ClsBERT(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
)
(intent_classifier): IntentClassifier(
(dropout): Dropout(p=0.1, inplace=False)
(linear): Linear(in_features=128, out_features=22, bias=True)
)
)
2. 加载数据、定义损失函数
输入:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
device = "cpu"
for step, batch in enumerate(train_dataloader):
if step > 1:
continue
batch = tuple(t.to(device) for t in batch)
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
"intent_label_ids": batch[3],}
input_ids = inputs["input_ids"]
print("input_ids: ", input_ids)
attention_mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]
intent_label_ids = inputs["intent_label_ids"]
outputs = model.bert(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids)
pooled_output = outputs[1]
intent_logits = model.intent_classifier(pooled_output)
print("intent_logits: ", intent_logits)
print("intent_logits: ", intent_logits.shape)
intent_loss_fct = nn.CrossEntropyLoss()
intent_loss = intent_loss_fct(intent_logits.view(-1, model.num_intent_labels), intent_label_ids.view(-1))
print("intent_loss: ", intent_loss)
输出:
tensor([[ 101, 1045, 2215, ..., 0, 0, 0],
[ 101, 2461, 4440, ..., 0, 0, 0],
[ 101, 2265, 2033, ..., 0, 0, 0],
...,
[ 101, 2425, 2033, ..., 0, 0, 0],
[ 101, 1045, 1005, ..., 0, 0, 0],
[ 101, 2003, 2045, ..., 0, 0, 0]])
<class 'torch.Tensor'>
input_ids: tensor([[ 101, 2054, 2515, 1996, 13258, 3642, 1042, 1998, 1042, 2078,
2812, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2265, 2033, 7599, 2013, 5759, 2000, 6278, 102, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2265, 2033, 2035, 7599, 2029, 2681, 6278, 4826, 1998,
7180, 1999, 6222, 2044, 1019, 1051, 1005, 5119, 7610, 102,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2054, 2003, 1996, 10036, 4355, 13258, 2008, 1045, 2064,
2131, 2090, 4407, 1998, 2624, 3799, 102, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
intent_logits: tensor([[-6.6785e-02, -2.6867e-01, 9.0055e-02, 2.8000e-01, -5.4302e-03,
2.4999e-02, -9.4538e-03, -9.2926e-02, 6.5185e-02, 5.5395e-02,
6.3910e-02, 2.0029e-01, 1.0393e-02, -3.6474e-03, 2.6379e-01,
-1.2984e-01, -2.4014e-01, 1.9804e-01, 2.8028e-01, -1.1415e-01,
-1.2588e-01, -1.6676e-01],
[-3.2886e-02, -3.4531e-01, 1.6030e-01, 2.0762e-01, 1.3981e-02,
7.5839e-02, 7.0054e-02, -1.1200e-01, 1.3198e-01, 1.2292e-01,
9.7293e-02, 1.6416e-01, 1.4267e-01, 3.2855e-02, 2.9160e-01,
-2.1743e-01, -2.5785e-01, 1.5410e-01, 2.6627e-01, -1.1073e-01,
-1.1497e-01, -1.1996e-01],
[-6.2165e-02, -3.4817e-01, -2.6278e-02, 2.3015e-01, -2.6621e-02,
8.9725e-02, -1.9474e-04, 2.2616e-02, 2.0806e-01, 9.1309e-02,
1.1127e-01, 2.1002e-01, 8.9848e-02, 7.9987e-02, 3.9927e-01,
-1.8766e-01, -2.6148e-01, 9.6200e-02, 2.2902e-01, -1.9097e-01,
-8.6905e-02, -7.7062e-02],
[-8.3226e-02, -2.4569e-01, 1.0448e-02, 2.0898e-01, 1.5219e-02,
8.2574e-02, 5.7453e-02, -3.6945e-02, 1.3960e-01, 1.3904e-01,
1.1725e-01, 2.2389e-01, 1.2740e-01, 9.4694e-03, 3.2417e-01,
-2.2192e-01, -2.8386e-01, 1.0251e-01, 2.8867e-01, -2.0414e-01,
-9.8556e-02, -1.7579e-01]], grad_fn=<AddmmBackward>)
intent_logits: torch.Size([4, 22])
intent_loss: tensor(3.1330, grad_fn=<NllLossBackward>)
input_ids: tensor([[ 101, 2129, 2055, 7599, 2013, 5759, 2000, 5865, 2006, 9317,
2851, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2029, 2142, 7608, 3462, 10029, 2013, 3731, 2000, 2624,
3799, 1998, 3084, 1037, 2644, 7840, 1999, 4407, 102, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2054, 2137, 7608, 7599, 2013, 6708, 2000, 9184, 18280,
6708, 2044, 1020, 7610, 2006, 9317, 102, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 101, 2265, 2033, 1996, 10036, 4355, 7599, 2013, 6222, 2000,
5759, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
intent_logits: tensor([[-0.0413, -0.2474, -0.0339, 0.2605, -0.0834, 0.0339, 0.0849, -0.0924,
0.1840, 0.0764, 0.1212, 0.2463, 0.0661, 0.0312, 0.2384, -0.1402,
-0.3652, 0.1549, 0.1569, -0.1146, 0.0442, -0.0687],
[-0.0079, -0.2114, -0.0082, 0.2215, -0.0067, 0.0373, 0.1558, -0.1708,
0.0856, 0.1550, 0.0960, 0.2478, 0.0917, 0.0946, 0.2417, -0.1291,
-0.1098, 0.2406, 0.2189, -0.0778, -0.1112, -0.1197],
[-0.0593, -0.2813, 0.1841, 0.2781, 0.1761, 0.0876, 0.1257, -0.0545,
0.1210, 0.0936, 0.1363, 0.0835, 0.1384, 0.1019, 0.2470, -0.1459,
-0.2072, 0.0371, 0.1957, -0.0550, -0.0101, -0.0650],
[-0.1047, -0.3152, 0.0247, 0.2151, -0.0135, 0.0146, 0.1752, 0.0308,
0.1882, 0.1092, 0.1121, 0.1614, 0.1212, 0.0079, 0.2502, -0.0955,
-0.2104, 0.0865, 0.2006, -0.1559, -0.0587, -0.1162]],
grad_fn=<AddmmBackward>)
intent_logits: torch.Size([4, 22])
intent_loss: tensor(3.0337, grad_fn=<NllLossBackward>)
|