IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 在自己的数据集上重新训练BERT(附代码) -> 正文阅读

[人工智能]在自己的数据集上重新训练BERT(附代码)

最近有需要在新的领域进一步训练BERT,因此参照了hugging face官方文档写了相应的代码。本文采用的是hugging face提供的checkpoint,并在相应的task special领域进行了微调。由于项目的保密协议代码数据不便全部公开,下面只给出关键的部分。

BERT MLM

重新训练BERT主要是在自己的数据集上实现Masked Language Model的预测任务。我忘记了在哪篇论文里看到Next Sentence Prediction对下游的任务的增益其实并不大(如果有误还请指出),并且本次重新训练是基于短句子语料的,所以只考虑MLM任务。还是用李宏毅老师的PPT中的例子说明MLM的目标:
在这里插入图片描述
在给定一个句子,以一定的概率随机mask其中的token(BERT中使用15%的概率),MLM的目标是在整个BERT的词表空间中预测[MASK]的词的概率分布,也就是会产出一个 ∣ V ∣ |V| V的概率向量, V V V表示词表,经过 s o f t m a x ( P ∣ V ∣ ) softmax(P_{|V|}) softmax(PV?)之后就可以获取到最可能的预测结果。MLM旨在让BERT通过self-attention熟悉相应的上下文。

词表扩充

BERT中的token表的大小是有限的,如果领域包含词表中未录入的词,则会产生[UNK]。为此,需要对词表进行扩充,如get_new_tokens所示:

class DataLoader:
    def __init__(self, in_dir='../初始数据.csv',
                 out_dir='checkpoints/policybert',
                 bert_dir='checkpoints/bert',
                 train_source='title',
                 batch_size: int = 64,
                 max_len: int = 64,
                 shuffle: bool = True,
                 mask_token='[MASK]',
                 mask_rate=0.15):
        super(DataLoader, self).__init__()
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.bert_dir = bert_dir
        self.train_source = train_source
        if len(os.listdir(out_dir)) == 0:   # 如果没有保存好的checkpoints,那么就使用BERT的tokenizer
            self.tokenizer = BertTokenizer.from_pretrained(self.bert_dir)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(self.out_dir)
        self.get_data()
        # 定义BERT数据加载的迭代器
        self.bert_iter = BERTMLMDataIter(datas=self.datas, tokenizer=self.tokenizer,
                                         max_len=max_len, batch_size=batch_size)
        self.model = BertForMaskedLM.from_pretrained(self.bert_dir)
        # 注意在使用之前resize bemedding大小
        self.model.resize_token_embeddings(len(self.tokenizer))

    def get_data(self):
        ''' 加载源数据获取raw的文本,文本是以excel(csv)形式存放的,并且只加载'title'字段的文本进行训练
        :return:
        '''
        self.datas = []
        frame = list(pd.read_csv(self.in_dir)[self.train_source].values)
        self.datas.extend(frame)

    def get_new_tokens(self):
        '''
        为BERT词表添加新的tokens
        :return:
        '''
        self.new_tokens = []
        for data in self.datas:
            # tokens = self.tokenizer.tokenize(data)
            # print(tokens)
            for word in data:
                if word not in self.tokenizer.vocab:  
                	# 由于是中文的模型,因此这里剔除一些非中文的特殊字符
                    if u'\u4e00' <= word <=u'\u9fff' and word not in self.new_tokens:
                        self.new_tokens.append(word)
        
        self.tokenizer.add_tokens(self.new_tokens)
        self.tokenizer.save_pretrained(self.out_dir)  #保存增加的词表

在词表扩充完毕下一次加载的时候,需要注意,因为BERT的第一层是Embedding层,参数量依赖于词表的大小,此时词表已经发生了变化,因此需要对其进行resize,也就是:self.model.resize_token_embeddings(len(self.tokenizer))。

构建迭代器

由于数据加载使用的csv中包含了几十万条数据,为了不为难显存,所以我选择了进行数据的动态加载,也就是构建下述迭代器:

