IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【BERT-意图分类】2. 模型构建与损失函数 -> 正文阅读

[人工智能]【BERT-意图分类】2. 模型构建与损失函数


任务简介:

学习一个简单的BERT意图分类项目,了解BERT进行NLP任务时的流程。

任务说明(本节):

  1. 构建BERT分类模型
  2. 损失函数计算

导入必须的第三方库:

输入:

%cd ../
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
from transformers import BertPreTrainedModel, BertModel, BertConfig, DistilBertConfig, AlbertConfig
from transformers import BertTokenizer, DistilBertTokenizer, AlbertTokenizer
# 下面三个import的类和函数在上一篇笔记已经记录
from bert_finetune_cls.model import ClsBERT
from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels
from bert_finetune_cls.data_loader import load_and_cache_examples

输出:

D:\notebook_workspace\BERT_cls

一、意图分类任务的MLP层

代码:

# intent分类的MLP全连接层
class IntentClassifier(nn.Module):
    def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_intent_labels)  # nn.Linear(input神经元数量,output神经元数量)

    def forward(self, x):
        # x: [batch_size, input_dim]
        x = self.dropout(x)
        return self.linear(x)

二、模型主要架构

代码:

class ClsBERT(BertPreTrainedModel):
    def __init__(self, config, args, intent_label_lst):
        super(ClsBERT, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert
        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)


    def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        
        pooled_output = outputs[1]  # [CLS]
 
        intent_logits = self.intent_classifier(pooled_output)

        outputs = ((intent_logits),) + outputs[2:]  # add hidden states and attention if they are here

        # 1. Intent Softmax
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))

            outputs = (intent_loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

三、交叉熵损失函数 CrossEntropyLoss

PyTorch中CrossEntropyLoss()函数的主要是将softmax -> log -> NLLLoss合并到一块得到的结果, 所以我们自己不需要求softmax。
L = ? ∑ i = 1 N y i ? log ? y i ^ L=- \sum_{i=1}^{N}y_i* \log \hat{y_i} L=?i=1N?yi??logyi?^?
y i y_i yi?是真正类别的one-hot分布,只有真实类别的概率为1,其他都是 0, y i ^ \hat{y_i} yi?^? 是经由softmax后的分布

  • softmax将输出数据规范化为一个概率分布。

  • 然后将Softmax之后的结果取log

  • 输入负对数损失函数

1. 实例化模型

代码:

MODEL_CLASSES = {
    'bert': (BertConfig, ClsBERT, BertTokenizer),
}

MODEL_PATH_MAP = {
    'bert': 'bert_finetune_cls/resources/uncased_L-2_H-128_A-2',
}


# 先构建参数
class Args():
    task =  None
    data_dir =  None
    intent_label_file =  None


args = Args()
args.task = "atis"
args.data_dir = "bert_finetune_cls/data"
args.intent_label_file = "intent_label.txt"
args.max_seq_len = 50
args.model_type = "bert"
args.model_dir = "bert_finetune_cls/experiments/outputs/clsbert_0"
args.model_name_or_path = MODEL_PATH_MAP[args.model_type]
args.train_batch_size = 4
args.dropout_rate = 0.1

tokenizer = load_tokenizer(args)
config = MODEL_CLASSES[args.model_type][0].from_pretrained(args.model_name_or_path)
intent_label_lst = get_intent_labels(args)
model = ClsBERT(config, args, intent_label_lst)

查看 tokenizer:

输入:

print(tokenizer)

输出:

PreTrainedTokenizer(name_or_path='bert_finetune_cls/resources/uncased_L-2_H-128_A-2', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

查看config:

输入:

print(config)

输出:

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.2.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

查看intent标签:

输入:

print(intent_label_lst)

输出:

['UNK', 'atis_abbreviation', 'atis_aircraft', 'atis_aircraft#atis_flight#atis_flight_no', 'atis_airfare', 'atis_airline', 'atis_airline#atis_flight_no', 'atis_airport', 'atis_capacity', 'atis_cheapest', 'atis_city', 'atis_distance', 'atis_flight', 'atis_flight#atis_airfare', 'atis_flight_no', 'atis_flight_time', 'atis_ground_fare', 'atis_ground_service', 'atis_ground_service#atis_ground_fare', 'atis_meal', 'atis_quantity', 'atis_restriction']

查看模型参数:

输入:

print(model)

输出:

ClsBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (activation): Tanh()
    )
  )
  (intent_classifier): IntentClassifier(
    (dropout): Dropout(p=0.1, inplace=False)
    (linear): Linear(in_features=128, out_features=22, bias=True)
  )
)

