IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 决策树算法总结 -> 正文阅读

[人工智能]决策树算法总结

决策树算法常用于解决分类问题,该方法的优势在于其数据形式非常容易理解。

概述

决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本 分类的任务,可看作对“当前样本属于正类吗?”这个问题的“决策”或“判定”过程。决策树是基于树结构来进行决策的,这是人类在面临决策问题时的一种很自然的处理机制。

例如:我们要对“这是好瓜吗?”这样的问题进行决策时,通常会进行一系列的判断或“子决策”:我们先看“它是什么颜色?”,如果是青绿色,我们再看它的根蒂是什么形态,如果是蜷缩,我们再看它敲起来是什么声音,最后得出最终决策:这是个好瓜,这个决策的过程如下图:

图1.1 西瓜问题的一颗决策树

决策过程的最终结论与我们所希望的判定结果是对应的,决策过程中提出的每个判定问题都是对某个属性的测试,每个测试的结果导出进一步的判定问题,其考虑范围是在上次决策结果的限定范围之内。

一棵决策树包含一个根结点、若干个内部结点和若干个叶结点,叶结点对应于决策结果,其他每个结点则对应于一个属性测试,每个结点包含的样本集合根据属性测试的结果被划分到子结点中,根结点包含样本全集。从根结点到每个叶结点的路径对应了一个判定测试序列,决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树

决策树算法的基本流程

在这里插入图片描述
决策树算法从数据中提取出一系列规则,在创建规则的过程,就是决策树学习的过程。

选择最优划分属性(ID3)

从上面的流程中可以看出:当前数据集上选择哪个特征作为划分属性是该算法的关键。一般来说,我们希望随着划分过程的不断进行,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的“纯度”越来越来。1970年,Quinlan J.R找到了用信息论中的熵来度量决策树的决策选择过程,这种方法的简洁和高效在当时就引起了轰动,并命名为ID3算法。
在这里插入图片描述
在这里插入图片描述

决策树算法的优缺点

优点:计算复杂度不高,输出结果容易理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题
使用数据类型:数值型和标称型

决策树分类示例

在信贷问题上的应用示例

下面通过所给的训练数据学习一个贷款申请的决策树,用以对未来的贷款申请进行分类,即当新的客户提出贷款申请时,根据申请人的特征利用决策树决定是否批准贷款申请,样本数据如下表所示
在这里插入图片描述
对于上述样本数据,经过对属性值进行离散化处理,对于年龄离散化表示为[0, 1, 2],其中0,1,2分别对应青年,中年与老年;对于有工作一栏离散化表示为[0, 1],其中0,1分别对应无工作和有工作,同样对有房子这一属性进行相似的离散化;对于信贷情况按照上述方式离散化为[0, 1, 2] ,其中0,1,2分别对应一般,好,非常好;对于类别用yes和no表示。

对该问题利用复现的ID3算法构建决策树,得到的决策树最终为:
在这里插入图片描述
对结果进行验证,年龄这一列的数据,也就是特征A1,一共有三个类别,分别是:青年、中年和老年。其中年龄是青年的数据一共有5个,所以年龄是青年的数据在训练数据集出现的概率是三分之一。同理,年龄是中年和老年的数据在训练数据集出现的概率也都是三分之一。而年龄是青年的数据的最终得到贷款的概率为五分之二,因为在五个数据中,只有两个数据显示拿到了最终的贷款,同理,年龄是中年和老年的数据最终得到贷款的概率分别为五分之三、五分之四。所以计算年龄的信息增益,过程如下:
在这里插入图片描述

同理可以得到 , , ,
因此利用属性A3是否有房子进行划分得到的信息增益最大,所以首先选择有房子作为划分属性。它将训练集D划分为两个子集D1(A3取值为“是”)和D2(A3取值为“否”)。由于D1只有同一类的样本点,所以它成为一个叶结点,结点的类标记为“是”。

对D2则需要从特征A1(年龄),A2(有工作)和A4(信贷情况)中选择新的特征,计算各个特征的信息增益
在这里插入图片描述

根据计算,选择信息增益最大的特征A2(有工作)作为结点的特征。由于A2有两个可能取值,从这一结点引出两个子结点:一个对应“是”(有工作)的子结点,包含3个样本,它们属于同一类,所以这是一个叶结点,类标记为“是”;另一个是对应“否”(无工作)的子结点,包含6个样本,它们也属于同一类,所以这也是一个叶结点,类标记为“否”。这样就生成了一个决策树,该决策树只用了两个属性就构建好了。

