IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> KBQA-Bert学习记录-构建BERT-CRF模型 -> 正文阅读

[人工智能]KBQA-Bert学习记录-构建BERT-CRF模型

目录

1.__init__方法

2.forward方法


将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。

1.__init__方法

这里面主要是bert的参数的定义及导入,还有bert模型的导入。

MODEL_NAME = 'bert-base-chinese-model.bin'
CONFIG_NAME = 'bert-base-chinese-config.json'
VOB_NAME = 'bert-base-chinese-vocab.txt'


class BertCrf(nn.Module):
    def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:
        self.batch_first = batch_first
        # 模型配置文件、模型预训练参数文件判断
        if not os.path.exists(config_name):
            raise ValueError(
                "未找到模型配置文件 '{}'".format(config_name)
            )
        else:
            self.config_name = config_name
        if model_name is not None:
            if not os.path.exists(model_name):
                raise ValueError(
                    "未找到模型预训练参数文件 '{}'".format(model_name)
                )
            else:
                self.model_name = model_name
        else:
            self.model_name = None
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()

        # 配置bert的config文件
        self.bert_config = BertConfig.from_pretrained(self.config_name)
        self.bert_config.num_labels = num_tags
        self.model_kwargs = {'config': self.bert_config}

        # 如果模型不存在
        if self.model_name is not None:
            self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)
        else:
            self.bertModel = BertForTokenClassification(self.bert_config)
        self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)

2.forward方法

输出的结果,经过处理后,输入CRF函数,返回loss即可。

    def forward(self, input_ids: torch.Tensor,
                tags: torch.Tensor = None,
                attention_mask: Optional[torch.ByteTensor] = None,
                token_type_ids=torch.Tensor,
                decode:bool = True,
                reduction: str = 'mean')->List:
        emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]

        # 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]
        new_emissions = emissions[:, 1:-1]
        new_mask = attention_mask[:, 2:].bool()

        # tags为None, 是预测过程,不能求loss
        if tags is None:
            loss = None
            pass
        else:
            new_tags = tags[:, 1:-1]
            loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)

        if decode:
            tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)
            return [loss, tag_list]
        return [loss]
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-16 17:40:43  更:2021-12-16 17:42:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 21:04:48-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码