1.分析数据集
数据集:
链接:https://pan.baidu.com/s/1YY9HuDqCSr3-CHWON3NdKg 提取码:15eq
mnist_train.csv 数据集一共 (60000, 785) 行列 数据。 已知 28 * 28 = 784
- 第一列的值为标签值。范围(0, 9), 我们希望神经网络能够预测得到正确的标签值。
- 剩下的 784 = 28*28 列数据 是手写识别体的数字的像素值。
因此 我们可以把第一列作为标签值,剩下的 28*28 列 作为 变量。
import pandas as pd
import numpy as np
path = r'data\mnist_train.csv'
df = pd.read_csv(path, header=None)
df.head()
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 |
---|
0 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
---|
1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
---|
2 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
---|
3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
---|
4 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
---|
5 rows × 785 columns
df.shape
(60000, 785)
df.describe()
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 |
---|
count | 60000.000000 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | 60000.0 | ... | 60000.000000 | 60000.000000 | 60000.000000 | 60000.000000 | 60000.000000 | 60000.0000 | 60000.0 | 60000.0 | 60000.0 | 60000.0 |
---|
mean | 4.453933 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.200433 | 0.088867 | 0.045633 | 0.019283 | 0.015117 | 0.0020 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
std | 2.889270 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 6.042472 | 3.956189 | 2.839845 | 1.686770 | 1.678283 | 0.3466 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
min | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0000 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
25% | 2.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0000 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
50% | 4.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0000 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
75% | 7.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0000 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
max | 9.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 254.000000 | 254.000000 | 253.000000 | 253.000000 | 254.000000 | 62.0000 | 0.0 | 0.0 | 0.0 | 0.0 |
---|
8 rows × 785 columns
读取数据的另外一种方法:
注:这种方法不常用,直接使用 pd.read_csv() 即可,非常方便。
data_file = open(r"data\mnist_train.csv", 'r')
data_list = data_file.readlines()
print(len(data_list))
data_file.close()
60000
data_list[0]
'5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n'
提取训练集的前100条数据
测试集的前10条数据
df.iloc[:100].to_csv(r"data\mnist_train_100.csv", index=False, header=False)
test = pd.read_csv(r'data\mnist_test.csv', header=None)
test.iloc[:10].to_csv(r'data\mnist_test_10.csv', index=False, header=False)
观察数据
train_path = r"data\mnist_train_100.csv"
test_path = r"data\mnist_test_10.csv"
data_file = open(train_path, 'r')
data_list = data_file.readlines()
data_file.close()
len(data_list)
100
data_list[0]
'5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n'
处理数据
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
all_values = data_list[0].split(',')
len(all_values)
image_array = np.asfarray(all_values[1:]).reshape(28, 28)
plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dcd8c0ca0>
inputs = (np.asfarray(all_values[1:]) / 255 * 0.99) + 0.01
inputs.shape
(784,)
inputs = np.array(inputs, ndmin=2).T
print(inputs)
inputs.shape
[[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.02164706]
[0.07988235]
[0.07988235]
[0.07988235]
[0.49917647]
[0.538 ]
[0.68941176]
[0.11094118]
[0.65447059]
[1. ]
[0.96894118]
[0.50305882]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.12647059]
[0.14976471]
[0.37494118]
[0.60788235]
[0.67 ]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.88352941]
[0.67776471]
[0.99223529]
[0.94952941]
[0.76705882]
[0.25847059]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.20023529]
[0.934 ]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.98447059]
[0.37105882]
[0.32835294]
[0.32835294]
[0.22741176]
[0.16141176]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.07988235]
[0.86023529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.77870588]
[0.71658824]
[0.96894118]
[0.94564706]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.32058824]
[0.61564706]
[0.42541176]
[0.99223529]
[0.99223529]
[0.80588235]
[0.05270588]
[0.01 ]
[0.17694118]
[0.60788235]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.06435294]
[0.01388235]
[0.60788235]
[0.99223529]
[0.35941176]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.54964706]
[0.99223529]
[0.74764706]
[0.01776471]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.05270588]
[0.74764706]
[0.99223529]
[0.28176471]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.14588235]
[0.94564706]
[0.88352941]
[0.63117647]
[0.42929412]
[0.01388235]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.32447059]
[0.94176471]
[0.99223529]
[0.99223529]
[0.472 ]
[0.10705882]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.18470588]
[0.73211765]
[0.99223529]
[0.99223529]
[0.59235294]
[0.11482353]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.07211765]
[0.37105882]
[0.98835294]
[0.99223529]
[0.736 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.97670588]
[0.99223529]
[0.97670588]
[0.25847059]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.18858824]
[0.51470588]
[0.72047059]
[0.99223529]
[0.99223529]
[0.81364706]
[0.01776471]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.16141176]
[0.58458824]
[0.89905882]
[0.99223529]
[0.99223529]
[0.99223529]
[0.98058824]
[0.71658824]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.10317647]
[0.45258824]
[0.868 ]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.79035294]
[0.31282353]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.09929412]
[0.26623529]
[0.83694118]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.77870588]
[0.32447059]
[0.01776471]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.07988235]
[0.67388235]
[0.86023529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.76705882]
[0.32058824]
[0.04494118]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.22352941]
[0.67776471]
[0.88741176]
[0.99223529]
[0.99223529]
[0.99223529]
[0.99223529]
[0.95729412]
[0.52635294]
[0.05270588]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.538 ]
[0.99223529]
[0.99223529]
[0.99223529]
[0.83305882]
[0.53411765]
[0.52247059]
[0.07211765]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]
[0.01 ]]
(784, 1)
output_nodes = 10
targets = np.zeros(output_nodes) + 0.01
targets[int(all_values[0])] = 0.99
print(targets.shape, targets)
targets = np.array(targets, ndmin=2).T
print(targets.shape, '\n', targets)
(10,) [0.01 0.01 0.01 0.01 0.01 0.99 0.01 0.01 0.01 0.01]
(10, 1)
[[0.01]
[0.01]
[0.01]
[0.01]
[0.01]
[0.99]
[0.01]
[0.01]
[0.01]
[0.01]]
下一条数据
all_values = data_list[1].split(',')
len(all_values)
image_array = np.asfarray(all_values[1:]).reshape(28, 28)
plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dce0cccd0>
? ?
以上功能可以使用 numpy 简化
train_path = r"data\mnist_train_100.csv"
data = pd.read_csv(train_path, header=None)
label = data[0].to_numpy(dtype=np.float64)
all_variabels = data.iloc[:, 1:]
all_variabels = all_variabels.to_numpy(dtype=np.float64)
plt.imshow(image_array, cmap='Greys', interpolation=None)
<matplotlib.image.AxesImage at 0x12dce122310>
? ?
把 训练数据 映射到指定的 区间
观察数据:我们知道 像素值 在 0~255 之间,在使用 神经网络 训练之前,我们把 该值 缩放到 0.01~1之间
注意:最小值为0.01,而不是为0,防止 像素值为0 后期 权重 更新失败。
scaled_input = (np.asfarray(all_values[1:])) / 255.5 * 0.99 + 0.01
print(all_variabels)
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
输出层
我们如何设置 输出层哪?
首先,输出结果是一个数字,数字的范围是 0~9,因此,该问题归纳为 多分类问题,输出神经元的个数设置为10.
onodes = 10
targets = np.zeros(onodes)+0.01
targets[int(all_values[0])] = 0.99
print(targets)
type(targets), targets.shape
[0.99 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]
(numpy.ndarray, (10,))
all_values[0]
'0'
正态分布随机生成矩阵
hnodes, inodes = 3, 3
np.random.normal(0.0, pow(hnodes, -0.5), (hnodes, inodes))
array([[ 0.5685311 , -0.17127778, 0.67140503],
[ 0.61448826, 0.29478324, 0.38356441],
[-0.26157523, -0.43210937, -0.76723949]])
pow(16, 2)
256
mu, sigma = 0, 0.1
s = np.random.normal(mu, sigma, 3)
s
array([-0.1519728 , 0.1522495 , -0.08011677])
激活函数
import scipy.special as ss
activation_function = lambda x: ss.expit(x)
x = np.arange(-10, 10, 0.1)
y = activation_function(x)
plt.plot(x, y)
plt.show()
activation_function(0)
? ?
0.5
2.框架代码
- 初始化函数 :设定 输入层、隐藏层、输出层
- 训练 : 学习给定训练集样本后,优化权重
- 查询 : 给定输入,从输出节点给出答案
class NeuralNetwork:
def __init__():
pass
def train():
pass
def query():
pass
初始化方法如下:
def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
""":arg
inodes : 输入层 神经元个数
hnodes : 隐藏层 神经元个数
onodes : 输出层 神经元个数
lr : 学习率
"""
self.inodes = inputnodes
self.hnodes = hiddennodes
self.onodes = outputnodes
self.lr = learningrate
pass
权重
权重,刚开始进行 随机生成,然后根据每次训练的结果,我们计算 损失值,进而更新权重,以便下次训练时,损失值更小。
import numpy as np
a = np.zeros([3, 2])
a[0, 0] = 1
a[0, 1] = 2
a[1, 0] = 9
a[2, 1] = 12
a.shape
(3, 2)
%matplotlib inline
import matplotlib.pyplot as plt
plt.imshow(a, interpolation="nearest")
<matplotlib.image.AxesImage at 0x12dce5c76a0>
? ?
我们设计的网络结构:3层神经元。
包含 : 输入层、隐藏层、输出层 注:hidden_nodes :表示隐藏层神经元个数, input_nodes :表示输入层 神经元个数, output_nodes : 输出层神经元个数
- 设置输入层和隐藏层之间的连接权重矩阵为 $ W_{input_hidden}$, 大小:(hidden_nodes, input_nodes)
- 设隐藏层和输出层之间的连接权重矩阵 为
W
h
i
d
d
e
n
o
u
t
p
u
t
W_{hidden_output}
Whiddeno?utput?, 大小:{output_nodes, hidden_nodes}
初始化权重值
我们设置初始的权重的要求:
方法一
np.random.rand(3, 3)
array([[0.56577557, 0.4548674 , 0.37397099],
[0.7145576 , 0.37282089, 0.39927602],
[0.25294715, 0.93189228, 0.19266301]])
一般权重的范围在(-1.0, 1.0) 之间,我了简单起见,我们设置 权重范围(-0.5, 0.5)
np.random.rand(3, 3) - 0.5
array([[ 0.18886128, 0.14154925, 0.25791405],
[-0.49226421, -0.35166701, -0.20517272],
[-0.32607314, 0.41795557, 0.01489836]])
inodes, hnodes = 3, 3
wih = np.random.rand(hnodes, inodes) - 0.5
wih
array([[ 0.30452585, 0.49108575, 0.13111538],
[ 0.46734034, 0.1032404 , 0.02777094],
[ 0.04864065, -0.07211676, -0.48808873]])
onodes = 2
who = np.random.rand(onodes, hnodes) - 0.5
who
array([[0.18630746, 0.17767593, 0.47013643],
[0.4316961 , 0.22181227, 0.44532566]])
这里我们初始化权重时,使用正态分布函数,在 神经网络 类 实例化时进行初始化:
注意:
方法二
mu, sigma = 0, 0.1
s = np.random.normal(mu, sigma, size=(2, 2))
s
array([[ 0.07580048, -0.16289481],
[-0.03706102, 0.01746931]])
正态分布函数的 平均值mu 设为0, 标准差sigma 设置为 传入链接输入的开方,即 $ \frac{1}{sqrt(传入链接数目)} $
inodes, hnodes = 3, 3
onodes = 3
wih = np.random.normal(0.0, pow(hnodes, -0.5), (hnodes, inodes))
wih
who = np.random.normal(0.0, pow(onodes, -0.5), (onodes, hnodes))
who
array([[ 0.90693664, 0.45453867, -0.42490091],
[-0.15520574, -0.11698295, 0.23764649],
[-0.21942348, 0.1452418 , 0.06733478]])
train() 训练函数的编写
a = np.arange(10).reshape(2, 5)
print(a)
np.transpose(a)
[[0 1 2 3 4]
[5 6 7 8 9]]
array([[0, 5],
[1, 6],
[2, 7],
[3, 8],
[4, 9]])
print(all_values)
['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '51', '159', '253', '159', '50', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '48', '238', '252', '252', '252', '237', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '54', '227', '253', '252', '239', '233', '252', '57', '6', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '10', '60', '224', '252', '253', '252', '202', '84', '252', '253', '122', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '163', '252', '252', '252', '253', '252', '252', '96', '189', '253', '167', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '51', '238', '253', '253', '190', '114', '253', '228', '47', '79', '255', '168', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '48', '238', '252', '252', '179', '12', '75', '121', '21', '0', '0', '253', '243', '50', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '38', '165', '253', '233', '208', '84', '0', '0', '0', '0', '0', '0', '253', '252', '165', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '7', '178', '252', '240', '71', '19', '28', '0', '0', '0', '0', '0', '0', '253', '252', '195', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '57', '252', '252', '63', '0', '0', '0', '0', '0', '0', '0', '0', '0', '253', '252', '195', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '198', '253', '190', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '255', '253', '196', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '76', '246', '252', '112', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '253', '252', '148', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '230', '25', '0', '0', '0', '0', '0', '0', '0', '0', '7', '135', '253', '186', '12', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '223', '0', '0', '0', '0', '0', '0', '0', '0', '7', '131', '252', '225', '71', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '145', '0', '0', '0', '0', '0', '0', '0', '48', '165', '252', '173', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '86', '253', '225', '0', '0', '0', '0', '0', '0', '114', '238', '253', '162', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '249', '146', '48', '29', '85', '178', '225', '253', '223', '167', '56', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '85', '252', '252', '252', '229', '215', '252', '252', '252', '196', '130', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '28', '199', '252', '252', '253', '252', '252', '233', '145', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '25', '128', '252', '253', '252', '141', '37', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0\n']
input_list = np.asfarray(all_values[1:])/255*0.99 + 0.01
input_list.shape
(784,)
input_list
array([0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.208 , 0.62729412, 0.99223529,
0.62729412, 0.20411765, 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.19635294,
0.934 , 0.98835294, 0.98835294, 0.98835294, 0.93011765,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.21964706, 0.89129412, 0.99223529, 0.98835294,
0.93788235, 0.91458824, 0.98835294, 0.23129412, 0.03329412,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.04882353, 0.24294118, 0.87964706,
0.98835294, 0.99223529, 0.98835294, 0.79423529, 0.33611765,
0.98835294, 0.99223529, 0.48364706, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.64282353, 0.98835294, 0.98835294, 0.98835294, 0.99223529,
0.98835294, 0.98835294, 0.38270588, 0.74376471, 0.99223529,
0.65835294, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.208 , 0.934 , 0.99223529,
0.99223529, 0.74764706, 0.45258824, 0.99223529, 0.89517647,
0.19247059, 0.31670588, 1. , 0.66223529, 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.19635294,
0.934 , 0.98835294, 0.98835294, 0.70494118, 0.05658824,
0.30117647, 0.47976471, 0.09152941, 0.01 , 0.01 ,
0.99223529, 0.95341176, 0.20411765, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.15752941, 0.65058824, 0.99223529, 0.91458824,
0.81752941, 0.33611765, 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.99223529, 0.98835294,
0.65058824, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.03717647, 0.70105882,
0.98835294, 0.94176471, 0.28564706, 0.08376471, 0.11870588,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.99223529, 0.98835294, 0.76705882, 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.23129412, 0.98835294, 0.98835294, 0.25458824,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.99223529,
0.98835294, 0.76705882, 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.77870588,
0.99223529, 0.74764706, 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 1. , 0.99223529, 0.77094118,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.30505882, 0.96505882, 0.98835294, 0.44482353,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.99223529, 0.98835294, 0.58458824, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.34 ,
0.98835294, 0.90294118, 0.10705882, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.03717647, 0.53411765, 0.99223529, 0.73211765,
0.05658824, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.34 , 0.98835294, 0.87576471,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.03717647, 0.51858824,
0.98835294, 0.88352941, 0.28564706, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.34 , 0.98835294, 0.57294118, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.19635294, 0.65058824, 0.98835294, 0.68164706, 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.34388235, 0.99223529,
0.88352941, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.45258824, 0.934 , 0.99223529,
0.63894118, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.34 , 0.98835294, 0.97670588, 0.57682353,
0.19635294, 0.12258824, 0.34 , 0.70105882, 0.88352941,
0.99223529, 0.87576471, 0.65835294, 0.22741176, 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.34 ,
0.98835294, 0.98835294, 0.98835294, 0.89905882, 0.84470588,
0.98835294, 0.98835294, 0.98835294, 0.77094118, 0.51470588,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.11870588, 0.78258824, 0.98835294,
0.98835294, 0.99223529, 0.98835294, 0.98835294, 0.91458824,
0.57294118, 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.10705882, 0.50694118, 0.98835294, 0.99223529,
0.98835294, 0.55741176, 0.15364706, 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 , 0.01 ,
0.01 , 0.01 , 0.01 , 0.01 ])
定义查询函数
接受神经网络的输入,返回网络的输出。
- 输入层和隐藏层的关系
公式:
X
h
i
d
d
e
n
=
W
i
n
p
u
t
_
h
i
d
d
e
n
?
I
X_{hidden} = W_{input\_hidden} \cdot I
Xhidden?=Winput_hidden??I
代码:
hidden_inputs = np.dot(self.wih, inputs)
公式:
O
h
i
d
d
e
n
=
s
i
g
m
o
i
d
(
X
h
i
d
d
e
n
)
O_{hidden} = sigmoid(X_{hidden})
Ohidden?=sigmoid(Xhidden?)
代码:
hidden_outputs = self.activation_function(hidden_inputs)
- 隐藏层和输出层之间权重和输入的处理
- 链接权重 X 隐藏层的输出值得到 输出层的输入信号 X
X
o
u
t
p
u
t
=
W
h
i
d
d
e
n
_
o
u
t
p
u
t
?
O
h
i
d
d
e
n
X_{output} = W_{hidden\_output} \cdot O_{hidden}
Xoutput?=Whidden_output??Ohidden?
O
o
u
t
p
u
t
=
s
i
g
m
o
i
d
(
X
o
u
t
p
u
t
)
O_{output} = sigmoid(X_{out_put})
Ooutput?=sigmoid(Xoutp?ut?)
final_inputs = np.dot(self.who, hidden_outputs)
final_outputs = self.activation_function(final_inputs)
import numpy as np
inputs = np.array(input_list, ndmin=2).T
inputs.shape
(784, 1)
np.zeros(10) + 0.01
array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
定义 train 函数
训练函数的功能:
- 前向传播 : 使用 权重 X 输入神经元的值
- 反向传播 : 根据 前向传播的 预测值,计算 误差,并使用梯度下降法反向更新权重
query(input_list) 函数已经把前向传播实现了,这里我们主要关注反向传播(backpropagation) 的实现。
反向传播稍微复杂,推导需要用到求导、链式法则等
def train(inputs_list, targets_list)
pass
我们需要传入:
- 要训练的样本 inputs_list
- 标签值 targets_list
标签值用来求误差,进而反向传播,更新权重,再前向传播得到优化后的值。
output_errors = targets - final_outputs
难点
e
r
r
o
r
s
h
i
d
d
e
n
=
w
e
i
g
h
t
s
h
i
d
d
e
n
_
o
u
t
p
u
t
T
?
e
r
r
o
r
s
o
u
t
p
u
t
errors_{hidden} = weights^{T}_{hidden\_output} \cdot errors_{output}
errorshidden?=weightshidden_outputT??errorsoutput?
hidden_errors = np.dot(self.who.T, output_errors)
因此, 对于 隐藏层 和 输出层 之间的权重,我们使用 output_errors 进行优化,
对于 输入层 和 隐藏层 之间的权重,我们使用 计算得到的 hidden_errors 进行优化。
我们根据 梯度下降算法 得到 更新节点 j 与 下一个 节点 k 之间 链接权重 的矩阵形式的表达式如下:
Δ
W
j
,
k
=
α
?
E
k
s
i
g
m
o
i
d
(
∑
j
w
j
,
k
?
O
j
)
?
(
1
?
s
i
g
m
o
i
d
(
∑
j
w
j
,
k
?
O
j
)
)
?
O
j
T
\Delta W_{j, k} = \alpha * E_{k} sigmoid(\sum \limits _{j}w_{j,k} \cdot O_j) * (1 - sigmoid(\sum \limits _{j}w_{j,k} \cdot O_j) ) \cdot O_j^T
ΔWj,k?=α?Ek?sigmoid(j∑?wj,k??Oj?)?(1?sigmoid(j∑?wj,k??Oj?))?OjT?
注:
α
\alpha
α 是学习率, sigmoid 是 激活函数 ,注意 * 表示正常的乘法,
?
\cdot
? 表示的是 矩阵点积。
Python 代码实现:
self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs))
self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
梯度下降公式:
n
e
w
W
j
,
k
=
o
l
d
W
j
,
k
?
α
?
?
E
?
w
j
,
k
new W_{j, k} = old W_{j, k} - \alpha \cdot \frac{\partial E}{\partial w_{j,k}}
newWj,k?=oldWj,k??α??wj,k??E?
前面的公式实际上求的就是 偏导数的值。
a = np.arange(4).reshape(-1, 2)
b = a
a, b
(array([[0, 1],
[2, 3]]),
array([[0, 1],
[2, 3]]))
np.dot(a, b)
array([[ 2, 3],
[ 6, 11]])
a = np.arange(10).reshape(-1, 2)
print(a)
max_value = np.argmax(a)
max_value
[[0 1]
[2 3]
[4 5]
[6 7]
[8 9]]
9
需要全部代码可以关注微信公众号哈。(学长杨小杨)。
|