在鱼类问题上的应用示例(含代码)

下面通过所给的训练数据学习一个海洋生物分类的决策树,根据海洋生物的特征利用决策树决定是否为鱼类,样本数据如下表所示
在这里插入图片描述
对于上述样本数据,经过对属性值进行离散化处理,用1表示“是”,0表示“否”。对于类别用“yes”与“no”表示是否为鱼类。

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']  # 特征的名称
    return dataSet, labels

生成决策树的主函数:
使用字典存储决策树mytree的结构;使用一个二维的list变量dataSet来保存数据集;使用一维的list变量labels来保存特征名称; 使用一维的list变量featLabels来保存选择的特征。

def createTree(dataSet, labels, featLabels):
    classList = [example[-1] for example in dataSet]      #取分类标签
    if classList.count(classList[0]) == len(classList):      #如果类别完全相同则停止划分
        return classList[0]
    if len(dataSet[0]) == 1:    #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)                #选择最优特征
    bestFeatLabel = labels[bestFeat]                            #最优特征的标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel:{}}                     #根据最优特征的标签生成树
    del(labels[bestFeat])                          #删除已经使用特征标签
    featValues = [example[bestFeat] for example in dataSet]  #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)            #去掉重复的属性值
    for value in uniqueVals:                 #遍历特征,创建决策树。                       
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return myTree

其中majorityCnt 用来统计出现次数最大的类标签,chooseBestFeatureToSplit函数选择最大的信息增益来选择当前最优划分属性。

def majorityCnt(classList):
    classCount = {}

    for vote in classList:
        classCount[vote] = classCount.get(vote, 0) + 1
        # if vote not in classCount.keys():
        #     classCount[vote] = 0
        # classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  # 从大到小排序
    return sortedClassCount[0][0]  # 返回出现次数最多的类别名称
def chooseBestFeatureToSplit(dataSet):
    # 特征数量
    numFeatures = len(dataSet[0]) - 1
    # 计算数据集的香农熵
    baseEntropy = calcShannonEnt(dataSet)
    # 信息增益
    bestInfoGain = 0.0
    # 最优特征的索引值
    bestFeature = -1
    # 遍历所有特征
    for i in range(numFeatures):
        # 获取dataSet的第i个所有特征存在featList这个列表中
        featList = [example[i] for example in dataSet]
        # 创建set集合{} 
        uniqueVals = set(featList)
        # 经验条件熵
        newEntropy = 0.0
        # 计算信息增益
        for value in uniqueVals:
            # subDataSet划分后的子集
            subDataSet = splitDataSet(dataSet, i, value)
            # 计算子集的概率
            prob = len(subDataSet) / float(len(dataSet))
            # 根据公式计算信息熵
            newEntropy += prob * calcShannonEnt(subDataSet)
        # 信息增益
        infoGain = baseEntropy - newEntropy
        # 比较得到最大信息增益
        if(infoGain > bestInfoGain):
            # 更新信息增益,找到最大的信息增益
            bestInfoGain = infoGain
            # 记录信息增益最大的特征的索引值
            bestFeature = i
    # 返回信息增益最大的特征的索引值
    return bestFeature

其中在chooseBestFeatureToSplit选择最优划分属性的函数中,利用calcShannonEnt计算信息熵,利用splitDataSet根据选择的属性划分数据集

# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 获得标签
        # 构造存放标签的字典
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt
# 划分数据集,三个参数为带划分的数据集,划分数据集的特征的索引位置,特征的值
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            # 将相同特征值的数据集的其他特征集合抽取出来
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 返回一个列表

最终构建的决策树如下:
在这里插入图片描述

构建好了决策树之后,则可以使用决策树进行分类

# 分类的函数,输入训练好的决策树,特征的label,以及待验证的数据,最后输出类别
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    # 获取下一组字典
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    if 'classLabel' not in dir():
        return "maybe"
    return classLabel

完整代码

# coding=utf-8
'''
'''
from math import log
import operator
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
from collections import Counter
import pickle

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']  # 分类的属性
    return dataSet, labels


# 计算给定数据的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 获得标签
        # 构造存放标签的字典
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
    # 计算香农熵
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


# 划分数据集,三个参数为带划分的数据集,划分数据集的特征的索引位置,特征的值
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            # 将相同特征值的数据集的其他特征集合抽取出来
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet  # 返回一个列表


# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeature = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    beatFeature = -1
    for i in range(numFeature):
        featureList = [example[i] for example in dataSet]  # 获取样本中所有的第i个特征的值
        uniqueVals = set(featureList)  # 从列表中创建集合,得到不重复的所有可能取值,也就是第i个特征所有的可能取值集合
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 以i索引位置的特征为划分属性,value为特征的值来划分数据集
            prob = len(subDataSet) / float(len(dataSet))  # 特征值为value所占的样本比例
            newEntropy += prob * calcShannonEnt(subDataSet)  # 计算特征值为value子集的信息熵,并累加计算最终划分后的信息熵
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        # 选择信息增益最大的划分属性
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


# 递归构建决策树
def majorityCnt(classList):
    classCount = {}

    for vote in classList:
        classCount[vote] = classCount.get(vote, 0) + 1
        # if vote not in classCount.keys():
        #     classCount[vote] = 0
        # classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  # 从大到小排序
    return sortedClassCount[0][0]  # 返回出现次数最多的类别名称


# 创建树的函数代码
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):  # 类别完全相同则停止划分
        return classList[0]
    if len(dataSet[0]) == 1:  # 遍历完所有特征值时返回出现次数最多的
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 选择最好的数据集划分方式
    bestFeatLabel = labels[bestFeat]  # 得到对应的标签值
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])  # 删掉labels[bestFeat],保证划分的数据集的属性索引和labels对应
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)  # 获取当前划分属性所有属性值来划分数据集并递归创建树
    for value in uniqueVals:
        subLabels = labels[:]
        # 递归调用创建决策树函数
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


# ==================下面是绘制决策树的函数==============

def getNumLeafs(myTree):
    # 初始化叶子
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    # 获取下一组字典
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 测试该结点是否为字典,如果不是字典,代表此节点为叶子结点
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    # 初始化决策树深度
    maxDepth = 0
    # python3中myTree.keys()返回的是dict_keys,不是list,所以不能用
    # myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    # next() 返回迭代器的下一个项目 next(iterator[, default])
    firstStr = list(myTree.keys())[0]
    # 获取下一个字典
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 测试该结点是否为字典,如果不是字典,代表此节点为叶子结点
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        # 更新最深层数
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    # 返回决策树的层数
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # 定义箭头格式
    arrow_args = dict(arrowstyle="<-")
    # 设置中文字体
    font = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=14)
    # 绘制结点createPlot.ax1创建绘图区
    # annotate是关于一个数据点的文本
    # nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType,
                            arrowprops=arrow_args, FontProperties=font)


def plotMidText(cntrPt, parentPt, txtString):
    # 计算标注位置(箭头起始位置的中点处)
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    # 设置结点格式boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    # 设置叶结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")
    # 获取决策树叶结点数目,决定了树的宽度
    numLeafs = getNumLeafs(myTree)
    # 获取决策树层数
    depth = getTreeDepth(myTree)
    # 下个字典
    firstStr = next(iter(myTree))
    # 中心位置
    cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
    # 标注有向边属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    # 绘制结点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    # 下一个字典,也就是继续绘制结点
    secondDict = myTree[firstStr]
    # y偏移
    plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
        if type(secondDict[key]).__name__ == 'dict':
            # 不是叶结点,递归调用继续绘制
            plotTree(secondDict[key], cntrPt, str(key))
        # 如果是叶结点,绘制叶结点,并标注有向边属性值
        else:
            plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
            plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD


def createPlot(inTree):
    # 创建fig
    fig = plt.figure(1, facecolor="white")
    # 清空fig
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    # 去掉x、y轴
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 获取决策树叶结点数目
    plotTree.totalW = float(getNumLeafs(inTree))
    # 获取决策树层数
    plotTree.totalD = float(getTreeDepth(inTree))
    # x偏移
    plotTree.xoff = -0.5 / plotTree.totalW
    plotTree.yoff = 1.0
    # 绘制决策树
    plotTree(inTree, (0.5, 1.0), '')
    # 显示绘制结果
    plt.show()


# 分类的函数,输入训练好的决策树,特征的label,以及待验证的数据,最后输出类别
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    # 获取下一组字典
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    if 'classLabel' not in dir():
        return "maybe"
    return classLabel


if __name__ == "__main__":
    dataSet, labels = createDataSet()
    clabels = labels[:]
    myTree = createTree(dataSet, clabels)
    print(classify(myTree, labels, [1, 0]))  # 若测试集中出现了训练集中未出现的属性值,则异常,这个点需要改
    createPlot(myTree)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-04 15:33:38  更:2022-03-04 15:37:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 1:58:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码