1. 何为决策树桩?
?
单层决策树(decision stump),也称决策树桩,它是一种简单的决策树,通过给定的阈值进行分类。
?
从实际意义上来看,决策树桩根据一个属性的单个判断(但是实际上待判断的物体具有多个属性)就确定最终的分类结果。这种特性比较适合做集成学习中的弱学习器,因为其至少比随机的效果好一些,又计算较为容易。
2. 关键问题
根本目的:通过选择一个合适的决策树桩(弱学习器),使得物体类别识别准确率尽可能高。
怎么选择一个合适的决策树桩?
- 从所有属性中,选择那个属性作为决策树桩(弱学习器)
- 该决策树桩的阈值设定为何值(上图中是1.75)?
- 是小于阈值识别为1(yes),还是大于阈值识别为1(yes)。一般情况下设定为小于。
3. MATLAB代码实现
?
?
主函数test主要是数据的加载和决策树桩的调用, 并呈现最佳的决策树桩以及判断错误的概率。
首先看如何在整个数据集上获得最佳的决策树桩,使得判断错误的概率最小。
关键代码:err = sum(weight .* (predi_labels ~= label));
function [classifier, min_error, best_labels] = decision_stump(data, weight, label)
% decision_stump 确定最优决策树桩并返回分类函数
% data:num_row*num_col; weights:num_row*1(初始值均为1/num_row); labels:num_row*1(初始值为1或0)
% classifier有dim(哪一维特征识别率最高),thresh_val(阈值何值时候该维征识别率最高),thresh_ineq(是大于阈值识别高还是小于阈值识别高)
num_row = size(data, 1);
num_col = size(data, 2);
% 优化迭代次数
max_iter = 10;
% 初始化相关参数
min_error = Inf;
best_labels = ones(num_row , 1);
classifier.dim = 0;
classifier.thresh_val = 0;
classifier.thresh_ineq = '';
for i = 1:num_col
cur_thresh_val = min(data(:, i));
step_size = (max(data(:, i)) - min(data(:, i))) / max_iter;
for j = 1:max_iter
for k = ['l', 'g']
thresh_val = cur_thresh_val + (j - 1) * step_size;
predi_labels = decision(data, i, thresh_val, k);
err = sum(weight .* (predi_labels ~= label));
fprintf("iter %d dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", j, i, thresh_val, k, err);
if err < min_error
%更新相关参数
min_error = err;
best_labels = predi_labels;
classifier.dim = i;
classifier.thresh_val = thresh_val;
classifier.thresh_ineq = k;
end
end
end
end
end
% 根据决策树桩的阈值来分类(1,0)
function predi_labels = decision(data, i_col, thresh_val, thresh_ineq)
% data:num_row*num_col,i_col:表示第i列,thresh_val:表示决策树桩的阈值,thresh_ineq:采用大于或者小于来比较阈值
% predi_labels:根据决策树桩的阈值来获得分类后的labels,num_row*1
num_row = size(data, 1);
% 初始化predi_labels为全1
predi_labels = ones(num_row, 1);
if thresh_ineq == 'l'
% 如果小于阈值,则判定为0
predi_labels(data(:, i_col) <= thresh_val) = 0;
elseif thresh_ineq == 'g'
% 如果大于阈值,则判定为0
predi_labels(data(:, i_col) > thresh_val) = 0;
end
end
测试一下:
clc;clear;clearvars;
% 加载数据[data, label]([X,Y])
[data, label] = loadData();
% 初始化每个数据点的相对权重,创建数值均为1/num_row的num_row*num_col数组
num_row = size(data, 1);
num_col = size(data, 2);
weight = repmat(1 / num_row, num_row, 1);
[classifier, min_error, best_labels] = decision_stump(data, weight, label);
% 输出最佳的决策树桩相关参数
fprintf("dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f\n", ...
classifier.dim, classifier.thresh_val, classifier.thresh_ineq, min_error);
disp(best_labels);
function [data, label] = loadData()
% data:5*2,label:5*1,标签为0,1.
data = [1, 2.1; 1.5, 1.6; 1.3, 1; 1, 1; 2, 1];
label = [1; 1; 0 ; 0 ; 1];
end
iter 1 dim 1, threshVal 1.00, thresh ineqal: l, the weighted error is 0.400 iter 1 dim 1, threshVal 1.00, thresh ineqal: g, the weighted error is 0.600 iter 2 dim 1, threshVal 1.10, thresh ineqal: l, the weighted error is 0.400 iter 2 dim 1, threshVal 1.10, thresh ineqal: g, the weighted error is 0.600 iter 3 dim 1, threshVal 1.20, thresh ineqal: l, the weighted error is 0.400 iter 3 dim 1, threshVal 1.20, thresh ineqal: g, the weighted error is 0.600 iter 4 dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200 iter 4 dim 1, threshVal 1.30, thresh ineqal: g, the weighted error is 0.800 iter 5 dim 1, threshVal 1.40, thresh ineqal: l, the weighted error is 0.200 iter 5 dim 1, threshVal 1.40, thresh ineqal: g, the weighted error is 0.800 iter 6 dim 1, threshVal 1.50, thresh ineqal: l, the weighted error is 0.400 iter 6 dim 1, threshVal 1.50, thresh ineqal: g, the weighted error is 0.600 iter 7 dim 1, threshVal 1.60, thresh ineqal: l, the weighted error is 0.400 iter 7 dim 1, threshVal 1.60, thresh ineqal: g, the weighted error is 0.600 iter 8 dim 1, threshVal 1.70, thresh ineqal: l, the weighted error is 0.400 iter 8 dim 1, threshVal 1.70, thresh ineqal: g, the weighted error is 0.600 iter 9 dim 1, threshVal 1.80, thresh ineqal: l, the weighted error is 0.400 iter 9 dim 1, threshVal 1.80, thresh ineqal: g, the weighted error is 0.600 iter 10 dim 1, threshVal 1.90, thresh ineqal: l, the weighted error is 0.400 iter 10 dim 1, threshVal 1.90, thresh ineqal: g, the weighted error is 0.600 iter 1 dim 2, threshVal 1.00, thresh ineqal: l, the weighted error is 0.200 iter 1 dim 2, threshVal 1.00, thresh ineqal: g, the weighted error is 0.800 iter 2 dim 2, threshVal 1.11, thresh ineqal: l, the weighted error is 0.200 iter 2 dim 2, threshVal 1.11, thresh ineqal: g, the weighted error is 0.800 iter 3 dim 2, threshVal 1.22, thresh ineqal: l, the weighted error is 0.200 iter 3 dim 2, threshVal 1.22, thresh ineqal: g, the weighted error is 0.800 iter 4 dim 2, threshVal 1.33, thresh ineqal: l, the weighted error is 0.200 iter 4 dim 2, threshVal 1.33, thresh ineqal: g, the weighted error is 0.800 iter 5 dim 2, threshVal 1.44, thresh ineqal: l, the weighted error is 0.200 iter 5 dim 2, threshVal 1.44, thresh ineqal: g, the weighted error is 0.800 iter 6 dim 2, threshVal 1.55, thresh ineqal: l, the weighted error is 0.200 iter 6 dim 2, threshVal 1.55, thresh ineqal: g, the weighted error is 0.800 iter 7 dim 2, threshVal 1.66, thresh ineqal: l, the weighted error is 0.400 iter 7 dim 2, threshVal 1.66, thresh ineqal: g, the weighted error is 0.600 iter 8 dim 2, threshVal 1.77, thresh ineqal: l, the weighted error is 0.400 iter 8 dim 2, threshVal 1.77, thresh ineqal: g, the weighted error is 0.600 iter 9 dim 2, threshVal 1.88, thresh ineqal: l, the weighted error is 0.400 iter 9 dim 2, threshVal 1.88, thresh ineqal: g, the weighted error is 0.600 iter 10 dim 2, threshVal 1.99, thresh ineqal: l, the weighted error is 0.400 iter 10 dim 2, threshVal 1.99, thresh ineqal: g, the weighted error is 0.600 dim 1, threshVal 1.30, thresh ineqal: l, the weighted error is 0.200 ? ? ?0 ? ? ?1 ? ? ?0 ? ? ?0 ? ? ?1
可以看到最佳佳决策树桩选择的是第一个特征,阈值1.30,识别错误的概率是0.2,第一个识别错误。
|