ID3决策树
本文从计算数据集的信息熵、划分数据集、选择最优特征、递归训练一棵树、预测五个方面介绍怎样构建ID3决策树。 先要介绍信息熵和信息增益的这两个公式:
Ent
?
(
D
)
=
?
∑
k
=
1
∣
Y
∣
p
k
log
?
2
p
k
\operatorname{Ent}(D)=-\sum_{k=1}^{|\mathcal{Y}|} p_{k} \log _{2} p_{k}
Ent(D)=?k=1∑∣Y∣?pk?log2?pk?
Gain
?
(
D
,
a
)
=
Ent
?
(
D
)
?
∑
v
=
1
V
∣
D
v
∣
∣
D
∣
Ent
?
(
D
v
)
\operatorname{Gain}(D, a)=\operatorname{Ent}(D)-\sum_{v=1}^{V} \frac{\left|D^{v}\right|}{|D|} \operatorname{Ent}\left(D^{v}\right)
Gain(D,a)=Ent(D)?v=1∑V?∣D∣∣Dv∣?Ent(Dv) 具体也可参考这篇文章,详细介绍了决策树公式及知识框架。
计算数据集的信息熵
假设现在的数据集dataSet最后一列为样本标签,因为数据集的信息熵只与标签的纯度有关,所以需要取出数据集最后一列的类别及其数量到字典中,代入公式计算信息熵。
以下为计算数据集信息熵的函数:
def entropy(dataSet):
labelCounts = {}
length = len(dataSet)
for example in dataSet:
if example[-1] not in labelCounts: labelCounts[example[-1]] = 0
labelCounts[example[-1]] += 1
e = 0.0
for i in labelCounts.values():
p_k = float(i) / length
e -= p_k * log(p_k, 2)
return e
划分数据集
以下函数功能为按照特征的不同值将样本集划分为多个数据集。
def splitDataSet(dataSet, axis, value):
reDataSet = []
for example in dataSet:
if example[axis] == value:
reDataSet += [example[:axis] + example[axis + 1:]]
return reDataSet
选择最优特征
遍历各个特征,选取信息熵最小的作为最优划分属性。
def chooseBestFeature(dataSet):
bestFeature = -1
length = len(dataSet)
minEntropy = entropy(dataSet)
for axis in range(len(dataSet[0]) - 1):
newEntropy = 0.0
record = {}
for example in dataSet:
if example[axis] not in record: record[example[axis]] = 0
record[example[axis]] += 1
for i in record:
newEntropy += entropy(splitDataSet(dataSet, axis, i)) * record[i] / length
if newEntropy < minEntropy:
bestFeature = axis
minEntropy = newEntropy
return bestFeature
递归训练一棵树
当前结点都属于同一个标签时,结束递归。
def trainTree(dataSet,feature_name):
myTree = {}
k = chooseBestFeature(dataSet)
s = set()
for example in dataSet:
s.add(example[k])
tree = {}
for i in s:
newDataSet = splitDataSet(dataSet, k, i)
labelRecord = []
for example in newDataSet:
labelRecord.append(example[-1])
if labelRecord.count(labelRecord[0]) == len(labelRecord):
tree[i] = labelRecord[0]
else:
tree[i] = trainTree(newDataSet, feature_name)
myTree[feature_name[k]] = tree
return myTree
预测
当结果是字典时,继续递归,否则输出该值为预测结果。
def predict(inputTree,feature_name,testVec):
feature = list(inputTree.keys())[0]
k = feature_name.index(feature)
childTree = inputTree[feature]
label = childTree[testVec[k]]
if isinstance(label, dict): return predict(label, feature_name, testVec)
return label
sklearn实现决策树
除了ID3决策树,常用的还有C4.5决策树和CART决策树。
Python的sklearn库中提供了决策树的模型,可用于快速构建不同类型的决策树模型,具体可参考……待续
|