将路透社新闻划分为 46 个互斥的主题,因为有多个类别,所以这是多分类(multiclass classification)问题的一个例子。因为每个数据点只能划分到一个类别,所以更具体地说,这是单标签、多分类(single-label, multiclass classification)问题的一个例子。如果每个数据点可以划分到多个类别(主题),那它就是一个多标签、多分类(multilabel, multiclass classification)问题。
1 单标签多分类问题模型的训练
单标签多分类问题有以下几个要点:
- 隐藏层激活函数用Relu
- 输出层激活函数用softmax
- 损失函数采用分类交叉熵categorical_crossentropy
- 监控指标为准确率metrics=[‘accuracy’]
训练代码如下:
from tensorflow.keras.datasets import reuters
from tensorflow.keras.utils import to_categorical
import numpy as np
from tensorflow.keras import models
from tensorflow.keras import layers
import matplotlib.pyplot as plt
word_index=reuters.get_word_index()
reverse_word_index=dict([(value,key) for (key,value) in word_index.items()])
def decode_review(review_list):
return ' '.join([reverse_word_index.get(i-3,'?') for i in review_list])
def vectorize_sequences(sequences,dimension=10000):
results=np.zeros(shape=(len(sequences),dimension))
for i,sequence in enumerate(sequences):
results[i,sequence]=1.0
return results
def to_one_hot(lables,dimension=46):
results=np.zeros(shape=(len(lables),dimension))
for i,label in enumerate(lables):
results[i,label]=1.
return results
if __name__=='__main__':
(train_data,train_lables),(test_data,test_lables)=reuters.load_data(num_words=10000)
print(decode_review(train_data[0]))
print(train_lables[0])
train_data=vectorize_sequences(train_data)
test_data=vectorize_sequences(test_data)
train_lables=to_one_hot(train_lables,dimension=46)
test_lables=to_one_hot(test_lables,dimension=46)
x_val=train_data[:1000]
y_val=train_lables[:1000]
x_train=train_data[1000:]
y_train=train_lables[1000:]
model=models.Sequential()
model.add(layers.Dense(units=64,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(units=64,activation='relu'))
model.add(layers.Dense(units=46,activation='softmax'))
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])
history=model.fit(x=x_train,
y=y_train,
batch_size=128,
epochs=50,
validation_data=(x_val,y_val))
print(history)
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs = range(1, len(loss_values) + 1)
plt.plot(epochs, loss_values, 'bo', label='Training loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.clf()
acc = history_dict['acc']
val_acc = history_dict['val_acc']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
test_loss,test_acc=model.evaluate(x=test_data,y=test_lables)
print(test_acc)
model.save(filepath='./model/reuters.h5')
训练的过程如下所示:
Epoch 45/50
63/63 [==============================] - 2s 24ms/step - loss: 0.0734 - acc: 0.9599 - val_loss: 2.2102 - val_acc: 0.7730
Epoch 46/50
63/63 [==============================] - 1s 18ms/step - loss: 0.0713 - acc: 0.9583 - val_loss: 2.3841 - val_acc: 0.7690
Epoch 47/50
63/63 [==============================] - 1s 18ms/step - loss: 0.0704 - acc: 0.9598 - val_loss: 2.3362 - val_acc: 0.7730
Epoch 48/50
63/63 [==============================] - 1s 17ms/step - loss: 0.0707 - acc: 0.9568 - val_loss: 2.3780 - val_acc: 0.7750
Epoch 49/50
63/63 [==============================] - 1s 21ms/step - loss: 0.0693 - acc: 0.9588 - val_loss: 2.4743 - val_acc: 0.7720
Epoch 50/50
63/63 [==============================] - 1s 21ms/step - loss: 0.0686 - acc: 0.9602 - val_loss: 2.4135 - val_acc: 0.7750
绘制的训练损失和验证损失图如下: 绘制的训练准确率和验证准确率图像如图所示: 模型的网络结果图如下所示:
2 单标签多分类模型的调用
调用代码如下:
from bitarray import test
from tensorflow.keras.datasets import reuters
import numpy as np
from tensorflow.keras import models
from tensorflow.keras import layers
import matplotlib.pyplot as plt
word_index=reuters.get_word_index()
reverse_word_index=dict([(value,key) for (key,value) in word_index.items()])
def decode_review(review_list):
return ' '.join([reverse_word_index.get(i-3,'?') for i in review_list])
def vectorize_sequences(sequences,dimension=10000):
results=np.zeros(shape=(len(sequences),dimension))
for i,sequence in enumerate(sequences):
results[i,sequence]=1.0
return results
if __name__=='__main__':
(train_data,train_lables),(test_data,test_lables)=reuters.load_data(num_words=10000)
test_data_0=vectorize_sequences(test_data)
print(test_data_0.shape)
model=models.load_model(filepath='./model/reuters.h5')
model.summary()
print(decode_review(test_data[5]))
print('该评论的真实标签为:'+str(test_lables[5]))
result=model.predict(test_data_0[5].reshape(1,-1))
result=result.argmax(axis=1)
print('预测的类别为:'+str(result[0]))
程序运行的结果为:
? shr 12 cts vs 15 cts net 282 000 vs 360 000 revs 5 261 000 vs 5 348 000 avg shrs 2 336 000 vs 2 335 000 year shr 91 cts vs 1
04 dlrs net 2 149 000 vs 2 075 000 revs 28 2 mln vs 28 3 mln avg shrs 2 356 000 vs 2 001 000 note 1986 quarter net includes 72
000 dlr charge from ? of investment tax credit reuter 3
该评论的真实标签为:3
2022-01-27 18:27:07.987209: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
预测的类别为:3
|