IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TensorFlow 从入门到精通(2)—— 波士顿房价预测 -> 正文阅读

[人工智能]TensorFlow 从入门到精通(2)—— 波士顿房价预测

import tensorflow as tf
  • 防止过拟合的一个办法是设置验证集,当模型在训练集上训练完之后,利用验证集对模型进行调参优化,最终用测试集测试模型性能
print(tf.__version__)
2.5.1

一、准备数据集

import pandas as pd
df = pd.read_csv('boston.csv')
df.head()
CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOLSTATMEDV
00.0063218.02.3100.5386.57565.24.0900129615.34.9824.0
10.027310.07.0700.4696.42178.94.9671224217.89.1421.6
20.027290.07.0700.4697.18561.14.9671224217.84.0334.7
30.032370.02.1800.4586.99845.86.0622322218.72.9433.4
40.069050.02.1800.4587.14754.26.0622322218.75.3336.2
df.tail()
CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOLSTATMEDV
5010.062630.011.9300.5736.59369.12.4786127321.09.6722.4
5020.045270.011.9300.5736.12076.72.2875127321.09.0820.6
5030.060760.011.9300.5736.97691.02.1675127321.05.6423.9
5040.109590.011.9300.5736.79489.32.3889127321.06.4822.0
5050.047410.011.9300.5736.03080.82.5050127321.07.8811.9
df.describe()
CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOLSTATMEDV
count506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000
mean3.61352411.36363611.1367790.0691700.5546956.28463468.5749013.7950439.549407408.23715418.45553412.65306322.532806
std8.60154523.3224536.8603530.2539940.1158780.70261728.1488612.1057108.707259168.5371162.1649467.1410629.197104
min0.0063200.0000000.4600000.0000000.3850003.5610002.9000001.1296001.000000187.00000012.6000001.7300005.000000
25%0.0820450.0000005.1900000.0000000.4490005.88550045.0250002.1001754.000000279.00000017.4000006.95000017.025000
50%0.2565100.0000009.6900000.0000000.5380006.20850077.5000003.2074505.000000330.00000019.05000011.36000021.200000
75%3.67708212.50000018.1000000.0000000.6240006.62350094.0750005.18842524.000000666.00000020.20000016.95500025.000000
max88.976200100.00000027.7400001.0000000.8710008.780000100.00000012.12650024.000000711.00000022.00000037.97000050.000000
data = df.values
data.shape
(506, 13)
x_data = data[:,:12]
y_data = data[:,12]
from sklearn.preprocessing import scale
# 划分训练集,测试集,验证集
train_num = 300 # 训练集
valid_num = 100 # 验证集

# 训练集
x_train = x_data[:train_num]
y_train = y_data[:train_num]

# 验证集
x_valid = x_data[train_num:train_num+valid_num]
y_valid = y_data[train_num:train_num+valid_num]

# 测试集
x_test = x_data[train_num+valid_num:]
y_test = y_data[train_num+valid_num:]


# 转换数据类型 
'''
tf.cast可以直接将numpy数组转换为tensor
'''
x_train = tf.cast(scale(x_train),dtype=tf.float32)
x_valid = tf.cast(scale(x_valid),dtype=tf.float32)
x_test = tf.cast(scale(x_test),dtype=tf.float32)

二、模型

import numpy as np
def line_model(x,w,b):
    return tf.matmul(x,w) + b

三、训练

W = tf.Variable(tf.random.normal(shape=(12,1)))
B = tf.Variable(tf.zeros(1),dtype=tf.float32)
def loss(x,y,w,b):
    loss_ = tf.square(line_model(x,w,b) - y)
    return tf.reduce_mean(loss_)
    
def grand(x,y,w,b):
    with tf.GradientTape() as tape:
        loss_ = loss(x,y,w,b)
    return tape.gradient(loss_,[w,b])
training_epochs = 50
learning_rate = 0.001
batch_size = 10 # 训练一次的样本数
optimizer = tf.keras.optimizers.SGD(learning_rate)
total_steps = train_num // batch_size
loss_train_list = []
loss_valid_list = []
for epoch in range(training_epochs):
    for step in range(total_steps):
        xs = x_train[step*batch_size:(step+1)*batch_size]
        ys = y_train[step*batch_size:(step+1)*batch_size]
        
        grads = grand(xs,ys,W,B)
        optimizer.apply_gradients(zip(grads,[W,B]))
    loss_train = loss(x_train,y_train,W,B).numpy()
    loss_valid = loss(x_valid,y_valid,W,B).numpy()
    loss_train_list.append(loss_train)
    loss_valid_list.append(loss_valid)
    print(f'{epoch+1}次 loss_train:{loss_train} loss_valid:{loss_valid}')
