首先在看本内容之前,推荐你去看tensorflow官网的迁移学习和微调,那个讲的更清楚,也更加详细,网址如下所示。但是官网里面的模型加载确实在网络上下载的,不能符合实际的需求,因为有些情况,我们想用自己预训练的模型加载并且进行微调,那么就是本内容的重点。https://tensorflow.google.cn/guide/keras/transfer_learning?hl=zh_cn
大致的内容步骤如下
1、保存自己预训练的模型
2、加载自己的模型
3、进行微调
那么首先第一步,保存自己与训练的模型
我们首先应该调用model.save_weights()这个函数,保存自己的模型。数据集自自己调整下。
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Conv1D,BatchNormalization,Dense,Activation,GlobalAveragePooling1D
from tensorflow.keras.regularizers import l2
auto_train_x_a = np.load("U_process_train_x.npy")[0:30,:]
auto_train_y = np.load("auto_process_data_train_y.npy")[0:30,:]
def auto():
x_input = tf.keras.Input(shape=(128,))
x1 = tf.keras.layers.Dense(128,kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x_input)
x2 = tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x1)
x3 = tf.keras.layers.Dense(128, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x2)
model = tf.keras.Model(inputs=x_input,output=x3)
return model
base_model = auto()
model.compile(''')
model.fit(''')
checkpoint_path = "auto/data_a/cp.ckpt"
base_model.save_weights(checkpoint_path)
那么首先第二步,加载自己训练的模型
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Conv1D,BatchNormalization,Dense,Activation,GlobalAveragePooling1D
from tensorflow.keras.regularizers import l2
auto_train_x_a = np.load("U_process_train_x.npy")[0:30,:]
auto_train_y = np.load("auto_process_data_train_y.npy")[0:30,:]
def auto():
x_input = tf.keras.Input(shape=(128,))
x1 = tf.keras.layers.Dense(128,kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x_input)
x2 = tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x1)
x3 = tf.keras.layers.Dense(128, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x2)
model = tf.keras.Model(inputs=x_input,output=x3)
return model
base_model = auto()
checkpoint_path = "auto/data_a/cp.ckpt"
base_model.load_weights(checkpoint_path)
那么首先第三步,加载自己的模型后我们再添加新的层
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Conv1D,BatchNormalization,Dense,Activation,GlobalAveragePooling1D
from tensorflow.keras.regularizers import l2
auto_train_x_a = np.load("U_process_train_x.npy")[0:30,:]
auto_train_y = np.load("auto_process_data_train_y.npy")[0:30,:]
def auto():
x_input = tf.keras.Input(shape=(128,))
x1 = tf.keras.layers.Dense(128,kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x_input)
x2 = tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x1)
x3 = tf.keras.layers.Dense(128, kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x2)
model = tf.keras.Model(inputs=x_input,output=x3)
return model
base_model = auto()
checkpoint_path = "auto/data_a/cp.ckpt"
base_model.load_weights(checkpoint_path)
base_model.trainable = False
x_input = tf.keras.Input(shape=(128,))
x = base_model(x_input)
x = tf.keras.layers.Dense(128,kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
model2 = tf.keras.Model(inputs=x_input,output=x)
model2.compile()
model2.fit()
|