class BERTMLMDataIter():
    '''
    BertForMaskedLM的数据加载工具,其输入的格式为:The capital of France is [MASK]转化之后的ids,
    输出则为[Mask]的预测
    '''
    def __init__(self, datas:list, tokenizer: BertTokenizer, batch_size: int = 32,
                 max_len: int = 128, shuffle:bool=True, mask_token='[MASK]', mask_rate=0.15):
        super(BERTMLMDataIter).__init__()
        self.datas = datas
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_len = max_len
        self.shuffle = shuffle
        self.Mask_id = self.tokenizer.convert_tokens_to_ids(mask_token)
        self.mask_rate = mask_rate
        # 首次初始化
        self.reset()
        self.ipts = 0

    def reset(self):
        print("dataiter reset, 读取数据")
        if self.shuffle:
            random.shuffle(self.datas)
        self.data_iter = iter(self.datas)

    def random_mask(self, tokens, rate):
        '''
        :param token: 需要mask的初始字符串
        :param rate:
        :return:
        '''
        mask_tokens, label = [], []
        for word in tokens:
            mmm = random.random()
            if mmm <= rate:
                mask_tokens.append(self.Mask_id)
                label.append(word)
            else:
                mask_tokens.append(word)
                label.append(-100)    # -100表示计算损失函数的时候不计算该值
        return mask_tokens, label

    def get_data(self):
        ''' 获取mask的data数据以及标签的程序
        :return:
        '''
        data_ids = []
        labels = []
        att_masks = []
        for data in self.datas:
            data_id = self.tokenizer.encode(data)
            masked_data, label = self.random_mask(data_id, self.mask_rate)
            att_mask = [1]*len(masked_data)+[0]*(self.max_len-len(masked_data))
            masked_data = masked_data + [0]*(self.max_len-len(masked_data))
            label = label + [-100]*(self.max_len-len(label))
            data_ids.append(masked_data)
            labels.append(label)
            att_masks.append(att_mask)

    def get_batch_data(self):
        ''''''
        batch_data = []
        for i in self.data_iter:
            batch_data.append(i)
            if len(batch_data) == self.batch_size:
                break
        if len(batch_data) < 1:
            return None
        data_ids = []
        labels = []
        att_masks = []
        for data in batch_data:
            data_id = self.tokenizer.encode(data)
            masked_data, label = self.random_mask(data_id, self.mask_rate)
            if len(masked_data) < self.max_len:
                att_mask = [1] * len(masked_data) + [0] * (self.max_len - len(masked_data))
            else:
                att_mask = [1]*self.max_len
            masked_data = masked_data[:self.max_len]
            label = label[:self.max_len]
            att_mask = att_mask[:self.max_len]
            masked_data = masked_data + [0] * (self.max_len - len(masked_data))
            label = label + [-100] * (self.max_len - len(label))
            data_ids.append(masked_data)
            labels.append(label)
            att_masks.append(att_mask)
        batch_ipts = {}
        batch_ipts['ids'] = torch.LongTensor(data_ids)
        batch_ipts['mask'] = torch.LongTensor(att_masks)
        batch_ipts['label'] = torch.LongTensor(labels)
        return batch_ipts

    def __iter__(self):
        return self

    def __next__(self):
        if self.ipts is None:
            self.reset()
        self.ipts = self.get_batch_data()
        if self.ipts is None:
            raise StopIteration
        else:
            return self.ipts

get_batch_data用于在训练时候每次处理一个batch的数据并放入模型训练,因此不用预先将full-batch的数据都进行预处理。
每一个batch的数据处理过程中,都采用random_mask进行15%的随机mask。
这里,主要使用了BertForMaskedLM进行预训练。BertForMaskedLM的输入是带有mask的token_ids,其实就是一串id数字,[MASK]对应的id是103。然后是attention_mask,用于告知模型哪些token需要参与到self-attention的计算,那些不需要。以及labels,对应的标签,没有被mask的token对应的标签是-100,表明该位置的未mask词不参与到损失函数的计算过程中。

训练

训练的时候就比较简单了,调用封装好的train方法即可:

    def train_my_bert(self):
        ttt = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.set_device(0)
        CE = torch.nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0)
        best_loss = 10000000
        for epoch in tqdm(range(self.epoch)):
            total_loss = 0.0
            # self.loader.bert_iter是BERTMaskedLM专用的数据迭代器
            for step, ipt in tqdm(enumerate(self.loader.bert_iter)):
            	# 获取一个batch的训练数据
                ipt = {k: v.to(device) for k, v in ipt.items()}
                out = self.model(input_ids=ipt['ids'],
                                 attention_mask=ipt['mask'],
                                 labels=ipt['label'])
                loss = out[0]
                total_loss += loss.data.item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                optimizer.step()
                self.model.zero_grad()
                if step % 10 == 0:
                    print('current batch loss:{}'.format(loss.data.item()))

            if total_loss < best_loss:   # 保存最优模型(loss最小的模型)
                print('save best model to {}!!!'.format(self.out_dir))
                best_loss = total_loss
                torch.save(self.model, self.out_dir+'/pytorch_model_{}.bin'.format(ttt))

值得注意的是BERTMaskedLM的返回结果中就就包含loss了,因此自定义的损失函数没有用上,具体可以参见BERTMaskedLM的相关文档,写得十分清晰。稍微看一下训练的log输出,其实训练的速度还是挺快的(但是架不住数据集太大):
在这里插入图片描述

总结

训练BERT还是需要挺大量的数据集的,目前我们的工作中对下游任务进行re-train之后的参数效果是否更好,还有待测试。本文只是提供一个相应的思路。如果在小规模数据集上做微调,那么还是推荐使用主任务+MLM辅助任务的形式,让BERT更适配于当前的任务。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-01 15:13:42  更:2022-06-01 15:15:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 23:00:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码