2. 加载数据、定义损失函数

输入:

# # 1. 定义dataset(torch) 
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

# torch自带的sampler类,功能是每次返回一个随机的样本索引
train_sampler = RandomSampler(train_dataset)

# # 2. 定义dataloader
# 使用dataloader输出batch
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

# # 3. 设置cpu或者gpu模式,按batch读取特征数据
device = "cpu"
for step, batch in enumerate(train_dataloader):
    
    if step > 1:
        continue
    
    batch = tuple(t.to(device) for t in batch) # 将batch上传到显卡
    inputs = {"input_ids": batch[0],
              "attention_mask": batch[1],
              "token_type_ids": batch[2],
              "intent_label_ids": batch[3],}
    
    input_ids = inputs["input_ids"]
    print("input_ids: ", input_ids)
    
    attention_mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]
    intent_label_ids = inputs["intent_label_ids"]
    
    outputs = model.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
    
    pooled_output = outputs[1]  # [CLS]  [4 * 128] 128维向量
    intent_logits = model.intent_classifier(pooled_output)
    print("intent_logits: ", intent_logits)   # [4 * 22]  22标签数量
    print("intent_logits: ", intent_logits.shape)
    
    intent_loss_fct = nn.CrossEntropyLoss()
    intent_loss = intent_loss_fct(intent_logits.view(-1, model.num_intent_labels), intent_label_ids.view(-1))
    print("intent_loss: ", intent_loss)   

输出:

tensor([[ 101, 1045, 2215,  ...,    0,    0,    0],
        [ 101, 2461, 4440,  ...,    0,    0,    0],
        [ 101, 2265, 2033,  ...,    0,    0,    0],
        ...,
        [ 101, 2425, 2033,  ...,    0,    0,    0],
        [ 101, 1045, 1005,  ...,    0,    0,    0],
        [ 101, 2003, 2045,  ...,    0,    0,    0]])
