传送门:官方文档 这里只想简单记录应用时遇到的一些坑。
keras模型的保存
假设有模型如下: keras模型的有三种形式:
- 使用
model.save 此种方式可以hdf5格式或者saved_model格式保存整个模型。
import tensorflow as tf
from tensorflow import keras
model = keras.applications.ResNet50()
model.save("model.hdf5", include_optimizer=False, save_format="h5)
model.save("saved_model/1", include_optimizer=False, save_format="tf")
- 使用
model.save_weights 这种方式只保存模型的参数,并不保存模型的结构。
import tensorflow as tf
from tensorflow import keras
model = keras.applications.ResNet50()
model.save_weights("weights.hdf5", save_format="h5")
model.save_weights("model.ckpt", save_format="tf")
- 使用
tf.saved_model.save 以saved_model格式保存整个模型
import tensorflow as tf
from tensorflow import keras
model = keras.applications.ResNet50()
tf.saved_model.save(model, "saved_model/2")
keras 模型的加载
对应保存方式的模型加载
保存 | 加载 |
---|
model.save | keras.models.load_model | model.save_weights | model.load_weights | tf.saved_model.save | tf.saved_model.load |
keras 模型保存成saved_model 格式,虽可以使用keras.models.load_model 或者tf.saved_model.load 进行加载,但都会报类似下面的warning ,并且会严重拖慢模型加载的速度,但h5 格式的模型通过keras.models.load_model 不会出现此类问题,很奇怪。
WARNING:tensorflow:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_xxxxx) with ops with custom gradients. Will likely fail if a gradient is requested.
|