如果模型与bert结构一致,或是transformers中的其他模型,都可以用transformer官方库提供的转换方式进行转换。 1)vim convert.py 2) 使用命令行
python convert.py --tf_checkpoint_path /Users/sunrui/Desktop/cbert/bert_model.ckpt --bert_config_file /Users/sunrui/Desktop/cbert/bert_config.json --pytorch_dump_path /Users/sunrui/Desktop/cbert/pytorch_model.bin
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config = BertConfig.from_json_file(bert_config_file)
print(f"Building PyTorch model from configuration: {config}")
model = BertForPreTraining(config)
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
但此时会报错: AttributeError: ‘BertForPreTraining’ object has no attribute ‘bias’ 原因是tensorflow保存的ckpt的key与pytorch transformers的key 不相符 这时,需要将tf ckpt中的key全部打印出来,与BertModel key进行比较。
可以采用以下命令打印tf ckpt中的key:
import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./cbert', "bert_model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
我的打印出来是这样 bert/bert/encoder/layer_9/attention/self/query/kernel/adam_m 应该将bert/bert变为bert/,多了一个bert 于是采用一个rename代码,将ckpt中的key进行rename rename代码如下: 并使用命令行,将bert/bert 变为bert
python rename.py --checkpoint_dir /Users/sunrui/Desktop/cbert/bert_model.ckpt --replace_from bert/bert --replace_to bert
import getopt
import sys
import tensorflow as tf
usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --replace_from=substr '
'--replace_to=substr --add_prefix=abc --dry_run')
find_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr')
comp_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ '
'--checkpoint_dir2=path/to/dir/')
def print_usage_str():
print('Please specify a checkpoint_dir. Usage:')
print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str))
print('Note: checkpoint_dir should be a *DIR*, not a file')
def compare(checkpoint_dir, checkpoint_dir2):
import difflib
with tf.Session():
list1 = [el1 for (el1, el2) in
tf.contrib.framework.list_variables(checkpoint_dir)]
list2 = [el1 for (el1, el2) in
tf.contrib.framework.list_variables(checkpoint_dir2)]
for k1 in list1:
if k1 in list2:
continue
else:
print('{} close matches: {}'.format(
k1, difflib.get_close_matches(k1, list2)))
def find(checkpoint_dir, find_str):
with tf.Session():
negate = find_str.startswith('!')
if negate:
find_str = find_str[1:]
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
if negate and find_str not in var_name:
print('%s missing from %s.' % (find_str, var_name))
if not negate and find_str in var_name:
print('Found %s in %s.' % (find_str, var_name))
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
if None not in [replace_from, replace_to]:
new_name = var_name
if replace_from in var_name:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name,
new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
var = tf.Variable(var, name=new_name)
if not dry_run:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, '/Users/sunrui/Desktop/cbert/bert_model.ckpt')
def main(argv):
checkpoint_dir = None
checkpoint_dir2 = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
find_str = None
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=',
'replace_from=', 'replace_to=',
'add_prefix=', 'dry_run',
'find_str=',
'checkpoint_dir2='])
except getopt.GetoptError as e:
print(e)
print_usage_str()
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--checkpoint_dir2':
checkpoint_dir2 = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
elif opt == '--find_str':
find_str = arg
if not checkpoint_dir:
print_usage_str()
sys.exit(2)
if checkpoint_dir2:
compare(checkpoint_dir, checkpoint_dir2)
elif find_str:
find(checkpoint_dir, find_str)
else:
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])
此外,我打印出来的ckpt中还有这样几项:/loss/output_weights/adam_v 这才BertModel中是不需要的 所以我在anaconda3/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py 中做了这样的修改: 在load_tf_weights_in_bert中,
for name, array in zip(names, arrays):
name = name.split("/")
print(name)
if name[0] == "loss":
continue
|