决策树算法常用于解决分类问题,该方法的优势在于其数据形式非常容易理解。
概述
决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本 分类的任务,可看作对“当前样本属于正类吗?”这个问题的“决策”或“判定”过程。决策树是基于树结构来进行决策的,这是人类在面临决策问题时的一种很自然的处理机制。
例如:我们要对“这是好瓜吗?”这样的问题进行决策时,通常会进行一系列的判断或“子决策”:我们先看“它是什么颜色?”,如果是青绿色,我们再看它的根蒂是什么形态,如果是蜷缩,我们再看它敲起来是什么声音,最后得出最终决策:这是个好瓜,这个决策的过程如下图:
决策过程的最终结论与我们所希望的判定结果是对应的,决策过程中提出的每个判定问题都是对某个属性的测试,每个测试的结果导出进一步的判定问题,其考虑范围是在上次决策结果的限定范围之内。
一棵决策树包含一个根结点、若干个内部结点和若干个叶结点,叶结点对应于决策结果,其他每个结点则对应于一个属性测试,每个结点包含的样本集合根据属性测试的结果被划分到子结点中,根结点包含样本全集。从根结点到每个叶结点的路径对应了一个判定测试序列,决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树
决策树算法的基本流程
决策树算法从数据中提取出一系列规则,在创建规则的过程,就是决策树学习的过程。
选择最优划分属性(ID3)
从上面的流程中可以看出:当前数据集上选择哪个特征作为划分属性是该算法的关键。一般来说,我们希望随着划分过程的不断进行,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的“纯度”越来越来。1970年,Quinlan J.R找到了用信息论中的熵来度量决策树的决策选择过程,这种方法的简洁和高效在当时就引起了轰动,并命名为ID3算法。
决策树算法的优缺点
优点:计算复杂度不高,输出结果容易理解,对中间值的缺失不敏感,可以处理不相关特征数据。 缺点:可能会产生过度匹配问题 使用数据类型:数值型和标称型
决策树分类示例
在信贷问题上的应用示例
下面通过所给的训练数据学习一个贷款申请的决策树,用以对未来的贷款申请进行分类,即当新的客户提出贷款申请时,根据申请人的特征利用决策树决定是否批准贷款申请,样本数据如下表所示 对于上述样本数据,经过对属性值进行离散化处理,对于年龄离散化表示为[0, 1, 2],其中0,1,2分别对应青年,中年与老年;对于有工作一栏离散化表示为[0, 1],其中0,1分别对应无工作和有工作,同样对有房子这一属性进行相似的离散化;对于信贷情况按照上述方式离散化为[0, 1, 2] ,其中0,1,2分别对应一般,好,非常好;对于类别用yes和no表示。
对该问题利用复现的ID3算法构建决策树,得到的决策树最终为: 对结果进行验证,年龄这一列的数据,也就是特征A1,一共有三个类别,分别是:青年、中年和老年。其中年龄是青年的数据一共有5个,所以年龄是青年的数据在训练数据集出现的概率是三分之一。同理,年龄是中年和老年的数据在训练数据集出现的概率也都是三分之一。而年龄是青年的数据的最终得到贷款的概率为五分之二,因为在五个数据中,只有两个数据显示拿到了最终的贷款,同理,年龄是中年和老年的数据最终得到贷款的概率分别为五分之三、五分之四。所以计算年龄的信息增益,过程如下:
同理可以得到 因此利用属性A3是否有房子进行划分得到的信息增益最大,所以首先选择有房子作为划分属性。它将训练集D划分为两个子集D1(A3取值为“是”)和D2(A3取值为“否”)。由于D1只有同一类的样本点,所以它成为一个叶结点,结点的类标记为“是”。
对D2则需要从特征A1(年龄),A2(有工作)和A4(信贷情况)中选择新的特征,计算各个特征的信息增益
根据计算,选择信息增益最大的特征A2(有工作)作为结点的特征。由于A2有两个可能取值,从这一结点引出两个子结点:一个对应“是”(有工作)的子结点,包含3个样本,它们属于同一类,所以这是一个叶结点,类标记为“是”;另一个是对应“否”(无工作)的子结点,包含6个样本,它们也属于同一类,所以这也是一个叶结点,类标记为“否”。这样就生成了一个决策树,该决策树只用了两个属性就构建好了。
在鱼类问题上的应用示例(含代码)
下面通过所给的训练数据学习一个海洋生物分类的决策树,根据海洋生物的特征利用决策树决定是否为鱼类,样本数据如下表所示 对于上述样本数据,经过对属性值进行离散化处理,用1表示“是”,0表示“否”。对于类别用“yes”与“no”表示是否为鱼类。
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
生成决策树的主函数: 使用字典存储决策树mytree的结构;使用一个二维的list变量dataSet来保存数据集;使用一维的list变量labels来保存特征名称; 使用一维的list变量featLabels来保存选择的特征。
def createTree(dataSet, labels, featLabels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
其中majorityCnt 用来统计出现次数最大的类标签,chooseBestFeatureToSplit函数选择最大的信息增益来选择当前最优划分属性。
def majorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
其中在chooseBestFeatureToSplit选择最优划分属性的函数中,利用calcShannonEnt计算信息熵,利用splitDataSet根据选择的属性划分数据集
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
最终构建的决策树如下:
构建好了决策树之后,则可以使用决策树进行分类
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
if 'classLabel' not in dir():
return "maybe"
return classLabel
完整代码
'''
'''
from math import log
import operator
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
from collections import Counter
import pickle
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeature = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
beatFeature = -1
for i in range(numFeature):
featureList = [example[i] for example in dataSet]
uniqueVals = set(featureList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
arrow_args = dict(arrowstyle="<-")
font = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=14)
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va='center', ha='center', bbox=nodeType,
arrowprops=arrow_args, FontProperties=font)
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = next(iter(myTree))
cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor="white")
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xoff = -0.5 / plotTree.totalW
plotTree.yoff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
if 'classLabel' not in dir():
return "maybe"
return classLabel
if __name__ == "__main__":
dataSet, labels = createDataSet()
clabels = labels[:]
myTree = createTree(dataSet, clabels)
print(classify(myTree, labels, [1, 0]))
createPlot(myTree)
|