写在前面
本文主要介绍深度学习算法中的梯度下降法(Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent)的数学原理及PyTorch的代码实现 。
数学原理
梯度下降法:每次迭代都朝梯度下降最快的方向走,其特点是效率高; 随机梯度下降法:每次迭代中随机选择其中的一个样本来求梯度值,进行权重更新,而不是累加求平均值。
两者的数学公式区别如下(以线性函数为例):
PyTorch代码实现
梯度下降的实现
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
lr = 0.01
def forward(x):
return w*x
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost/len(xs)
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
grad += (y_pred - y) * 2 * x
return grad/len(xs)
epoch_list = []
cost_list = []
print('训练前的输入值x:{}, 训练前的预测值:{}\n'.format(4.0, forward(4.0)))
print("***************************开始训练***************************")
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= lr * grad_val
print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, cost_val, grad_val))
epoch_list.append(epoch)
cost_list.append(cost_val)
print("***************************训练结束***************************\n")
print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4)))
plt.plot(epoch_list, cost_list)
plt.ylabel('Cost')
plt.xlabel('Epoch')
plt.show()
输出结果如下:
训练前的输入值x:4.0, 训练前的预测值:4.0
***************************开始训练***************************
Epoch:0, w=1.0933333333333333, loss=4.666666666666667, grad=-9.333333333333334
Epoch:1, w=1.1779555555555554, loss=3.8362074074074086, grad=-8.462222222222222
Epoch:2, w=1.2546797037037036, loss=3.1535329869958857, grad=-7.6724148148148155
Epoch:3, w=1.3242429313580246, loss=2.592344272332262, grad=-6.956322765432099
Epoch:4, w=1.3873135910979424, loss=2.1310222071581117, grad=-6.30706597399177
Epoch:5, w=1.4444976559288012, loss=1.7517949663820642, grad=-5.718406483085872
Epoch:6, w=1.4963445413754464, loss=1.440053319920117, grad=-5.184688544664522
Epoch:7, w=1.5433523841804047, loss=1.1837878313441108, grad=-4.700784280495834
Epoch:8, w=1.5859728283235668, loss=0.9731262101573632, grad=-4.262044414316223
Epoch:9, w=1.6246153643467005, loss=0.7999529948031382, grad=-3.864253602313377
Epoch:10, w=1.659651263674342, loss=0.6575969151946154, grad=-3.503589932764129
Epoch:11, w=1.6914171457314033, loss=0.5405738908195378, grad=-3.1765882057061425
Epoch:12, w=1.7202182121298057, loss=0.44437576375991855, grad=-2.8801066398402355
Epoch:13, w=1.7463311789976905, loss=0.365296627844598, grad=-2.6112966867884806
Epoch:14, w=1.7700069356245727, loss=0.3002900634939416, grad=-2.3675756626882225
Epoch:15, w=1.7914729549662791, loss=0.2468517784170642, grad=-2.1466019341706555
Epoch:16, w=1.8109354791694263, loss=0.2029231330489788, grad=-1.9462524203147282
Epoch:17, w=1.8285815011136133, loss=0.16681183417217407, grad=-1.764602194418688
Epoch:18, w=1.8445805610096762, loss=0.1371267415488235, grad=-1.5999059896062764
Epoch:19, w=1.8590863753154396, loss=0.11272427607497944, grad=-1.4505814305763567
Epoch:20, w=1.872238313619332, loss=0.09266436490145864, grad=-1.31519383038923
Epoch:21, w=1.8841627376815275, loss=0.07617422636521683, grad=-1.1924424062195684
Epoch:22, w=1.8949742154979183, loss=0.06261859959338009, grad=-1.081147781639076
Epoch:23, w=1.904776622051446, loss=0.051475271914629306, grad=-0.9802406553527626
Epoch:24, w=1.9136641373266443, loss=0.04231496130368814, grad=-0.888751527519838
Epoch:25, w=1.9217221511761575, loss=0.03478477885657844, grad=-0.8058013849513194
Epoch:26, w=1.9290280837330496, loss=0.02859463421027894, grad=-0.7305932556891969
Epoch:27, w=1.9356521292512983, loss=0.023506060193480772, grad=-0.6624045518248707
Epoch:28, w=1.9416579305211772, loss=0.01932302619282764, grad=-0.6005801269878838
Epoch:29, w=1.9471031903392007, loss=0.015884386331668398, grad=-0.5445259818023471
Epoch:30, w=1.952040225907542, loss=0.01305767153735723, grad=-0.49370355683412726
Epoch:31, w=1.9565164714895047, loss=0.010733986344664803, grad=-0.44762455819627417
Epoch:32, w=1.9605749341504843, loss=0.008823813841374291, grad=-0.40584626609795665
Epoch:33, w=1.9642546069631057, loss=0.007253567147113681, grad=-0.3679672812621462
Epoch:34, w=1.9675908436465492, loss=0.005962754575689583, grad=-0.33362366834434704
Epoch:35, w=1.970615698239538, loss=0.004901649272531298, grad=-0.3024854592988742
Epoch:36, w=1.9733582330705144, loss=0.004029373553099482, grad=-0.27425348309764513
Epoch:37, w=1.975844797983933, loss=0.0033123241439168096, grad=-0.24865649134186527
Epoch:38, w=1.9780992835054327, loss=0.0027228776607060357, grad=-0.22544855214995874
Epoch:39, w=1.980143350378259, loss=0.002238326453885249, grad=-0.20440668728262779
Epoch:40, w=1.9819966376762883, loss=0.001840003826269386, grad=-0.185328729802915
Epoch:41, w=1.983676951493168, loss=0.0015125649231412608, grad=-0.1680313816879758
Epoch:42, w=1.9852004360204722, loss=0.0012433955919298103, grad=-0.1523484527304313
Epoch:43, w=1.9865817286585614, loss=0.0010221264385926248, grad=-0.13812926380892523
Epoch:44, w=1.987834100650429, loss=0.0008402333603648631, grad=-0.12523719918675966
Epoch:45, w=1.9889695845897222, loss=0.0006907091659248264, grad=-0.11354839392932907
Epoch:46, w=1.9899990900280147, loss=0.0005677936325753796, grad=-0.10295054382926017
Epoch:47, w=1.9909325082920666, loss=0.0004667516012495216, grad=-0.09334182640519595
Epoch:48, w=1.9917788075181404, loss=0.000383690560742734, grad=-0.08462992260737945
Epoch:49, w=1.9925461188164473, loss=0.00031541069384432885, grad=-0.07673112983068957
Epoch:50, w=1.9932418143935788, loss=0.0002592816085930997, grad=-0.06956955771315876
Epoch:51, w=1.9938725783835114, loss=0.0002131410058905752, grad=-0.06307639899326374
Epoch:52, w=1.994444471067717, loss=0.00017521137977565514, grad=-0.0571892684205603
Epoch:53, w=1.9949629871013967, loss=0.0001440315413480261, grad=-0.05185160336797523
Epoch:54, w=1.9954331083052663, loss=0.0001184003283899171, grad=-0.0470121203869646
Epoch:55, w=1.9958593515301082, loss=9.733033217332803e-05, grad=-0.042624322484180986
Epoch:56, w=1.9962458120539648, loss=8.000985883901657e-05, grad=-0.0386460523856574
Epoch:57, w=1.9965962029289281, loss=6.57716599593935e-05, grad=-0.035039087496328225
Epoch:58, w=1.9969138906555615, loss=5.406722767150764e-05, grad=-0.03176877266333733
Epoch:59, w=1.997201927527709, loss=4.444566413387458e-05, grad=-0.02880368721475879
Epoch:60, w=1.9974630809584561, loss=3.65363112808981e-05, grad=-0.026115343074715636
Epoch:61, w=1.9976998600690001, loss=3.0034471708953996e-05, grad=-0.02367791105440838
Epoch:62, w=1.9979145397958935, loss=2.4689670610172655e-05, grad=-0.02146797268933165
Epoch:63, w=1.9981091827482769, loss=2.0296006560253656e-05, grad=-0.01946429523832638
Epoch:64, w=1.9982856590251044, loss=1.6684219437262796e-05, grad=-0.01764762768274834
Epoch:65, w=1.9984456641827613, loss=1.3715169898293847e-05, grad=-0.016000515765691798
Epoch:66, w=1.9985907355257035, loss=1.1274479219506377e-05, grad=-0.014507134294228674
Epoch:67, w=1.9987222668766378, loss=9.268123006398985e-06, grad=-0.013153135093433596
Epoch:68, w=1.9988415219681517, loss=7.61880902783969e-06, grad=-0.011925509151381094
Epoch:69, w=1.9989496465844576, loss=6.262999634617916e-06, grad=-0.010812461630584766
Epoch:70, w=1.9990476795699081, loss=5.1484640551938914e-06, grad=-0.009803298545062233
Epoch:71, w=1.9991365628100501, loss=4.232266273994499e-06, grad=-0.008888324014190227
Epoch:72, w=1.999217150281112, loss=3.479110977946351e-06, grad=-0.008058747106198657
Epoch:73, w=1.999290216254875, loss=2.859983851026929e-06, grad=-0.00730659737628736
Epoch:74, w=1.9993564627377531, loss=2.3510338359374262e-06, grad=-0.006624648287833749
Epoch:75, w=1.9994165262155628, loss=1.932654303533636e-06, grad=-0.00600634778096983
Epoch:76, w=1.999470983768777, loss=1.5887277332523938e-06, grad=-0.005445755321414225
Epoch:77, w=1.9995203586170245, loss=1.3060048068548734e-06, grad=-0.004937484824748761
Epoch:78, w=1.9995651251461022, loss=1.0735939958924364e-06, grad=-0.004476652907771476
Epoch:79, w=1.9996057134657994, loss=8.825419799121559e-07, grad=-0.004058831969712499
Epoch:80, w=1.9996425135423248, loss=7.254887315754342e-07, grad=-0.003680007652538434
Epoch:81, w=1.999675878945041, loss=5.963839812987369e-07, grad=-0.003336540271635139
Epoch:82, w=1.999706130243504, loss=4.902541385825727e-07, grad=-0.0030251298462834106
Epoch:83, w=1.9997335580874436, loss=4.0301069098738336e-07, grad=-0.002742784393962546
Epoch:84, w=1.9997584259992822, loss=3.312926995781724e-07, grad=-0.002486791183860415
Epoch:85, w=1.9997809729060159, loss=2.723373231729343e-07, grad=-0.002254690673365515
Epoch:86, w=1.9998014154347876, loss=2.2387338352920307e-07, grad=-0.0020442528771848303
Epoch:87, w=1.9998199499942075, loss=1.8403387118941732e-07, grad=-0.0018534559419821999
Epoch:88, w=1.9998367546614149, loss=1.5128402140063082e-07, grad=-0.0016804667207292272
Epoch:89, w=1.9998519908930161, loss=1.2436218932547864e-07, grad=-0.0015236231601270707
Epoch:90, w=1.9998658050763347, loss=1.0223124683409346e-07, grad=-0.00138141833184946
Epoch:91, w=1.9998783299358769, loss=8.403862850836479e-08, grad=-0.0012524859542084599
Epoch:92, w=1.9998896858085284, loss=6.908348768398496e-08, grad=-0.0011355872651486187
Epoch:93, w=1.9998999817997325, loss=5.678969725349543e-08, grad=-0.0010295991204016808
Epoch:94, w=1.9999093168317574, loss=4.66836551287917e-08, grad=-0.0009335032024962627
Epoch:95, w=1.9999177805941268, loss=3.8376039345125727e-08, grad=-0.000846376236931512
Epoch:96, w=1.9999254544053418, loss=3.154680994333735e-08, grad=-0.0007673811214832978
Epoch:97, w=1.9999324119941766, loss=2.593287985380858e-08, grad=-0.0006957588834774301
Epoch:98, w=1.9999387202080534, loss=2.131797981222471e-08, grad=-0.000630821387685554
Epoch:99, w=1.9999444396553017, loss=1.752432687141379e-08, grad=-0.0005719447248348312
***************************训练结束***************************
训练后的输入值x:4.0, 训练后的预测值:7.999777758621207
随机梯度下降的实现
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
lr = 0.01
def forward(x):
return w*x
"""
# 定义损失函数loss
def loss(x, y):
y_pred = forward(x) # 预测值 y_hat
return (y_pred - y) ** 2
"""
def loss(xs, ys):
for x, y in zip(xs, ys):
y_pred = forward(x)
return (y_pred - y) ** 2
"""
# 定义梯度
def gradient(x, y):
y_pred = forward(x)
return 2 * x * (y_pred - y)
"""
def gradient(xs, ys):
for x, y in zip(xs, ys):
y_pred = forward(x)
return 2 * x * (y_pred - y)
epoch_list = []
cost_list = []
print('训练前的输入值x:{}, 训练前的预测值:{}\n'.format(4.0, forward(4.0)))
print("***************************开始训练***************************")
"""
# 开始训练
for epoch in range(100):
for x, y in zip(x_data, y_data):
loss_val = loss(x, y) # 预测的loss值
grad_val = gradient(x, y) # 预测的gradient
w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在
print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val))
epoch_list.append(epoch)
cost_list.append(loss_val)
"""
for epoch in range(100):
loss_val = loss(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= lr * grad_val
print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val))
epoch_list.append(epoch)
cost_list.append(loss_val)
print("***************************训练结束***************************\n")
print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4)))
plt.plot(epoch_list, cost_list)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()
输出结果如下:
训练前的输入值x:4.0, 训练前的预测值:4.0
***************************开始训练***************************
Epoch:0, w=1.18, loss=9.0, grad=-18.0
Epoch:1, w=1.3276, loss=6.0516, grad=-14.76
Epoch:2, w=1.448632, loss=4.069095840000001, grad=-12.103200000000001
Epoch:3, w=1.54787824, loss=2.7360600428160007, grad=-9.924624000000001
Epoch:4, w=1.6292601568, loss=1.8397267727894793, grad=-8.138191680000002
Epoch:5, w=1.695993328576, loss=1.237032282023645, grad=-6.6733171775999995
Epoch:6, w=1.75071452943232, loss=0.8317805064326984, grad=-5.472120085631998
Epoch:7, w=1.7955859141345025, loss=0.5592892125253471, grad=-4.487138470218241
Epoch:8, w=1.832380449590292, loss=0.3760660665020425, grad=-3.679453545578953
Epoch:9, w=1.8625519686640395, loss=0.2528668231159735, grad=-3.0171519073747426
Epoch:10, w=1.8872926143045123, loss=0.17002765186318064, grad=-2.474064564047289
Epoch:11, w=1.9075799437297, loss=0.1143265931128029, grad=-2.0287329425187792
Epoch:12, w=1.924215553858354, loss=0.07687320120904852, grad=-1.6635610128653973
Epoch:13, w=1.9378567541638503, loss=0.05168954049296429, grad=-1.3641200305496266
Epoch:14, w=1.9490425384143573, loss=0.03475604702746934, grad=-1.1185784250506963
Epoch:15, w=1.9582148814997729, loss=0.02336996602127023, grad=-0.917234308541568
Epoch:16, w=1.9657362028298138, loss=0.01571396515270217, grad=-0.7521321330040873
Epoch:17, w=1.9719036863204473, loss=0.010566070168677008, grad=-0.6167483490633536
Epoch:18, w=1.9769610227827668, loss=0.007104625581418294, grad=-0.5057336462319455
Epoch:19, w=1.9811080386818687, loss=0.004777150240945764, grad=-0.4147015899101998
Epoch:20, w=1.9845085917191323, loss=0.0032121558220119784, grad=-0.3400553037263663
Epoch:21, w=1.9872970452096885, loss=0.0021598535747208276, grad=-0.27884534905561864
Epoch:22, w=1.9895835770719446, loss=0.0014522855436422575, grad=-0.22865318622560515
Epoch:23, w=1.9914585331989945, loss=0.0009765167995450728, grad=-0.18749561270499804
Epoch:24, w=1.9929959972231754, loss=0.0006566098960140951, grad=-0.153746402418097
Epoch:25, w=1.994256717723004, loss=0.00044150449407990444, grad=-0.12607204998284338
Epoch:26, w=1.9952905085328632, loss=0.0002968676218193051, grad=-0.10337908098592763
Epoch:27, w=1.9961382169969477, loss=0.00019961378891129975, grad=-0.08477084640846044
Epoch:28, w=1.9968333379374972, loss=0.00013422031166397195, grad=-0.06951209405494119
Epoch:29, w=1.9974033371087476, loss=9.024973756285169e-05, grad=-0.056999917125050814
Epoch:30, w=1.997870736429173, loss=6.068392353726093e-05, grad=-0.046739932042541454
Epoch:31, w=1.9982540038719219, loss=4.080387018645243e-05, grad=-0.03832674427488314
Epoch:32, w=1.9985682831749758, loss=2.7436522313372105e-05, grad=-0.03142793030540503
Epoch:33, w=1.9988259922034801, loss=1.844831760351415e-05, grad=-0.02577090285043404
Epoch:34, w=1.9990373136068538, loss=1.2404648756606792e-05, grad=-0.021132140337359218
Epoch:35, w=1.9992105971576202, loss=8.340885823940048e-06, grad=-0.017328355076632107
Epoch:36, w=1.9993526896692486, loss=5.608411628015101e-06, grad=-0.014209251162835557
Epoch:37, w=1.999469205528784, loss=3.7710959786780433e-06, grad=-0.011651585953526222
Epoch:38, w=1.9995647485336028, loss=2.535684936061193e-06, grad=-0.009554300481887879
Epoch:39, w=1.9996430937975542, loss=1.7049945510078245e-06, grad=-0.0078345263951487
Epoch:40, w=1.9997073369139944, loss=1.1464383360988404e-06, grad=-0.006424311644025238
Epoch:41, w=1.9997600162694753, loss=7.708651371929851e-07, grad=-0.005267935548101121
Epoch:42, w=1.9998032133409698, loss=5.183297182483074e-07, grad=-0.004319707149441854
Epoch:43, w=1.9998386349395951, loss=3.4852490255049745e-07, grad=-0.0035421598625440254
Epoch:44, w=1.999867680650468, loss=2.343481444747825e-07, grad=-0.002904571087285035
Epoch:45, w=1.9998914981333837, loss=1.5757569234467452e-07, grad=-0.0023817482915724497
Epoch:46, w=1.9999110284693746, loss=1.0595389553262853e-07, grad=-0.0019530335990900483
Epoch:47, w=1.9999270433448872, loss=7.124339935627219e-08, grad=-0.0016014875512553317
Epoch:48, w=1.9999401755428075, loss=4.790406172717297e-08, grad=-0.0013132197920295852
Epoch:49, w=1.9999509439451022, loss=3.221069110544675e-08, grad=-0.0010768402294658586
Epoch:50, w=1.9999597740349837, loss=2.1658468699093255e-08, grad=-0.0008830089881577408
Epoch:51, w=1.9999670147086868, loss=1.456315435353612e-08, grad=-0.0007240673702959555
Epoch:52, w=1.9999729520611231, loss=9.792264987127843e-09, grad=-0.0005937352436369281
Epoch:53, w=1.9999778206901209, loss=6.584318977344762e-09, grad=-0.00048686289978228103
Epoch:54, w=1.999981812965899, loss=4.427296080402076e-09, grad=-0.00039922757782306917
Epoch:55, w=1.9999850866320372, loss=2.976913884501124e-09, grad=-0.00032736661381704835
Epoch:56, w=1.9999877710382705, loss=2.001676895938556e-09, grad=-0.00026844062332997964
Epoch:57, w=1.9999899722513819, loss=1.3459275448290849e-09, grad=-0.0002201213111305833
Epoch:58, w=1.999991777246133, loss=9.050016811366642e-10, grad=-0.00018049947512643882
Epoch:59, w=1.999993257341829, loss=6.08523130391911e-10, grad=-0.00014800956960314693
Epoch:60, w=1.9999944710203, loss=4.091709529057039e-10, grad=-0.0001213678470790569
Epoch:61, w=1.999995466236646, loss=2.7512654872200957e-10, grad=-9.952163460269503e-05
Epoch:62, w=1.9999962823140496, loss=1.8499509135681353e-10, grad=-8.160774037335727e-05
Epoch:63, w=1.9999969514975207, loss=1.2439069943862355e-10, grad=-6.691834710892408e-05
Epoch:64, w=1.9999975002279669, loss=8.364030629083358e-11, grad=-5.4873044625480816e-05
Epoch:65, w=1.999997950186933, loss=5.623974196274511e-11, grad=-4.4995896598010177e-05
Epoch:66, w=1.999998319153285, loss=3.781560249181731e-11, grad=-3.689663520844988e-05
Epoch:67, w=1.9999986217056938, loss=2.5427211109406962e-11, grad=-3.0255240867305133e-05
Epoch:68, w=1.9999988697986688, loss=1.709725675158115e-11, grad=-2.4809297512362605e-05
Epoch:69, w=1.9999990732349084, loss=1.1496195438197204e-11, grad=-2.0343623958751778e-05
Epoch:70, w=1.999999240052625, loss=7.730041815607078e-12, grad=-1.66817716493739e-05
Epoch:71, w=1.9999993768431525, loss=5.197680114303315e-12, grad=-1.3679052749182574e-05
Epoch:72, w=1.999999489011385, loss=3.4949201104515554e-12, grad=-1.1216823256887665e-05
Epoch:73, w=1.9999995809893356, loss=2.349984281723007e-12, grad=-9.197795069582071e-06
Epoch:74, w=1.9999996564112552, loss=1.5801294318344075e-12, grad=-7.5421919589757636e-06
Epoch:75, w=1.9999997182572293, loss=1.0624790294527732e-12, grad=-6.184597404867986e-06
Epoch:76, w=1.9999997689709281, loss=7.144108997944157e-13, grad=-5.071369873377307e-06
Epoch:77, w=1.9999998105561612, loss=4.803698890710119e-13, grad=-4.158523296382555e-06
Epoch:78, w=1.999999844656052, loss=3.230007131690541e-13, grad=-3.409989101754718e-06
Epoch:79, w=1.9999998726179626, loss=2.1718567983289397e-13, grad=-2.796191065357334e-06
Epoch:80, w=1.9999998955467293, loss=1.4603565098387233e-13, grad=-2.2928766725272e-06
Epoch:81, w=1.999999914348318, loss=9.81943716992902e-14, grad=-1.880158871259141e-06
Epoch:82, w=1.9999999297656208, loss=6.60258955032161e-14, grad=-1.5417302741127514e-06
Epoch:83, w=1.999999942407809, loss=4.439581219624794e-14, grad=-1.2642188256251075e-06
Epoch:84, w=1.9999999527744035, loss=2.985174434787262e-14, grad=-1.0366594409561003e-06
Epoch:85, w=1.9999999612750108, loss=2.0072312657907758e-14, grad=-8.500607364680945e-07
Epoch:86, w=1.9999999682455087, loss=1.3496623055941361e-14, grad=-6.97049804543326e-07
Epoch:87, w=1.9999999739613172, loss=9.075129393581547e-15, grad=-5.715808413242485e-07
Epoch:88, w=1.99999997864828, loss=6.1021170708499814e-15, grad=-4.686962924438376e-07
Epoch:89, w=1.9999999824915897, loss=4.103063445617242e-15, grad=-3.843309563933417e-07
Epoch:90, w=1.9999999856431034, loss=2.7588998272437547e-15, grad=-3.151513823240748e-07
Epoch:91, w=1.999999988227345, loss=1.8550843478908236e-15, grad=-2.5842414075327724e-07
Epoch:92, w=1.999999990346423, loss=1.2473586828983886e-15, grad=-2.1190779264657067e-07
Epoch:93, w=1.9999999920840668, loss=8.387239228207162e-16, grad=-1.737643842147918e-07
Epoch:94, w=1.9999999935089348, loss=5.639580129513639e-16, grad=-1.4248680102468825e-07
Epoch:95, w=1.9999999946773266, loss=3.79205343002729e-16, grad=-1.1683917300331359e-07
Epoch:96, w=1.9999999956354078, loss=2.54977656183392e-16, grad=-9.580811877185624e-08
Epoch:97, w=1.9999999964210344, loss=1.7144698904287566e-16, grad=-7.856266037720161e-08
Epoch:98, w=1.9999999970652482, loss=1.1528095085501517e-16, grad=-6.44213802303284e-08
Epoch:99, w=1.9999999975935034, loss=7.751490791422244e-17, grad=-5.282553061647377e-08
***************************训练结束***************************
训练后的输入值x:4.0, 训练后的预测值:7.999999990374014
思考题
如下的实验对比发现,zip 函数写在损失计算函数里面和写在训练开始加载数据的这个位置中,得到的结果中,每次Epoch的loss值不一样,有小伙伴知道这是什么原因吗?知道的小伙伴可以把答案打在评论区哦!
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
lr = 0.01
def forward(x):
return w*x
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
def gradient(x, y):
y_pred = forward(x)
return 2 * x * (y_pred - y)
epoch_list = []
cost_list = []
print('训练前的输入值x:{}, 训练前的预测值:{}\n'.format(4.0, forward(4.0)))
print("***************************开始训练***************************")
for epoch in range(100):
for x, y in zip(x_data, y_data):
loss_val = loss(x, y)
grad_val = gradient(x, y)
w -= lr * grad_val
print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val))
epoch_list.append(epoch)
cost_list.append(loss_val)
"""
# 开始训练
for epoch in range(100):
loss_val = loss(x_data, y_data) # 预测的loss值
grad_val = gradient(x_data, y_data) # 预测的gradient
w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在
print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val))
epoch_list.append(epoch)
cost_list.append(loss_val)
"""
print("***************************训练结束***************************\n")
print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4)))
plt.plot(epoch_list, cost_list)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()
训练前的输入值x:4.0, 训练前的预测值:4.0
***************************开始训练***************************
Epoch:0, w=1.260688, loss=7.315943039999998, grad=-16.2288
Epoch:1, w=1.453417766656, loss=3.9987644858206908, grad=-11.998146585599997
Epoch:2, w=1.5959051959019805, loss=2.1856536232765476, grad=-8.87037374849311
Epoch:3, w=1.701247862192685, loss=1.1946394387269013, grad=-6.557973756745939
Epoch:4, w=1.7791289594933983, loss=0.6529686924601721, grad=-4.848388694047353
Epoch:5, w=1.836707389300983, loss=0.3569010862285927, grad=-3.584471942173538
Epoch:6, w=1.8792758133988885, loss=0.195075792793724, grad=-2.650043120512205
Epoch:7, w=1.910747160155559, loss=0.10662496249654511, grad=-1.9592086795121197
Epoch:8, w=1.9340143044689266, loss=0.05827931013158195, grad=-1.4484664872674653
Epoch:9, w=1.9512159834655312, loss=0.03185443548946761, grad=-1.0708686556346834
Epoch:10, w=1.9639333911678687, loss=0.017411068491745587, grad=-0.7917060475345892
Epoch:11, w=1.9733355232910992, loss=0.009516580701123755, grad=-0.5853177814148953
Epoch:12, w=1.9802866323953892, loss=0.005201593933418656, grad=-0.4327324596134101
Epoch:13, w=1.9854256707695, loss=0.0028430988290765965, grad=-0.3199243001817109
Epoch:14, w=1.9892250235079405, loss=0.0015539873076143675, grad=-0.2365238742159388
Epoch:15, w=1.9920339305797026, loss=0.000849381853184108, grad=-0.17486493849433593
Epoch:16, w=1.994110589284741, loss=0.00046425703027525234, grad=-0.12927974740812687
Epoch:17, w=1.9956458879852805, loss=0.0002537546444534916, grad=-0.09557806861579543
Epoch:18, w=1.9967809527381737, loss=0.00013869778028681822, grad=-0.07066201306448505
Epoch:19, w=1.9976201197307648, loss=7.580974250901852e-05, grad=-0.052241274202728505
Epoch:20, w=1.998240525958391, loss=4.143625836981207e-05, grad=-0.03862260091336722
Epoch:21, w=1.99869919972735, loss=2.2648322641186416e-05, grad=-0.028554152326460525
Epoch:22, w=1.9990383027488265, loss=1.2379170770717257e-05, grad=-0.021110427464781978
Epoch:23, w=1.9992890056818404, loss=6.766234806804613e-06, grad=-0.01560719234984198
Epoch:24, w=1.999474353368653, loss=3.6983037320320918e-06, grad=-0.011538584590544687
Epoch:25, w=1.9996113831376856, loss=2.021427113440456e-06, grad=-0.008530614050808794
Epoch:26, w=1.9997126908902887, loss=1.104876146204831e-06, grad=-0.006306785335127074
Epoch:27, w=1.9997875889274812, loss=6.039056715601388e-07, grad=-0.00466268207967957
Epoch:28, w=1.9998429619451539, loss=3.300841106907982e-07, grad=-0.0034471768136938863
Epoch:29, w=1.9998838998815958, loss=1.8041811041234527e-07, grad=-0.0025485391844828342
Epoch:30, w=1.9999141657892625, loss=9.861333372463779e-08, grad=-0.0018841656015560204
Epoch:31, w=1.9999365417379913, loss=5.390029618456292e-08, grad=-0.0013929862392156878
Epoch:32, w=1.9999530845453979, loss=2.946094426616231e-08, grad=-0.0010298514424817995
Epoch:33, w=1.9999653148414271, loss=1.6102828713572706e-08, grad=-0.0007613815296476645
Epoch:34, w=1.999974356846045, loss=8.801520081617991e-09, grad=-0.0005628985014531906
Epoch:35, w=1.9999810417085633, loss=4.810754502894822e-09, grad=-0.0004161576169003922
Epoch:36, w=1.9999859839076413, loss=2.6294729403166827e-09, grad=-0.0003076703200690645
Epoch:37, w=1.9999896377347262, loss=1.43722319226415e-09, grad=-0.00022746435967313516
Epoch:38, w=1.999992339052936, loss=7.855606621992112e-10, grad=-0.00016816713067413502
Epoch:39, w=1.9999943361699042, loss=4.293735011907528e-10, grad=-0.00012432797771566584
Epoch:40, w=1.9999958126624442, loss=2.3468792720119317e-10, grad=-9.191716585732479e-05
Epoch:41, w=1.999996904251097, loss=1.2827625139303397e-10, grad=-6.795546372551087e-05
Epoch:42, w=1.999997711275687, loss=7.011351996364471e-11, grad=-5.0240289795056015e-05
Epoch:43, w=1.9999983079186507, loss=3.832280433642867e-11, grad=-3.714324913239864e-05
Epoch:44, w=1.9999987490239537, loss=2.0946563973304985e-11, grad=-2.7460449796734565e-05
Epoch:45, w=1.9999990751383971, loss=1.1449019716984442e-11, grad=-2.0301840059744336e-05
Epoch:46, w=1.9999993162387186, loss=6.257830771044664e-12, grad=-1.5009393983689279e-05
Epoch:47, w=1.9999994944870796, loss=3.4204191158124504e-12, grad=-1.109662508014253e-05
Epoch:48, w=1.9999996262682318, loss=1.8695403191196464e-12, grad=-8.20386808086937e-06
Epoch:49, w=1.999999723695619, loss=1.0218575233353146e-12, grad=-6.065218119744031e-06
Epoch:50, w=1.9999997957248556, loss=5.585291664185541e-13, grad=-4.484088535150477e-06
Epoch:51, w=1.9999998489769344, loss=3.0528211874783223e-13, grad=-3.3151404608133817e-06
Epoch:52, w=1.9999998883468353, loss=1.6686178282138566e-13, grad=-2.4509231284497446e-06
Epoch:53, w=1.9999999174534755, loss=9.120368570648034e-14, grad=-1.811996877876254e-06
Epoch:54, w=1.999999938972364, loss=4.9850314593866976e-14, grad=-1.3396310407642886e-06
Epoch:55, w=1.9999999548815364, loss=2.7247296013817913e-14, grad=-9.904052991061008e-07
Epoch:56, w=1.9999999666433785, loss=1.4892887826055098e-14, grad=-7.322185204827747e-07
Epoch:57, w=1.9999999753390494, loss=8.140187918760348e-15, grad=-5.413379398078177e-07
Epoch:58, w=1.9999999817678633, loss=4.449282094197275e-15, grad=-4.002176350326181e-07
Epoch:59, w=1.9999999865207625, loss=2.4318985397157373e-15, grad=-2.9588569994132286e-07
Epoch:60, w=1.999999990034638, loss=1.3292325355918982e-15, grad=-2.1875184863517916e-07
Epoch:61, w=1.9999999926324883, loss=7.265349176868077e-16, grad=-1.617258700292723e-07
Epoch:62, w=1.99999999455311, loss=3.971110830662586e-16, grad=-1.195658771990793e-07
Epoch:63, w=1.9999999959730488, loss=2.1705387408049341e-16, grad=-8.839649012770678e-08
Epoch:64, w=1.9999999970228268, loss=1.186377771034419e-16, grad=-6.53525820126788e-08
Epoch:65, w=1.9999999977989402, loss=6.484530240933061e-17, grad=-4.8315948575350376e-08
Epoch:66, w=1.9999999983727301, loss=3.544328347681514e-17, grad=-3.5720557178819945e-08
Epoch:67, w=1.9999999987969397, loss=1.937267496512019e-17, grad=-2.6408640607655798e-08
Epoch:68, w=1.999999999110563, loss=1.0588762836876607e-17, grad=-1.9524227568012975e-08
Epoch:69, w=1.9999999993424284, loss=5.78763006728202e-18, grad=-1.4434496264925656e-08
Epoch:70, w=1.9999999995138495, loss=3.1634161455080883e-18, grad=-1.067159693945996e-08
Epoch:71, w=1.9999999996405833, loss=1.7290652800585402e-18, grad=-7.88963561149103e-09
Epoch:72, w=1.999999999734279, loss=9.45076322860815e-19, grad=-5.832902161273523e-09
Epoch:73, w=1.9999999998035491, loss=5.165625480570949e-19, grad=-4.31233715403323e-09
Epoch:74, w=1.9999999998547615, loss=2.8234328489219424e-19, grad=-3.188159070077745e-09
Epoch:75, w=1.9999999998926234, loss=1.5432429879714383e-19, grad=-2.3570478902001923e-09
Epoch:76, w=1.9999999999206153, loss=8.435055999638128e-20, grad=-1.7425900722400911e-09
Epoch:77, w=1.9999999999413098, loss=4.610497285725064e-20, grad=-1.2883241140571045e-09
Epoch:78, w=1.9999999999566096, loss=2.520026245264157e-20, grad=-9.524754318590567e-10
Epoch:79, w=1.9999999999679208, loss=1.377407569393045e-20, grad=-7.041780492045291e-10
Epoch:80, w=1.9999999999762834, loss=7.528673013117128e-21, grad=-5.206075570640678e-10
Epoch:81, w=1.999999999982466, loss=4.115053962078213e-21, grad=-3.8489211817704927e-10
Epoch:82, w=1.9999999999870368, loss=2.2492314589577438e-21, grad=-2.845563784603655e-10
Epoch:83, w=1.999999999990416, loss=1.229387284413716e-21, grad=-2.1037571684701106e-10
Epoch:84, w=1.9999999999929146, loss=6.719695441682722e-22, grad=-1.5553425214420713e-10
Epoch:85, w=1.9999999999947617, loss=3.6730159234427135e-22, grad=-1.1499068364173581e-10
Epoch:86, w=1.9999999999961273, loss=2.0073851892168518e-22, grad=-8.500933290633839e-11
Epoch:87, w=1.999999999997137, loss=1.0972931813778698e-22, grad=-6.285105769165966e-11
Epoch:88, w=1.9999999999978835, loss=5.996996411023123e-23, grad=-4.646416584819235e-11
Epoch:89, w=1.9999999999984353, loss=3.2777893208522223e-23, grad=-3.4351188560322043e-11
Epoch:90, w=1.9999999999988431, loss=1.791878298003441e-23, grad=-2.539835008974478e-11
Epoch:91, w=1.9999999999991447, loss=9.796529104915932e-24, grad=-1.8779644506139448e-11
Epoch:92, w=1.9999999999993676, loss=5.353229824352417e-24, grad=-1.3882228699912957e-11
Epoch:93, w=1.9999999999995324, loss=2.926260595255618e-24, grad=-1.0263789818054647e-11
Epoch:94, w=1.9999999999996543, loss=1.5996332109454424e-24, grad=-7.58859641791787e-12
Epoch:95, w=1.9999999999997444, loss=8.746960714572049e-25, grad=-5.611511255665391e-12
Epoch:96, w=1.999999999999811, loss=4.774848841557949e-25, grad=-4.1460168631601846e-12
Epoch:97, w=1.9999999999998603, loss=2.6081713678869703e-25, grad=-3.064215547965432e-12
Epoch:98, w=1.9999999999998967, loss=1.4248800100554526e-25, grad=-2.2648549702353193e-12
Epoch:99, w=1.9999999999999236, loss=7.82747233205549e-26, grad=-1.6786572132332367e-12
***************************训练结束***************************
训练后的输入值x:4.0, 训练后的预测值:7.9999999999996945
写到这里,差不多本文也就要结束了。如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行。
参考文章
|