pytorch > onnx
准备库:pytorch要大于等于1.8.0
pip install transformers[onnx]
先将预训练模型保存成.pt
格式
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pt_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained("local-pt-checkpoint")
pt_model.save_pretrained("local-pt-checkpoint")
然后将local-pt-checkpoint
文件夹内的模型转化成onnx,放入onnx
文件夹
python -m transformers.onnx --model=local-pt-checkpoint onnx/