<class 'torch.Tensor'>
input_ids:  tensor([[  101,  2054,  2515,  1996, 13258,  3642,  1042,  1998,  1042,  2078,
          2812,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2265,  2033,  7599,  2013,  5759,  2000,  6278,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2265,  2033,  2035,  7599,  2029,  2681,  6278,  4826,  1998,
          7180,  1999,  6222,  2044,  1019,  1051,  1005,  5119,  7610,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2054,  2003,  1996, 10036,  4355, 13258,  2008,  1045,  2064,
          2131,  2090,  4407,  1998,  2624,  3799,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])
intent_logits:  tensor([[-6.6785e-02, -2.6867e-01,  9.0055e-02,  2.8000e-01, -5.4302e-03,
          2.4999e-02, -9.4538e-03, -9.2926e-02,  6.5185e-02,  5.5395e-02,
          6.3910e-02,  2.0029e-01,  1.0393e-02, -3.6474e-03,  2.6379e-01,
         -1.2984e-01, -2.4014e-01,  1.9804e-01,  2.8028e-01, -1.1415e-01,
         -1.2588e-01, -1.6676e-01],
        [-3.2886e-02, -3.4531e-01,  1.6030e-01,  2.0762e-01,  1.3981e-02,
          7.5839e-02,  7.0054e-02, -1.1200e-01,  1.3198e-01,  1.2292e-01,
          9.7293e-02,  1.6416e-01,  1.4267e-01,  3.2855e-02,  2.9160e-01,
         -2.1743e-01, -2.5785e-01,  1.5410e-01,  2.6627e-01, -1.1073e-01,
         -1.1497e-01, -1.1996e-01],
        [-6.2165e-02, -3.4817e-01, -2.6278e-02,  2.3015e-01, -2.6621e-02,
          8.9725e-02, -1.9474e-04,  2.2616e-02,  2.0806e-01,  9.1309e-02,
          1.1127e-01,  2.1002e-01,  8.9848e-02,  7.9987e-02,  3.9927e-01,
         -1.8766e-01, -2.6148e-01,  9.6200e-02,  2.2902e-01, -1.9097e-01,
         -8.6905e-02, -7.7062e-02],
        [-8.3226e-02, -2.4569e-01,  1.0448e-02,  2.0898e-01,  1.5219e-02,
          8.2574e-02,  5.7453e-02, -3.6945e-02,  1.3960e-01,  1.3904e-01,
          1.1725e-01,  2.2389e-01,  1.2740e-01,  9.4694e-03,  3.2417e-01,
         -2.2192e-01, -2.8386e-01,  1.0251e-01,  2.8867e-01, -2.0414e-01,
         -9.8556e-02, -1.7579e-01]], grad_fn=<AddmmBackward>)
intent_logits:  torch.Size([4, 22])
intent_loss:  tensor(3.1330, grad_fn=<NllLossBackward>)
input_ids:  tensor([[  101,  2129,  2055,  7599,  2013,  5759,  2000,  5865,  2006,  9317,
          2851,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2029,  2142,  7608,  3462, 10029,  2013,  3731,  2000,  2624,
          3799,  1998,  3084,  1037,  2644,  7840,  1999,  4407,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2054,  2137,  7608,  7599,  2013,  6708,  2000,  9184, 18280,
          6708,  2044,  1020,  7610,  2006,  9317,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2265,  2033,  1996, 10036,  4355,  7599,  2013,  6222,  2000,
          5759,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])
intent_logits:  tensor([[-0.0413, -0.2474, -0.0339,  0.2605, -0.0834,  0.0339,  0.0849, -0.0924,
          0.1840,  0.0764,  0.1212,  0.2463,  0.0661,  0.0312,  0.2384, -0.1402,
         -0.3652,  0.1549,  0.1569, -0.1146,  0.0442, -0.0687],
        [-0.0079, -0.2114, -0.0082,  0.2215, -0.0067,  0.0373,  0.1558, -0.1708,
          0.0856,  0.1550,  0.0960,  0.2478,  0.0917,  0.0946,  0.2417, -0.1291,
         -0.1098,  0.2406,  0.2189, -0.0778, -0.1112, -0.1197],
        [-0.0593, -0.2813,  0.1841,  0.2781,  0.1761,  0.0876,  0.1257, -0.0545,
          0.1210,  0.0936,  0.1363,  0.0835,  0.1384,  0.1019,  0.2470, -0.1459,
         -0.2072,  0.0371,  0.1957, -0.0550, -0.0101, -0.0650],
        [-0.1047, -0.3152,  0.0247,  0.2151, -0.0135,  0.0146,  0.1752,  0.0308,
          0.1882,  0.1092,  0.1121,  0.1614,  0.1212,  0.0079,  0.2502, -0.0955,
         -0.2104,  0.0865,  0.2006, -0.1559, -0.0587, -0.1162]],
       grad_fn=<AddmmBackward>)
intent_logits:  torch.Size([4, 22])
intent_loss:  tensor(3.0337, grad_fn=<NllLossBackward>)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-10 10:50:39  更:2021-09-10 10:51:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:55:57-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码