1次 loss_train:665.7977905273438 loss_valid:477.3428039550781
2次 loss_train:593.6646728515625 loss_valid:413.1649475097656
3次 loss_train:533.6004028320312 loss_valid:362.72601318359375
4次 loss_train:482.3865661621094 loss_valid:321.768310546875
5次 loss_train:438.0271911621094 loss_valid:287.7205810546875
6次 loss_train:399.220703125 loss_valid:258.9722900390625
7次 loss_train:365.0664367675781 loss_valid:234.4657745361328
8次 loss_train:334.9011535644531 loss_valid:213.4668426513672
9次 loss_train:308.20770263671875 loss_valid:195.4357147216797
10次 loss_train:284.56365966796875 loss_valid:179.9535675048828
11次 loss_train:263.6124572753906 loss_valid:166.6811065673828
12次 loss_train:245.04588317871094 loss_valid:155.3340301513672
13次 loss_train:228.59408569335938 loss_valid:145.66885375976562
14次 loss_train:214.01869201660156 loss_valid:137.47401428222656
15次 loss_train:201.10826110839844 loss_valid:130.56398010253906
16次 loss_train:189.67495727539062 loss_valid:124.77549743652344
17次 loss_train:179.55177307128906 loss_valid:119.96448516845703
18次 loss_train:170.58995056152344 loss_valid:116.00358581542969
19次 loss_train:162.65756225585938 loss_valid:112.78060913085938
20次 loss_train:155.6372833251953 loss_valid:110.1964340209961
21次 loss_train:149.42471313476562 loss_valid:108.16374969482422
22次 loss_train:143.92759704589844 loss_valid:106.60588836669922
23次 loss_train:139.0637664794922 loss_valid:105.45535278320312
24次 loss_train:134.7605743408203 loss_valid:104.65306091308594
25次 loss_train:130.95376586914062 loss_valid:104.14726257324219
26次 loss_train:127.58616638183594 loss_valid:103.89268493652344
27次 loss_train:124.60723114013672 loss_valid:103.84986114501953
28次 loss_train:121.97221374511719 loss_valid:103.98436737060547
29次 loss_train:119.64164733886719 loss_valid:104.26632690429688
30次 loss_train:117.58037567138672 loss_valid:104.66973876953125
31次 loss_train:115.75743103027344 loss_valid:105.17220306396484
32次 loss_train:114.14533233642578 loss_valid:105.7542495727539
33次 loss_train:112.71986389160156 loss_valid:106.39917755126953
34次 loss_train:111.45946502685547 loss_valid:107.09259796142578
35次 loss_train:110.34514617919922 loss_valid:107.82219696044922
36次 loss_train:109.360107421875 loss_valid:108.57747650146484
37次 loss_train:108.4894790649414 loss_valid:109.3494873046875
38次 loss_train:107.72004699707031 loss_valid:110.13068389892578
39次 loss_train:107.04024505615234 loss_valid:110.91461181640625
40次 loss_train:106.43974304199219 loss_valid:111.69599914550781
41次 loss_train:105.909423828125 loss_valid:112.47042846679688
42次 loss_train:105.44117736816406 loss_valid:113.23422241210938
43次 loss_train:105.0279312133789 loss_valid:113.9843978881836
44次 loss_train:104.66329193115234 loss_valid:114.7186508178711
45次 loss_train:104.34174346923828 loss_valid:115.43498992919922
46次 loss_train:104.05831146240234 loss_valid:116.13207244873047
47次 loss_train:103.80862426757812 loss_valid:116.80872344970703
48次 loss_train:103.58882141113281 loss_valid:117.46420288085938
49次 loss_train:103.39546966552734 loss_valid:118.0979995727539
50次 loss_train:103.22554779052734 loss_valid:118.70991516113281
# 可视化损失
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(loss_train_list)
plt.plot(loss_valid_list)
[<matplotlib.lines.Line2D at 0x7efbb0f696d8>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GAPwi39G-1632045017347)(output_18_1.png)]

四、预测

W.numpy()
array([[ 0.04078581],
       [ 0.42111152],
       [-1.266555  ],
       [ 0.7909989 ],
       [-0.7962972 ],
       [ 1.6725844 ],
       [-0.23721075],
       [-0.900601  ],
       [ 0.45604745],
       [-0.51313245],
       [-2.141279  ],
       [-0.713085  ]], dtype=float32)
B.numpy()
array([24.265219], dtype=float32)
# 从测试集里随机挑一个
house_id = np.random.randint(0,x_data.shape[0]-train_num-valid_num)
# 真实值
y = y_test[house_id]
# 预测值
y_pre = int(line_model(x_test,W.numpy(),B.numpy())[house_id])

print(f'真实值:{y},预测值:{y_pre}')
真实值:10.2,预测值:24
plt.plot(y_test,'r')
plt.plot(line_model(x_test,W.numpy(),B.numpy()),'g')
[<matplotlib.lines.Line2D at 0x7efbb069a0b8>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UYVLWavY-1632045017351)(output_23_1.png)]

plt.plot(y_train,'r')
plt.plot(line_model(x_train,W.numpy(),B.numpy()),'g')
[<matplotlib.lines.Line2D at 0x7efbb0643828>]

(img-HNNdO7bQ-1632045017353)(output_24_1.png)]

plt.plot(y_valid,'r')
plt.plot(line_model(x_valid,W.numpy(),B.numpy()),'g')
[<matplotlib.lines.Line2D at 0x7efbb0610550>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n2NromCv-1632045017356)(output_25_1.png)]

import sklearn.metrics as sm # 模型性能评价模块
print('R2_score:', sm.r2_score(y_test,line_model(x_test,W.numpy(),B.numpy())))
R2_score: -2.399917872887695
预测效果不是很好…略微无奈,模型本身问题,解决不了房价预测问题。由此可见,选择模型非常非常重要,如果你模型选的不对,你调多好的参数都没用。
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-20 15:47:40  更:2021-09-20 15:50:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/22 3:03:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码