import sklearn.preprocessing
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score,accuracy_score
from network import *
from load import *
from load_2 import *
# import load_2
BS =300
EPOCHS = 500
NUM_C = 21
d_LEARNING_RATE = 0.005
BETA_1 = 0.5
def train():
#读取网络输入数据
X_train, y_train, X_test, y_test = load_2()
# print(X_train.shape)
# 转换为one_hot类型
# Y_train_one = tf.keras.utils.to_categorical(y_train,num_classes=NUM_C)
# Y_test_one = tf.keras.utils.to_categorical(y_test,num_classes=NUM_C)
# Y_train_one = tf.one_hot(y_train, NUM_C)
# Y_test_one = tf.one_hot(y_test, NUM_C)
#构建CNN分类网络
d = CNN_focus_predict_model()
#编译CNN分类网络
d_optimizer = tf.keras.optimizers.Adam(lr=d_LEARNING_RATE, beta_1=BETA_1)
d.compile(loss='mean_squared_error', optimizer=d_optimizer, metrics=['accuracy']) #loss='categorical_crossentropy'
#训练方法1
H = d.fit(X_train, y_train, batch_size=100, epochs=EPOCHS, verbose=2)
# #训练方法2
# aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
# height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
# horizontal_flip=True, fill_mode="nearest")
# H = d.fit_generator(aug.flow(X_train, Y_train_one, batch_size=BS),steps_per_epoch=len(X_train) // BS,
# epochs=EPOCHS)
# 评估模型
score = d.evaluate(X_test, y_test, verbose=0)
print('Test score:', score[0])
# print('Test accuracy:', score[1])
# 测试
# lb = sklearn.preprocessing.MultiLabelBinarizer()
# print("------测试网络------")
# predictions = d.predict(X_test, batch_size=32)
# print(classification_report(y_test.argmax(axis=1),
# predictions.argmax(axis=1), target_names=lb.classes_))
#绘制loss曲线
N1 = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N1, H.history["loss"], label="train_loss")
# plt.plot(N1, H.history["accuracy"], label="train_acc")
plt.title("Training Loss")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig('huigui_loss.png')
#绘制预测结果图(回归问题)
y_test_len = len(y_test)
N = np.arange(0, y_test_len)
Y_pred = d.predict(X_test)
plt.style.use("ggplot")
plt.figure()
plt.scatter(N, y_test, label="sample",color='red',marker =".",s= 10)
plt.scatter(N, Y_pred, label="predict",color='blue',marker ="." ,s=10)
# plt.plot(N, Y_pred, label="predict",color='blue',linestyle='--')
plt.title("Sample points and predicted results")
plt.legend() #画图lable
plt.savefig('huigui_sample.png')
#保存训练网络参数
d.save_weights("./generator_weight.h5", True)
if __name__ == "__main__":
train()
|