TensorFlow 离线加载预训练模型
TensorFlow 内置了很多预训练的模型。如果要使用这些模型,一般可使用如下两种方式:
第一种方式:直接加载
VGG16_MODEL = tf.keras.applications.VGG16(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
此时,模型会被下载到以下几个位置:
- Windows :C:\Users\Administrator.keras\models
- 其他待验证;
第二种方式:离线加载
因为有时候,直接下载时的网络环境不好,因此需要提前下载好预训练模型,并在使用的时候加载,此时可采用如下方式;
# 已下载好的预训练模型的存放位置
path_weights = r"D:\XX\vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
VGG16_MODEL = tf.keras.applications.VGG16(input_shape=IMG_SHAPE,
include_top=False,
weights=path_weights) # 在此处,将预训练模型作为权重进行加载
|