IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 决策树桩(Decision Stump) -> 正文阅读

[人工智能]决策树桩(Decision Stump)

1. 何为决策树桩?

?

单层决策树(decision stump),也称决策树桩,它是一种简单的决策树,通过给定的阈值进行分类。

?

从实际意义上来看,决策树桩根据一个属性的单个判断(但是实际上待判断的物体具有多个属性)就确定最终的分类结果。这种特性比较适合做集成学习中的弱学习器,因为其至少比随机的效果好一些,又计算较为容易。

2. 关键问题

根本目的:通过选择一个合适的决策树桩(弱学习器),使得物体类别识别准确率尽可能高。

怎么选择一个合适的决策树桩?

  1. 从所有属性中,选择那个属性作为决策树桩(弱学习器)
  2. 该决策树桩的阈值设定为何值(上图中是1.75)?
  3. 是小于阈值识别为1(yes),还是大于阈值识别为1(yes)。一般情况下设定为小于。

3. MATLAB代码实现

?

?

主函数test主要是数据的加载和决策树桩的调用, 并呈现最佳的决策树桩以及判断错误的概率。

首先看如何在整个数据集上获得最佳的决策树桩,使得判断错误的概率最小。

关键代码:err = sum(weight .* (predi_labels ~= label));

function [classifier, min_error, best_labels] = decision_stump(data, weight, label)
% decision_stump 确定最优决策树桩并返回分类函数
% data:num_row*num_col; weights:num_row*1(初始值均为1/num_row); labels:num_row*1(初始值为1或0)
% classifier有dim(哪一维特征识别率最高),thresh_val(阈值何值时候该维征识别率最高),thresh_ineq(是大于阈值识别高还是小于阈值识别高)
num_row = size(data, 1);
num_col = size(data, 2);
% 优化迭代次数
max_iter = 10;
% 初始化相关参数
min_error = Inf;
best_labels = ones(num_row , 1);
classifier.dim = 0;
classifier.thresh_val = 0;
classifier.thresh_ineq = '';
for i = 1:num_col
    cur_thresh_val = min(data(:, i));
    step_size = (max(data(:, i)) - min(data(:, i))) / max_iter;
    for j = 1:max_iter
        for k = ['l', 'g']
            thresh_val = cur_thresh_val + (j - 1) * step_size;
            predi_labels = decision(data, i, thresh_val, k);
            err = sum(weight .* (predi_labels ~= label));
            fprintf("iter %d dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", j, i, thresh_val, k, err);
            if err < min_error
                %更新相关参数
                min_error = err;
                best_labels = predi_labels;
                classifier.dim = i;
                classifier.thresh_val = thresh_val;
                classifier.thresh_ineq = k;
            end
        end
    end
end
end

% 根据决策树桩的阈值来分类(1,0)
function predi_labels = decision(data, i_col, thresh_val, thresh_ineq)
% data:num_row*num_col,i_col:表示第i列,thresh_val:表示决策树桩的阈值,thresh_ineq:采用大于或者小于来比较阈值
% predi_labels:根据决策树桩的阈值来获得分类后的labels,num_row*1
num_row = size(data, 1);
% 初始化predi_labels为全1
predi_labels = ones(num_row, 1);
if thresh_ineq == 'l'
    % 如果小于阈值,则判定为0
    predi_labels(data(:, i_col) <= thresh_val) = 0;
elseif thresh_ineq == 'g'
    % 如果大于阈值,则判定为0
    predi_labels(data(:, i_col) > thresh_val) = 0;
end
end

测试一下:

clc;clear;clearvars;
% 加载数据[data, label]([X,Y])
[data, label] = loadData();
% 初始化每个数据点的相对权重,创建数值均为1/num_row的num_row*num_col数组
num_row = size(data, 1);
num_col = size(data, 2);
weight = repmat(1 / num_row, num_row, 1);
[classifier, min_error, best_labels] = decision_stump(data, weight, label);
% 输出最佳的决策树桩相关参数
fprintf("dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", ...
    classifier.dim, classifier.thresh_val, classifier.thresh_ineq, min_error);
disp(best_labels);


function [data, label] = loadData()
% data:5*2,label:5*1,标签为0,1.
data = [1, 2.1; 1.5, 1.6; 1.3, 1; 1, 1; 2, 1];
label = [1; 1; 0 ; 0 ; 1];
end

