在周志华《机器学习》西瓜书,有关于决策树的连续值的处理的描述,并用西瓜的密度和含糖量这两个连续属性来举例。
那么,下面以17个西瓜的密度(features)为例:
[0.697,0.774,0.634,0.608,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719],
那么如何找到相应的划分点?
1、对17个连续值X排序后,再对序列的前一个和后一个取平均,得到16个划分点 split_X
利用二分法,由sort_X到split_X。当然n个连续值,不必然得到n-1个划分点,因为有重复值的存在。
先对X排序得到sort_X:
[0.243, 0.245, 0.343, 0.36, 0.403, 0.437, 0.481, 0.556, 0.593, 0.608, 0.634, 0.639, 0.657, 0.666, 0.697, 0.719, 0.774]
定划分点,得到split_X:
[0.244, 0.29400000000000004, 0.35150000000000003, 0.3815, 0.42000000000000004, 0.45899999999999996, 0.5185, 0.5745, 0.6005, 0.621, 0.6365000000000001, 0.648, 0.6615, 0.6815, 0.708, 0.7464999999999999]
2、针对16个划分点,求到每个划分点的增益,并得到gains向量,找到gains中最大的增益点,即为其对应的最佳划分点: 一般而言,信息增益越大,则意味着使用“属性”或“划分点”所获得的“纯度提升”越大。相同的情况下,纯度越高,模型越好。 gains:
[0.05682352941176472, 0.11847797178090891, 0.18663565267767523, 0.26293671403544794, 0.09399614386760902, 0.03069956878979646, 0.004082532221190538, 0.0027244389091765076, 0.0027244389091765076, 0.004082532221190482, 0.030699568789796516, 0.006543942807450409, 0.0012672425232922446, 0.02458344666805945, 0.0008309129573794705, 0.0674593804343554, 0.0]
可见,以上面的增益最大的值为0.2629,对应的划分点为0.381.所以密度这个连续性的数据的最佳划分点就是0.381.(密度与这个值进行比较)
3、相关计算划分点和增益程序如下:
function get_gains(X,y,splits_X) :: Vector{Float64}
@assert length(X) ==length(y)
gains = zeros(length(splits_X))
for i in 1:length(splits_X)
split_ = splits_X[i]
greater_num = 0
greater_yes_num = 0
greater_no_num = 0
less_num = 0
less_yes_num = 0
less_no_nun = 0
for j in 1:length(X)
_x = X[j]
if _x > split_
greater_num += 1
if y[j] == 1
greater_yes_num +=1
else
greater_no_num +=1
end
else
less_num += 1
if y[j] == 1
less_yes_num +=1
else
less_no_nun +=1
end
end
end
if greater_num > 0
if greater_yes_num == 0
entrop_greater = - (greater_no_num/greater_num)*log2(greater_no_num/greater_num)
elseif greater_no_num == 0
entrop_greater = - (greater_yes_num/greater_num)*log2(greater_yes_num/greater_num)
else
entrop_greater = -(greater_yes_num/greater_num)*log2(greater_yes_num/greater_num) - (greater_no_num/greater_num)*log2(greater_no_num/greater_num)
end
else
entrop_greater = 0
end
if less_num > 0
if less_yes_num ==0
entrop_less = - (less_no_nun/less_num)*log2(less_no_nun/less_num)
elseif less_no_nun ==0
entrop_less = -(less_yes_num/less_num)*log2(less_yes_num/less_num)
else
entrop_less = -(less_yes_num/less_num)*log2(less_yes_num/less_num) - (less_no_nun/less_num)*log2(less_no_nun/less_num)
end
else
entrop_less =0
end
gains[i] = entrop_total - greater_num/length(X)*entrop_greater - less_num/length(X)*entrop_less
end
return gains
end
X = [0.697,0.774,0.634,0.608,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719]
y = [true,true,true,true,true,true,true,true,false,false,false,false,false,false,false,false,false]
entrop_total = 0.998
sort_X = sort(X)
splits_X = (sort_X[1:end-1] .+ sort_X[2:end])/2
gains = get_gains(X,y,splits_X)
println("gains : $(gains)")
println("gains max value : $(maximum(gains)) ")
输出:
gains : [0.05682352941176472, 0.11847797178090891, 0.18663565267767523, 0.26293671403544794, 0.09399614386760902, 0.03069956878979646, 0.004082532221190538, 0.0027244389091765076, 0.0027244389091765076, 0.004082532221190482, 0.030699568789796516, 0.006543942807450409, 0.0012672425232922446, 0.02458344666805945, 0.0008309129573794705, 0.0674593804343554]
gains max value : 0.26293671403544794
|