“You Jump,I Jump”语出经典爱情电影《泰坦尼克号》经典台词,女主角Rose在船首即将跳入海里,站在旁边的男主Jack为挽救女主,便说出经典台词“You Jump,I Jump”。当一个陌生男人肯为一个陌生女人没理由地去死的时候,毫无缘由的,女主对男主产生了爱的情愫。 当然这跟我这篇教程关系不大,这里我们将会通过AI预测Jack和Rose的存活率,国庆没断更,属实不易,需要数据集可以私聊本人or加学习群。谢谢大家支持!
一、数据集
1.读取数据集
import pandas as pd
df = pd.read_excel('titanic3.xls')
df.describe()
| pclass | survived | age | sibsp | parch | fare | body |
---|
count | 1309.000000 | 1309.000000 | 1046.000000 | 1309.000000 | 1309.000000 | 1308.000000 | 121.000000 |
---|
mean | 2.294882 | 0.381971 | 29.881135 | 0.498854 | 0.385027 | 33.295479 | 160.809917 |
---|
std | 0.837836 | 0.486055 | 14.413500 | 1.041658 | 0.865560 | 51.758668 | 97.696922 |
---|
min | 1.000000 | 0.000000 | 0.166700 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
---|
25% | 2.000000 | 0.000000 | 21.000000 | 0.000000 | 0.000000 | 7.895800 | 72.000000 |
---|
50% | 3.000000 | 0.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 | 155.000000 |
---|
75% | 3.000000 | 1.000000 | 39.000000 | 1.000000 | 0.000000 | 31.275000 | 256.000000 |
---|
max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 | 328.000000 |
---|
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 pclass 1309 non-null int64
1 survived 1309 non-null int64
2 name 1309 non-null object
3 sex 1309 non-null object
4 age 1046 non-null float64
5 sibsp 1309 non-null int64
6 parch 1309 non-null int64
7 ticket 1309 non-null object
8 fare 1308 non-null float64
9 cabin 295 non-null object
10 embarked 1307 non-null object
11 boat 486 non-null object
12 body 121 non-null float64
13 home.dest 745 non-null object
dtypes: float64(3), int64(4), object(7)
memory usage: 143.3+ KB
df.head()
| pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest |
---|
0 | 1 | 1 | Allen, Miss. Elisabeth Walton | female | 29.0000 | 0 | 0 | 24160 | 211.3375 | B5 | S | 2 | NaN | St Louis, MO |
---|
1 | 1 | 1 | Allison, Master. Hudson Trevor | male | 0.9167 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | 11 | NaN | Montreal, PQ / Chesterville, ON |
---|
2 | 1 | 0 | Allison, Miss. Helen Loraine | female | 2.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
---|
3 | 1 | 0 | Allison, Mr. Hudson Joshua Creighton | male | 30.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | 135.0 | Montreal, PQ / Chesterville, ON |
---|
4 | 1 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | female | 25.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
---|
2.处理数据集
selected_cols = ['survived','name','pclass','sex','age','sibsp','parch','fare','embarked']
df_selected = df[selected_cols]
df = df[selected_cols]
df.head()
| survived | name | pclass | sex | age | sibsp | parch | fare | embarked |
---|
0 | 1 | Allen, Miss. Elisabeth Walton | 1 | female | 29.0000 | 0 | 0 | 211.3375 | S |
---|
1 | 1 | Allison, Master. Hudson Trevor | 1 | male | 0.9167 | 1 | 2 | 151.5500 | S |
---|
2 | 0 | Allison, Miss. Helen Loraine | 1 | female | 2.0000 | 1 | 2 | 151.5500 | S |
---|
3 | 0 | Allison, Mr. Hudson Joshua Creighton | 1 | male | 30.0000 | 1 | 2 | 151.5500 | S |
---|
4 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | 1 | female | 25.0000 | 1 | 2 | 151.5500 | S |
---|
df.isnull().any()
survived False
name False
pclass False
sex False
age True
sibsp False
parch False
fare True
embarked True
dtype: bool
df.isnull().sum()
survived 0
name 0
pclass 0
sex 0
age 263
sibsp 0
parch 0
fare 1
embarked 2
dtype: int64
df[df.isnull().values==True]
| survived | name | pclass | sex | age | sibsp | parch | fare | embarked |
---|
15 | 0 | Baumann, Mr. John D | 1 | male | NaN | 0 | 0 | 25.9250 | S |
---|
37 | 1 | Bradley, Mr. George ("George Arthur Brayton") | 1 | male | NaN | 0 | 0 | 26.5500 | S |
---|
40 | 0 | Brewe, Dr. Arthur Jackson | 1 | male | NaN | 0 | 0 | 39.6000 | C |
---|
46 | 0 | Cairns, Mr. Alexander | 1 | male | NaN | 0 | 0 | 31.0000 | S |
---|
59 | 1 | Cassebeer, Mrs. Henry Arthur Jr (Eleanor Genev... | 1 | female | NaN | 0 | 0 | 27.7208 | C |
---|
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
---|
1293 | 0 | Williams, Mr. Howard Hugh "Harry" | 3 | male | NaN | 0 | 0 | 8.0500 | S |
---|
1297 | 0 | Wiseman, Mr. Phillippe | 3 | male | NaN | 0 | 0 | 7.2500 | S |
---|
1302 | 0 | Yousif, Mr. Wazli | 3 | male | NaN | 0 | 0 | 7.2250 | C |
---|
1303 | 0 | Yousseff, Mr. Gerious | 3 | male | NaN | 0 | 0 | 14.4583 | C |
---|
1305 | 0 | Zabour, Miss. Thamine | 3 | female | NaN | 1 | 0 | 14.4542 | C |
---|
266 rows × 9 columns
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
df['age'].isnull().any()
False
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)
df['embarked'] = df['embarked'].fillna('S')
df.isnull().any()
survived False
name False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
df = df.drop(['name'],axis=1)
df.head()
| survived | pclass | sex | age | sibsp | parch | fare | embarked |
---|
0 | 1 | 1 | 0 | 29.0000 | 0 | 0 | 211.3375 | 2 |
---|
1 | 1 | 1 | 1 | 0.9167 | 1 | 2 | 151.5500 | 2 |
---|
2 | 0 | 1 | 0 | 2.0000 | 1 | 2 | 151.5500 | 2 |
---|
3 | 0 | 1 | 1 | 30.0000 | 1 | 2 | 151.5500 | 2 |
---|
4 | 0 | 1 | 0 | 25.0000 | 1 | 2 | 151.5500 | 2 |
---|
3.划分特征值和标签值
data = df.values
features = data[:,1:]
labels = data[:,0]
labels.shape
(1309,)
4.定义数据预处理函数
def prepare_data(df):
df = df.drop(['name'],axis=1)
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(age_mean)
df['embarked'] = df['embarked'].fillna('S')
df['sex'] = df['sex'].map({'female':0,'male':1}).astype(int)
df['embarked'] = df['embarked'].map({'C':0,'Q':1,'S':2}).astype(int)
print(df.isnull().any())
data = df.values
features = data[:,1:]
labels = data[:,0]
return features,labels
5.划分训练集和测试集
shuffle_df = df_selected.sample(frac=1)
x_data,y_data = prepare_data(shuffle_df)
x_data.shape,y_data.shape
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
((1309, 7), (1309,))
shuffle_df.head()
| survived | name | pclass | sex | age | sibsp | parch | fare | embarked |
---|
58 | 0 | Case, Mr. Howard Brown | 1 | male | 49.0 | 0 | 0 | 26.0000 | S |
---|
666 | 0 | Barbara, Mrs. (Catherine David) | 3 | female | 45.0 | 0 | 1 | 14.4542 | C |
---|
781 | 0 | Drazenoic, Mr. Jozef | 3 | male | 33.0 | 0 | 0 | 7.8958 | C |
---|
480 | 0 | Laroche, Mr. Joseph Philippe Lemercier | 2 | male | 25.0 | 1 | 2 | 41.5792 | C |
---|
459 | 0 | Jacobsohn, Mr. Sidney Samuel | 2 | male | 42.0 | 1 | 0 | 27.0000 | S |
---|
test_split = 0.2
train_num = int((1 - test_split) * x_data.shape[0])
x_train = x_data[:train_num]
y_trian = y_data[:train_num]
x_test = x_data[train_num:]
y_test = y_data[train_num:]
6.归一化
from sklearn import preprocessing
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_train = minmax_scale.fit_transform(x_train)
x_test = minmax_scale.fit_transform(x_test)
二、模型
import tensorflow as tf
tf.__version__
'2.6.0'
1.建立序列模型
model = tf.keras.models.Sequential()
2.添加隐藏层
model.add(tf.keras.layers.Dense(units=64,
use_bias=True,
activation='relu',
input_dim=7,
bias_initializer='zeros',
kernel_initializer='normal'))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(tf.keras.layers.Dense(units=32,
activation='sigmoid',
input_shape=(64,),
bias_initializer='zeros',
kernel_initializer='uniform'))
model.add(tf.keras.layers.Dropout(rate=0.2))
3.添加输出层
model.add(tf.keras.layers.Dense(units=1,
activation='sigmoid',
input_dim=32,
bias_initializer='zeros',
kernel_initializer='uniform'))
model.summary()
Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_68 (Dense) (None, 64) 512
_________________________________________________________________
dropout_6 (Dropout) (None, 64) 0
_________________________________________________________________
dense_69 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_7 (Dropout) (None, 32) 0
_________________________________________________________________
dense_70 (Dense) (None, 1) 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________
三、训练
1.训练
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.003),loss='binary_crossentropy',metrics=['accuracy'])
train_epochs = 100
batch_size = 40
train_history = model.fit(x=x_train,
y=y_trian,
validation_split=0.2,
epochs=train_epochs,
batch_size=batch_size,
verbose=2)
Epoch 1/100
21/21 - 1s - loss: 0.6780 - accuracy: 0.5854 - val_loss: 0.6464 - val_accuracy: 0.6429
Epoch 2/100
21/21 - 0s - loss: 0.6623 - accuracy: 0.6057 - val_loss: 0.6293 - val_accuracy: 0.6429
Epoch 3/100
21/21 - 0s - loss: 0.6306 - accuracy: 0.6069 - val_loss: 0.5861 - val_accuracy: 0.6667
Epoch 4/100
21/21 - 0s - loss: 0.5771 - accuracy: 0.7336 - val_loss: 0.5199 - val_accuracy: 0.7905
Epoch 5/100
21/21 - 0s - loss: 0.5364 - accuracy: 0.7646 - val_loss: 0.4939 - val_accuracy: 0.7952
Epoch 6/100
21/21 - 0s - loss: 0.5200 - accuracy: 0.7670 - val_loss: 0.4847 - val_accuracy: 0.8143
Epoch 7/100
21/21 - 0s - loss: 0.5118 - accuracy: 0.7718 - val_loss: 0.4771 - val_accuracy: 0.8143
Epoch 8/100
21/21 - 0s - loss: 0.5060 - accuracy: 0.7766 - val_loss: 0.4738 - val_accuracy: 0.8095
Epoch 9/100
21/21 - 0s - loss: 0.4934 - accuracy: 0.7861 - val_loss: 0.4670 - val_accuracy: 0.7952
Epoch 10/100
21/21 - 0s - loss: 0.4966 - accuracy: 0.7814 - val_loss: 0.4637 - val_accuracy: 0.8000
Epoch 11/100
21/21 - 0s - loss: 0.4928 - accuracy: 0.7766 - val_loss: 0.4635 - val_accuracy: 0.7905
Epoch 12/100
21/21 - 0s - loss: 0.4995 - accuracy: 0.7670 - val_loss: 0.4691 - val_accuracy: 0.7905
Epoch 13/100
21/21 - 0s - loss: 0.4886 - accuracy: 0.7957 - val_loss: 0.4620 - val_accuracy: 0.8095
Epoch 14/100
21/21 - 0s - loss: 0.4790 - accuracy: 0.7838 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 15/100
21/21 - 0s - loss: 0.4877 - accuracy: 0.7766 - val_loss: 0.4576 - val_accuracy: 0.8095
Epoch 16/100
21/21 - 0s - loss: 0.4839 - accuracy: 0.7897 - val_loss: 0.4560 - val_accuracy: 0.8095
Epoch 17/100
21/21 - 0s - loss: 0.4813 - accuracy: 0.7814 - val_loss: 0.4614 - val_accuracy: 0.8095
Epoch 18/100
21/21 - 0s - loss: 0.4812 - accuracy: 0.7742 - val_loss: 0.4553 - val_accuracy: 0.8095
Epoch 19/100
21/21 - 0s - loss: 0.4762 - accuracy: 0.7885 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 20/100
21/21 - 0s - loss: 0.4784 - accuracy: 0.7802 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 21/100
21/21 - 0s - loss: 0.4794 - accuracy: 0.7885 - val_loss: 0.4626 - val_accuracy: 0.8000
Epoch 22/100
21/21 - 0s - loss: 0.4824 - accuracy: 0.7838 - val_loss: 0.4567 - val_accuracy: 0.7857
Epoch 23/100
21/21 - 0s - loss: 0.4786 - accuracy: 0.7849 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 24/100
21/21 - 0s - loss: 0.4801 - accuracy: 0.7742 - val_loss: 0.4735 - val_accuracy: 0.7905
Epoch 25/100
21/21 - 0s - loss: 0.4752 - accuracy: 0.7849 - val_loss: 0.4571 - val_accuracy: 0.7905
Epoch 26/100
21/21 - 0s - loss: 0.4688 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8000
Epoch 27/100
21/21 - 0s - loss: 0.4624 - accuracy: 0.7873 - val_loss: 0.4577 - val_accuracy: 0.8048
Epoch 28/100
21/21 - 0s - loss: 0.4656 - accuracy: 0.7993 - val_loss: 0.4602 - val_accuracy: 0.8000
Epoch 29/100
21/21 - 0s - loss: 0.4649 - accuracy: 0.7969 - val_loss: 0.4546 - val_accuracy: 0.8000
Epoch 30/100
21/21 - 0s - loss: 0.4645 - accuracy: 0.7849 - val_loss: 0.4638 - val_accuracy: 0.8000
Epoch 31/100
21/21 - 0s - loss: 0.4635 - accuracy: 0.7921 - val_loss: 0.4603 - val_accuracy: 0.7952
Epoch 32/100
21/21 - 0s - loss: 0.4646 - accuracy: 0.7909 - val_loss: 0.4567 - val_accuracy: 0.7952
Epoch 33/100
21/21 - 0s - loss: 0.4664 - accuracy: 0.7909 - val_loss: 0.4583 - val_accuracy: 0.7952
Epoch 34/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7921 - val_loss: 0.4575 - val_accuracy: 0.8000
Epoch 35/100
21/21 - 0s - loss: 0.4660 - accuracy: 0.7838 - val_loss: 0.4582 - val_accuracy: 0.7952
Epoch 36/100
21/21 - 0s - loss: 0.4577 - accuracy: 0.8005 - val_loss: 0.4567 - val_accuracy: 0.8000
Epoch 37/100
21/21 - 0s - loss: 0.4648 - accuracy: 0.7909 - val_loss: 0.4585 - val_accuracy: 0.7952
Epoch 38/100
21/21 - 0s - loss: 0.4613 - accuracy: 0.7921 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 39/100
21/21 - 0s - loss: 0.4643 - accuracy: 0.7921 - val_loss: 0.4687 - val_accuracy: 0.8000
Epoch 40/100
21/21 - 0s - loss: 0.4696 - accuracy: 0.7814 - val_loss: 0.4601 - val_accuracy: 0.8048
Epoch 41/100
21/21 - 0s - loss: 0.4589 - accuracy: 0.7933 - val_loss: 0.4562 - val_accuracy: 0.7952
Epoch 42/100
21/21 - 0s - loss: 0.4587 - accuracy: 0.7885 - val_loss: 0.4594 - val_accuracy: 0.8000
Epoch 43/100
21/21 - 0s - loss: 0.4601 - accuracy: 0.7981 - val_loss: 0.4563 - val_accuracy: 0.7905
Epoch 44/100
21/21 - 0s - loss: 0.4639 - accuracy: 0.7897 - val_loss: 0.4594 - val_accuracy: 0.8048
Epoch 45/100
21/21 - 0s - loss: 0.4569 - accuracy: 0.7957 - val_loss: 0.4587 - val_accuracy: 0.8000
Epoch 46/100
21/21 - 0s - loss: 0.4619 - accuracy: 0.7957 - val_loss: 0.4556 - val_accuracy: 0.8048
Epoch 47/100
21/21 - 0s - loss: 0.4661 - accuracy: 0.7861 - val_loss: 0.4563 - val_accuracy: 0.8000
Epoch 48/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7969 - val_loss: 0.4538 - val_accuracy: 0.8000
Epoch 49/100
21/21 - 0s - loss: 0.4550 - accuracy: 0.7873 - val_loss: 0.4572 - val_accuracy: 0.8048
Epoch 50/100
21/21 - 0s - loss: 0.4603 - accuracy: 0.7909 - val_loss: 0.4584 - val_accuracy: 0.8000
Epoch 51/100
21/21 - 0s - loss: 0.4575 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8095
Epoch 52/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.8029 - val_loss: 0.4584 - val_accuracy: 0.8048
Epoch 53/100
21/21 - 0s - loss: 0.4594 - accuracy: 0.7909 - val_loss: 0.4558 - val_accuracy: 0.8000
Epoch 54/100
21/21 - 0s - loss: 0.4588 - accuracy: 0.8065 - val_loss: 0.4523 - val_accuracy: 0.8000
Epoch 55/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.8029 - val_loss: 0.4593 - val_accuracy: 0.8048
Epoch 56/100
21/21 - 0s - loss: 0.4578 - accuracy: 0.8100 - val_loss: 0.4614 - val_accuracy: 0.8048
Epoch 57/100
21/21 - 0s - loss: 0.4549 - accuracy: 0.8041 - val_loss: 0.4580 - val_accuracy: 0.8095
Epoch 58/100
21/21 - 0s - loss: 0.4568 - accuracy: 0.7909 - val_loss: 0.4597 - val_accuracy: 0.8095
Epoch 59/100
21/21 - 0s - loss: 0.4567 - accuracy: 0.7981 - val_loss: 0.4532 - val_accuracy: 0.8095
Epoch 60/100
21/21 - 0s - loss: 0.4532 - accuracy: 0.7993 - val_loss: 0.4569 - val_accuracy: 0.7952
Epoch 61/100
21/21 - 0s - loss: 0.4543 - accuracy: 0.7969 - val_loss: 0.4555 - val_accuracy: 0.8000
Epoch 62/100
21/21 - 0s - loss: 0.4472 - accuracy: 0.8053 - val_loss: 0.4543 - val_accuracy: 0.8048
Epoch 63/100
21/21 - 0s - loss: 0.4458 - accuracy: 0.8100 - val_loss: 0.4534 - val_accuracy: 0.8095
Epoch 64/100
21/21 - 0s - loss: 0.4497 - accuracy: 0.8005 - val_loss: 0.4593 - val_accuracy: 0.8000
Epoch 65/100
21/21 - 0s - loss: 0.4511 - accuracy: 0.8053 - val_loss: 0.4522 - val_accuracy: 0.8095
Epoch 66/100
21/21 - 0s - loss: 0.4506 - accuracy: 0.8005 - val_loss: 0.4592 - val_accuracy: 0.7952
Epoch 67/100
21/21 - 0s - loss: 0.4533 - accuracy: 0.8005 - val_loss: 0.4545 - val_accuracy: 0.8000
Epoch 68/100
21/21 - 0s - loss: 0.4481 - accuracy: 0.7909 - val_loss: 0.4545 - val_accuracy: 0.7952
Epoch 69/100
21/21 - 0s - loss: 0.4555 - accuracy: 0.7981 - val_loss: 0.4551 - val_accuracy: 0.8000
Epoch 70/100
21/21 - 0s - loss: 0.4440 - accuracy: 0.8029 - val_loss: 0.4552 - val_accuracy: 0.7952
Epoch 71/100
21/21 - 0s - loss: 0.4584 - accuracy: 0.8029 - val_loss: 0.4530 - val_accuracy: 0.7952
Epoch 72/100
21/21 - 0s - loss: 0.4480 - accuracy: 0.7933 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 73/100
21/21 - 0s - loss: 0.4554 - accuracy: 0.7981 - val_loss: 0.4536 - val_accuracy: 0.7952
Epoch 74/100
21/21 - 0s - loss: 0.4438 - accuracy: 0.8029 - val_loss: 0.4532 - val_accuracy: 0.7952
Epoch 75/100
21/21 - 0s - loss: 0.4483 - accuracy: 0.8053 - val_loss: 0.4515 - val_accuracy: 0.8095
Epoch 76/100
21/21 - 0s - loss: 0.4408 - accuracy: 0.8041 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 77/100
21/21 - 0s - loss: 0.4470 - accuracy: 0.8017 - val_loss: 0.4531 - val_accuracy: 0.8000
Epoch 78/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.8053 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 79/100
21/21 - 0s - loss: 0.4456 - accuracy: 0.8053 - val_loss: 0.4526 - val_accuracy: 0.8048
Epoch 80/100
21/21 - 0s - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4573 - val_accuracy: 0.7952
Epoch 81/100
21/21 - 0s - loss: 0.4496 - accuracy: 0.7981 - val_loss: 0.4573 - val_accuracy: 0.8095
Epoch 82/100
21/21 - 0s - loss: 0.4515 - accuracy: 0.8053 - val_loss: 0.4502 - val_accuracy: 0.8095
Epoch 83/100
21/21 - 0s - loss: 0.4503 - accuracy: 0.8100 - val_loss: 0.4546 - val_accuracy: 0.7952
Epoch 84/100
21/21 - 0s - loss: 0.4386 - accuracy: 0.8065 - val_loss: 0.4540 - val_accuracy: 0.8048
Epoch 85/100
21/21 - 0s - loss: 0.4371 - accuracy: 0.8088 - val_loss: 0.4552 - val_accuracy: 0.8095
Epoch 86/100
21/21 - 0s - loss: 0.4420 - accuracy: 0.8053 - val_loss: 0.4553 - val_accuracy: 0.8048
Epoch 87/100
21/21 - 0s - loss: 0.4437 - accuracy: 0.8112 - val_loss: 0.4550 - val_accuracy: 0.7952
Epoch 88/100
21/21 - 0s - loss: 0.4432 - accuracy: 0.7969 - val_loss: 0.4565 - val_accuracy: 0.8095
Epoch 89/100
21/21 - 0s - loss: 0.4396 - accuracy: 0.8065 - val_loss: 0.4552 - val_accuracy: 0.8000
Epoch 90/100
21/21 - 0s - loss: 0.4477 - accuracy: 0.8088 - val_loss: 0.4554 - val_accuracy: 0.8048
Epoch 91/100
21/21 - 0s - loss: 0.4412 - accuracy: 0.8017 - val_loss: 0.4507 - val_accuracy: 0.8048
Epoch 92/100
21/21 - 0s - loss: 0.4484 - accuracy: 0.7957 - val_loss: 0.4531 - val_accuracy: 0.8048
Epoch 93/100
21/21 - 0s - loss: 0.4433 - accuracy: 0.8017 - val_loss: 0.4519 - val_accuracy: 0.8048
Epoch 94/100
21/21 - 0s - loss: 0.4415 - accuracy: 0.7957 - val_loss: 0.4524 - val_accuracy: 0.8095
Epoch 95/100
21/21 - 0s - loss: 0.4399 - accuracy: 0.8065 - val_loss: 0.4549 - val_accuracy: 0.8048
Epoch 96/100
21/21 - 0s - loss: 0.4387 - accuracy: 0.8065 - val_loss: 0.4546 - val_accuracy: 0.8095
Epoch 97/100
21/21 - 0s - loss: 0.4463 - accuracy: 0.7945 - val_loss: 0.4542 - val_accuracy: 0.8048
Epoch 98/100
21/21 - 0s - loss: 0.4447 - accuracy: 0.7993 - val_loss: 0.4542 - val_accuracy: 0.8143
Epoch 99/100
21/21 - 0s - loss: 0.4368 - accuracy: 0.8041 - val_loss: 0.4551 - val_accuracy: 0.8048
Epoch 100/100
21/21 - 0s - loss: 0.4395 - accuracy: 0.8053 - val_loss: 0.4501 - val_accuracy: 0.8095
2.训练过程可视化
import matplotlib.pyplot as plt
def show_train_history(trian_history,train_metric,validation_metric):
plt.plot(trian_history[train_metric])
plt.plot(trian_history[validation_metric])
plt.title('Train History')
plt.ylabel(train_metric)
plt.xlabel('epoch')
plt.legend(['train','validation'],loc='upper left')
plt.show()
show_train_history(train_history.history,'loss','val_loss')
show_train_history(train_history.history,'accuracy','val_accuracy')
3.评估模型
loss,acc = model.evaluate(x_test,y_test)
9/9 [==============================] - 0s 2ms/step - loss: 0.3703 - accuracy: 0.8435
loss,acc
(0.3702643811702728, 0.8435114622116089)
四.预测
Jack_info = [0,'Jack',3,'male',23,1,0,5.0000,'S']
Rose_info = [1,'Rose',1,'female',20,1,0,100.0000,'S']
x_pre = pd.DataFrame([Jack_info,Rose_info],columns=selected_cols)
x_pre
| survived | name | pclass | sex | age | sibsp | parch | fare | embarked |
---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S |
---|
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S |
---|
x_pre_features,y = prepare_data(x_pre)
from sklearn import preprocessing
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0,1))
x_pre_features = minmax_scale.fit_transform(x_pre_features)
y_pre = model.predict(x_pre_features)
survived False
pclass False
sex False
age False
sibsp False
parch False
fare False
embarked False
dtype: bool
x_pre.insert(len(x_pre.columns),'surv_probabilty',y_pre)
x_pre
| survived | name | pclass | sex | age | sibsp | parch | fare | embarked | surv_probabilty |
---|
0 | 0 | Jack | 3 | male | 23 | 1 | 0 | 5.0 | S | 0.058498 |
---|
1 | 1 | Rose | 1 | female | 20 | 1 | 0 | 100.0 | S | 0.975978 |
---|
|