iter 1 dim 1, threshVal 1.00, thresh ineqal: l, the weighted error is 0.400
iter 1 dim 1, threshVal 1.00, thresh ineqal: g, the weighted error is 0.600
iter 2 dim 1, threshVal 1.10, thresh ineqal: l, the weighted error is 0.400
iter 2 dim 1, threshVal 1.10, thresh ineqal: g, the weighted error is 0.600
iter 3 dim 1, threshVal 1.20, thresh ineqal: l, the weighted error is 0.400
iter 3 dim 1, threshVal 1.20, thresh ineqal: g, the weighted error is 0.600
iter 4 dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200
iter 4 dim 1, threshVal 1.30, thresh ineqal: g, the weighted error is 0.800
iter 5 dim 1, threshVal 1.40, thresh ineqal: l, the weighted error is 0.200
iter 5 dim 1, threshVal 1.40, thresh ineqal: g, the weighted error is 0.800
iter 6 dim 1, threshVal 1.50, thresh ineqal: l, the weighted error is 0.400
iter 6 dim 1, threshVal 1.50, thresh ineqal: g, the weighted error is 0.600
iter 7 dim 1, threshVal 1.60, thresh ineqal: l, the weighted error is 0.400
iter 7 dim 1, threshVal 1.60, thresh ineqal: g, the weighted error is 0.600
iter 8 dim 1, threshVal 1.70, thresh ineqal: l, the weighted error is 0.400
iter 8 dim 1, threshVal 1.70, thresh ineqal: g, the weighted error is 0.600
iter 9 dim 1, threshVal 1.80, thresh ineqal: l, the weighted error is 0.400
iter 9 dim 1, threshVal 1.80, thresh ineqal: g, the weighted error is 0.600
iter 10 dim 1, threshVal 1.90, thresh ineqal: l, the weighted error is 0.400
iter 10 dim 1, threshVal 1.90, thresh ineqal: g, the weighted error is 0.600
iter 1 dim 2, threshVal 1.00, thresh ineqal: l, the weighted error is 0.200
iter 1 dim 2, threshVal 1.00, thresh ineqal: g, the weighted error is 0.800
iter 2 dim 2, threshVal 1.11, thresh ineqal: l, the weighted error is 0.200
iter 2 dim 2, threshVal 1.11, thresh ineqal: g, the weighted error is 0.800
iter 3 dim 2, threshVal 1.22, thresh ineqal: l, the weighted error is 0.200
iter 3 dim 2, threshVal 1.22, thresh ineqal: g, the weighted error is 0.800
iter 4 dim 2, threshVal 1.33, thresh ineqal: l, the weighted error is 0.200
iter 4 dim 2, threshVal 1.33, thresh ineqal: g, the weighted error is 0.800
iter 5 dim 2, threshVal 1.44, thresh ineqal: l, the weighted error is 0.200
iter 5 dim 2, threshVal 1.44, thresh ineqal: g, the weighted error is 0.800
iter 6 dim 2, threshVal 1.55, thresh ineqal: l, the weighted error is 0.200
iter 6 dim 2, threshVal 1.55, thresh ineqal: g, the weighted error is 0.800
iter 7 dim 2, threshVal 1.66, thresh ineqal: l, the weighted error is 0.400
iter 7 dim 2, threshVal 1.66, thresh ineqal: g, the weighted error is 0.600
iter 8 dim 2, threshVal 1.77, thresh ineqal: l, the weighted error is 0.400
iter 8 dim 2, threshVal 1.77, thresh ineqal: g, the weighted error is 0.600
iter 9 dim 2, threshVal 1.88, thresh ineqal: l, the weighted error is 0.400
iter 9 dim 2, threshVal 1.88, thresh ineqal: g, the weighted error is 0.600
iter 10 dim 2, threshVal 1.99, thresh ineqal: l, the weighted error is 0.400
iter 10 dim 2, threshVal 1.99, thresh ineqal: g, the weighted error is 0.600
dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200
? ? ?0
? ? ?1
? ? ?0
? ? ?0
? ? ?1

可以看到最佳佳决策树桩选择的是第一个特征,阈值1.30,识别错误的概率是0.2,第一个识别错误。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-13 11:44:22  更:2022-05-13 11:44:47 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 5:39:27-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码