IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> sklearn入门——回归树 -> 正文阅读

[人工智能]sklearn入门——回归树

重要参数、属性及接口

  • criterion
    回归树衡量分支质量的指标,支持的标准有三种:
    1)输入“mse”使用均方误差(mean squared error),父节点和子节点之间的均方误差的差额来作为特征选择的标准,通过使用叶子节点的均值来最小化L2损失。
    2)输入"friedman_mse"使用费尔德曼均方误差,这种指标使用费尔德曼针对潜在分支中的问题改进后的均方误差。
    3)输入"mae" 使用绝对平方误差,使用节点的中值来最小化L1损失
    其他属性也包括feature_importances_,接口有apply,fit,predict,score等核心的。
    M S E = 1 N ∑ i = 1 N ( f i ? y i ) 2 MSE=\frac{1}{N}\sum_{i=1}^{N}(f_i-y_i)^2 MSE=N1?i=1N?(fi??yi?)2

    在回归树中MSE既是分支质量的衡量标准也是回归树回归质量的标准,在回归中,MSE越小越好,而接口score返回的是R2,R2定义是:
    R 2 = 1 ? u v , u = ∑ ( f i ? y i ) 2 , ( 残 差 平 方 和 ) v = ∑ ( y i ? y h a t ) ( 总 平 方 和 ) R^2=1-\frac{u}{v},u=\sum(f_i-y_i)^2,(残差平方和)v=\sum(y_i-y_{hat})(总平方和) R2=1?vu?,u=(fi??yi?)2,()v=(yi??yhat?)(),其中y-hat是所有标签的平均值。
    虽然均方误差恒为非负,但考虑到作为损失,变成了负的。

交叉验证

交叉验证是用来观察模型的稳定性的一种方法,将数据划分为n份,依次使用其中的一份作为测试集,其他n-1份作为训练集,多次计算模型的准确性来评估模型的平均准确程度。(这吃瓜的时候学过,现在听到居然毫无印象。。。。)
相关操作:

from sklearn import tree
# 交叉验证的库
from sklearn.medel_selection import cross_val_score
# 随便搞一个数据集
from sklearn.datasets import load_boston

# 实例化一个回归模型
regressor = tree.DecesionTreeRegressor(random_state=0)
# 参数分别是实例化的模型,样本数据矩阵,样本标签,验证次数,衡量标准(这里是均方误差的相反数)
cross_val_score(regressor, boston.data, boston.target, cv=10,scoring='neg_mean_squared_error')

回归树的例子

用回归树拟合正弦函数曲线

import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
rng = np.random.RandomState(1) # 随机数种子
X = np.sort(5*rng.rand(80,1), axis=0)
# y数据必须是一维的,用ravel()展平
# numpy数组中一维数组不分行和列
y = np.sin(X).ravel()

plt.figure()
# plt.scatter专门画散点图
plt.scatter(X, y, s=20, edgecolor='black', c='darkorange', label='data')
plt.legend()

# np.random.rand(数组结构),生成随机数组的函数
# 加噪声
y[::5] += 3*(0.5-rng.rand(16))

plt.scatter(X, y, s=20, edgecolor='black', c='darkorange', label='data')
plt.legend()

regressor1 = tree.DecisionTreeRegressor(max_depth=3)
regressor2 = tree.DecisionTreeRegressor(max_depth=5)
regressor1.fit(X, y)
regressor2.fit(X, y)

X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regressor1.predict(X_test)
y_2 = regressor2.predict(X_test)

plt.figure()
plt.scatter(X, y, s=20, edgecolor='black', c='darkorange', label='data')
plt.plot(X_test, y_1, color='cornflowerblue', label='max_depth=3', linewidth=2)
plt.plot(X_test, y_2, color='yellowgreen', label='max_depth=5', linewidth=2)
plt.xlabel('data')
plt.ylabel('Decision Tree Regression')
plt.legend()
plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章           查看所有文章
加:2021-08-05 17:21:26  更:2021-08-05 17:24:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/17 22:30:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码