IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 与Bert结构不完全相同的模型从.ckpt 转换为.bin.报错:AttributeError: ‘BertForPreTraining‘ object has noattribute ‘bias“ -> 正文阅读

[人工智能]与Bert结构不完全相同的模型从.ckpt 转换为.bin.报错:AttributeError: ‘BertForPreTraining‘ object has noattribute ‘bias“

如果模型与bert结构一致,或是transformers中的其他模型,都可以用transformer官方库提供的转换方式进行转换。
1)vim convert.py
2) 使用命令行

python convert.py --tf_checkpoint_path /Users/sunrui/Desktop/cbert/bert_model.ckpt --bert_config_file /Users/sunrui/Desktop/cbert/bert_config.json --pytorch_dump_path /Users/sunrui/Desktop/cbert/pytorch_model.bin
import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

但此时会报错:
AttributeError: ‘BertForPreTraining’ object has no attribute ‘bias’
原因是tensorflow保存的ckpt的key与pytorch transformers的key 不相符
这时,需要将tf ckpt中的key全部打印出来,与BertModel key进行比较。

可以采用以下命令打印tf ckpt中的key:

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./cbert', "bert_model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)

我的打印出来是这样
bert/bert/encoder/layer_9/attention/self/query/kernel/adam_m
应该将bert/bert变为bert/,多了一个bert
于是采用一个rename代码,将ckpt中的key进行rename
rename代码如下:
并使用命令行,将bert/bert 变为bert

python rename.py --checkpoint_dir /Users/sunrui/Desktop/cbert/bert_model.ckpt  --replace_from bert/bert  --replace_to bert

import getopt
import sys

import tensorflow as tf

usage_str = ('python tensorflow_rename_variables.py '
             '--checkpoint_dir=path/to/dir/ --replace_from=substr '
             '--replace_to=substr --add_prefix=abc --dry_run')
find_usage_str = ('python tensorflow_rename_variables.py '
                  '--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr')
comp_usage_str = ('python tensorflow_rename_variables.py '
                  '--checkpoint_dir=path/to/dir/ '
                  '--checkpoint_dir2=path/to/dir/')


def print_usage_str():
    print('Please specify a checkpoint_dir. Usage:')
    print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str))
    print('Note: checkpoint_dir should be a *DIR*, not a file')


def compare(checkpoint_dir, checkpoint_dir2):
    import difflib
    with tf.Session():
        list1 = [el1 for (el1, el2) in
                 tf.contrib.framework.list_variables(checkpoint_dir)]
        list2 = [el1 for (el1, el2) in
                 tf.contrib.framework.list_variables(checkpoint_dir2)]
        for k1 in list1:
            if k1 in list2:
                continue
            else:
                print('{} close matches: {}'.format(
                    k1, difflib.get_close_matches(k1, list2)))


def find(checkpoint_dir, find_str):
    with tf.Session():
        negate = find_str.startswith('!')
        if negate:
            find_str = find_str[1:]
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            if negate and find_str not in var_name:
                print('%s missing from %s.' % (find_str, var_name))
            if not negate and find_str in var_name:
                print('Found %s in %s.' % (find_str, var_name))


def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            # Load the variable
            var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)

            # Set the new name
            if None not in [replace_from, replace_to]:
                new_name = var_name
                if replace_from in var_name:
                    new_name = new_name.replace(replace_from, replace_to)
                    if add_prefix:
                        new_name = add_prefix + new_name
                    if dry_run:
                        print('%s would be renamed to %s.' % (var_name,
                                                              new_name))
                    else:
                        print('Renaming %s to %s.' % (var_name, new_name))
                # Create the variable, potentially renaming it
                var = tf.Variable(var, name=new_name)

        if not dry_run:
            # Save the variables
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            saver.save(sess, '/Users/sunrui/Desktop/cbert/bert_model.ckpt')


def main(argv):
    checkpoint_dir = None
    checkpoint_dir2 = None
    replace_from = None
    replace_to = None
    add_prefix = None
    dry_run = False
    find_str = None

    try:
        opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=',
                                               'replace_from=', 'replace_to=',
                                               'add_prefix=', 'dry_run',
                                               'find_str=',
                                               'checkpoint_dir2='])
    except getopt.GetoptError as e:
        print(e)
        print_usage_str()
        sys.exit(2)
    for opt, arg in opts:
        if opt in ('-h', '--help'):
            print(usage_str)
            sys.exit()
        elif opt == '--checkpoint_dir':
            checkpoint_dir = arg
        elif opt == '--checkpoint_dir2':
            checkpoint_dir2 = arg
        elif opt == '--replace_from':
            replace_from = arg
        elif opt == '--replace_to':
            replace_to = arg
        elif opt == '--add_prefix':
            add_prefix = arg
        elif opt == '--dry_run':
            dry_run = True
        elif opt == '--find_str':
            find_str = arg

    if not checkpoint_dir:
        print_usage_str()
        sys.exit(2)

    if checkpoint_dir2:
        compare(checkpoint_dir, checkpoint_dir2)
    elif find_str:
        find(checkpoint_dir, find_str)
    else:
        rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)


if __name__ == '__main__':
    main(sys.argv[1:])

此外,我打印出来的ckpt中还有这样几项:/loss/output_weights/adam_v
这才BertModel中是不需要的
所以我在anaconda3/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py
中做了这样的修改:
在load_tf_weights_in_bert中,

    for name, array in zip(names, arrays):
        name = name.split("/")
        print(name)
        if name[0] == "loss":
            continue
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-06 11:03:14  更:2022-05-06 11:05:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 7:21:41-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码