我是靠谱客的博主 糟糕纸飞机,最近开发中收集的这篇文章主要介绍决策树的完整实现决策树,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

文章目录

  • 决策树
    • 决策树的构造
      • 信息增益
      • 划分数据集
      • 递归构建决策树
    • 在python中使用Matplotlib注解绘制树形图
      • Matplotlib注解
      • 构造注解树
    • 测试和存储分类器
      • 测试算法:使用决策树执行分类
      • 使用算法:决策树的存储
    • 示例:使用决策树预测隐形眼镜类型
    • 算法改进
      • 信息增益(ID3算法)
      • 信息增益率(C4.5算法)
      • 基尼指数(CART算法)
    • 剪枝
    • 实现树的可视化
    • 总结

决策树

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。
决策树的生成过程主要分为以下几部分:

  1. 特征选择:特征选择是指从训练数据中众多的特征中选择一个特征作为当前节点的分裂标准,如何选择特征有着很多不同量化评估标准,从而衍生出不同的决策树算法。
  2. 决策树生成: 根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则决策树停止生长。 树结构来说,递归结构是最容易理解的方式。
  3. 剪枝:剪枝是决策树停止分支的方法之一,剪枝技术有预剪枝和后剪枝两种。

决策树示意图:
在这里插入图片描述

决策树的构造

决策树的优缺点以及适用数据类型:

  1. 优点:计算复杂度不高,输出结果易于理解,对中间值的确实不敏感,可以处理不相关特征数据。
  2. 缺点:可能会产生过度匹配问题。
  3. 适用数据类型:数值型和标称型。

构造决策树时首先要解决当前数据集上哪个特征在划分数据分类时起决定性作用的问题,为了找到决定性特征,划分出最好的结果,我们需要评估每个特征,在完成测试之后,原始数据集就划分为了几个数据子集,这些数据子集会分布在第一个决策点的所有分支上,若某个分支下的数据属于同一类型,则无需进一步对数据集进行分割。若数据子集内的数据不属于同一类型,则需要重复划分数据子集的过程。划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数据均在一个数据子集内。
创建分支的伪代码函数createBranch如下所示:
检测数据集中的每个子项是否属于同一分类:

If so return 类标签
Else
	寻找划分数据集的最好特征
	划分数据集
	创建分支节点
		for每个划分的子集
			调用函数createBranch并增加返回结果到分支节点中
	return 分支节点

决策树的一般流程

  1. 收集数据:可以使用任何方法
  2. 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
  4. 训练算法:构造树的数据结构
  5. 测试算法:使用经验树计算错误率
  6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义

信息增益

划分数据集的大原则:将无序的数据变得更加有序。
组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
集合信息的度量方式称为香农熵或简称为熵。
熵定义为信息的期望值,若待分类的事务可能划分在多个分类当中,则符号 x i x_i xi的信息定义为: l ( x i ) = − l o g 2 p ( x i ) l(x_i)=-log_2p(x_i) l(xi)=log2p(xi),其中 p ( x i ) p(x_i) p(xi)是选择该分类的概率。为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面公式得到: H = − Σ i = 1 n p ( x i ) l o g 2 p ( x i ) H=-Σ^n_{i=1}p(x_i)log_2p(x_i) H=Σi=1np(xi)log2p(xi),其中n是分类的数目。
计算给定数据集的香农熵的代码部分:

from math import log
import operator
def calcshannonEnt(dataSet): #计算给定数据集的香农熵
    nuMEntries = len(dataSet) #记录数据集行数
    labelCounts = {}  #记录每个标签出现次数的字典
    for featVec in dataSet:  #为所有可能分类创建字典
        currentLabel = featVec[-1] #提取标签信息
        if currentLabel not in labelCounts.keys(): #若标签未放入统计次数的字典,进行添加
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0 #经验熵
    for key in labelCounts:
        prob = float(labelCounts[key])/nuMEntries  #选择该分类的概率
        shannonEnt -= prob * log(prob,2) #,以2为底求对数,计算香农熵
    return shannonEnt

在这部分计算给定数据集的香农熵的代码当中,首先,我们要计算出数据集中实例的总数,这里我们显示地声明一个变量保存实例总数。然后再创建一个数据字典,它的键值为最后一列的数值,若当前键值不存在,则扩展字典并将当前键值加入字典。这里每个键值都记录了当前类别出现的次数,最后再使用所有类标签的发生概率计算类别出现的概率。我们将用这个概率计算香农熵,统计所有类标签发生的次数。
利用createDataSet函数得到简单鱼鉴定数据集:

def createDataset():
    dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

计算香农熵运行结果:
在这里插入图片描述
熵越高,混合数据越多,我们可以在数据集中添加更多分类,观察熵如何变化,我们增加第三个名为maybe的分类,测试熵的变化:

myData[0][-1]='maybe'
    print(myData)
    print(calcshannonEnt(myData))

测试结果:
在这里插入图片描述

另一个度量集合无需程度的方法是基尼不纯度,简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。以下为计算基尼不纯度的公式 I G ( f ) = ∑ i = 1 m f i ( 1 − f i ) = ∑ i = 1 m f i − ∑ i = 1 m f 2 i = 1 − ∑ i = 1 m f 2 i IG(f)=∑i=1mfi(1−fi)=∑i=1mfi−∑i=1mf2i=1−∑i=1mf2i IG(f)=i=1mfi(1fi)=i=1mfii=1mf2i=1i=1mf2i.基尼不纯度具有如下特点:

  1. 基尼不纯度越小,纯度越高,集合的有序程度越高,分类的效果越好;
  2. 基尼不纯度为 0 时,表示集合类别一致;
  3. 基尼不纯度最高(纯度最低)时,f1=f2=…=fm=1m,IG(f)=1−(1m)2×m=1−1m

划分数据集

分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确的划分了数据集。我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。
代码如下所示:

def splitDataSet(dataSet,axis,value): #参数一:待划分的数据集,参数二:划分数据集的特征,参数三:需要返回的特征值
    retDataSet = [] #创建新的list对象
    for featVec in dataSet: #遍历抽取符合特征的数据值
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

在该函数中,为了不修改原始数据集,我们创建了一个新的列表对象,数据集这个列表中的各个元素也是列表,我们要遍历数据集中的每个元素,一旦发现符合要求的值,则将其添加到新创建的列表中,在抽取符合要求的元素时,我们用到了extend和append方法,这里需要注意区分两种方式的区别:假设a=[1,2,3],b=[4,5,6],则a.append(b)=[1,2,3,[4,5,6]],而a.extend(b)=[1,2,3,4,5,6]。
测试函数splitDataSet:

   print(splitDataSet(myData,0,1))
    print(splitDataSet(myData,0,0))

测试结果:
在这里插入图片描述

接下来我们需要选择最好的数据集划分方式,代码如下所示:

def chooseBestFeatureTOSplit(dataSet):
    numFeatures = len(dataSet[0])-1 #特征数量
    baseEntropy = calcshannonEnt(dataSet) #计算数据集香农熵
    bestInfoGain = 0.0; #信息增益
    bestFeature = -1; #最优特征索引值
    for i in range(numFeatures): #创建唯一的分类标签列表
        featList = [example[i] for example in dataSet] #获取数据集的第i个所有特征
        uniqueVals = set(featList) #创建set集合,元素不可重复
        newEntropy = 0.0 #经验条件熵
        for value in uniqueVals: #计算每种划分方式的信息熵
            subDataSet = splitDataSet(dataSet,i,value) #划分后的子集
            prob = len(subDataSet)/float(len(dataSet)) #计算子集概率
            newEntropy += prob * calcshannonEnt(subDataSet) #计算经验条件熵
        infoGain = baseEntropy - newEntropy #信息增益
        if(infoGain>bestInfoGain): #计算最好的信息增益
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

在开始划分数据集前,我们需要计算整个数据集的原始香农熵,保存最初的无序度量值,用于与划分后的数据集计算的熵值进行比较。第一个for循环先遍历数据集中的所有特征,然后创建新的列表,并将数据集中所有第i个特征值或所有可能存在的值写入列表中,然后从列表中创建集合得到列表中唯一元素值。遍历当前特征中的所有唯一属性值,对每一个唯一属性值划分一次数据集,然后计算数据集的新熵值,并对所有唯一特征值得到的熵求和,最后比较所有特征中的信息增益,返回最好特征划分的索引值。
测试实际输出结果:

myData,labels = createDataset()
print(myData)
print(chooseBestFeatureTOSplit(myData))

输出结果:
在这里插入图片描述
输出结果显示:第0个特征是最好的用于划分数据集的特征。
若我们按照第一个特征属性划分数据,即第一个特征是1的放在第一组,第一个特征是0的放在另一个组,第一个特征为1的海洋生物分组将有两个属于鱼类,一个属于非鱼类;另一个分组则全部属于非鱼类;
若按照第二个特征分组,第一个海洋动物分组将有两个属于鱼类,两个属于非鱼类;另一个分组则只有一个非鱼类。
可以看出第一种划分很好的处理了相关数据。

递归构建决策树

工作原理: 得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分,第一次划分之后,数据将被向下传递到树分支的下一节点,在这个节点上,我们可以再次划分数据,因此我们可以采用递归的原则处理数据集。
递归结束的条件: 程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,若所有实例具有相同的分类,则得到一个叶子节点或者终止块,任何到达叶子节点的数据必然属于叶子节点的分类。
若数据集已处理完所有属性后,但类标签依然不是唯一的,我们就需要决定如何定义该叶子节点,我们通常采用多数表决的方法决定该叶子节点的分类,代码如下所示:

def majorityCnt(classList):
    classCount={}
    for vote in classList: #统计classList中每个元素出现的次数
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) #排序字典
    return sortedClassCount[0][0]

创建树的代码如下所示:

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):#类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureTOSplit(dataSet) #选择最优特征
    bestFeatLabel = labels[bestFeat] #最优特征标签
    myTree = {bestFeatLabel:{}}
    del (labels[bestFeat]) #得到列表包含的所有属性值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals: #创建决策树
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

首先创建classList的列表变量用于包含数据集的所有类标签。递归函数的第一个停止条件是所有类标签完全相同,则直接返回该类标签,第二个停止条件是使用完所有特征后仍不能将数据集划分成仅包含唯一类别的分组,由于第二个条件无法简单返回唯一的类标签,因此这里使用majorityCnt函数挑选出现次数最多的类别作为返回值。创建树时使用字典变量myTree存储树的所有信息,当前数据集选取的最好特征存储在变量bestFeat中,得到列表包含的所有属性值,最后遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree,得到的返回值插入到字典变量myTree中,在函数终止执行时,字典中将嵌套很多代表叶子节点信息的字典数据。subLabels = labels[:]这行代码复制了类标签并将其存储在新列表变量subLabels中,因为当函数参数是列表类型时,参数是按照引用方式传递的,为保证每次调用函数createTree时不改变原始列表的内容,使用新变量subLabels代替原始列表。
测试输出结果:

    myData,labels = createDataset()
    print(createTree(myData,labels))

测试结果:
在这里插入图片描述
变量myTree包含很多代表树结构信息的嵌套字典,从左边开始,第一个关键字no surfacing是第一个划分数据集的特征名称,该关键字的值是另一个数据字典。第二个关键字是no surfacing特征划分的数据集,这些关键字的值是no surfacing节点的子节点。这些值可能是类标签,也可能是另一个数据字典,若值是类标签,则该子节点是叶子节点;若值是另一个数据字典,则子节点是一个判断节点,这种格式结构不断重复就构成了整棵树。这棵树包含了3个叶子节点和2个判断节点。

在python中使用Matplotlib注解绘制树形图

Matplotlib注解

在从数据集中创建树时,字典的表示形式非常不易于理解,而且直接绘制图形也比较困难。决策树的主要优点就是直观易于理解,若不能将其直观的显示出来,就无法发挥其优势,这里我们就将使用Matplotlib库创建树形图。

import matplotlib.pyplot as plt
from pylab import mpl
decisionNode = dict(boxstyle='sawtooth',fc="0.8")#定义文本框和箭头格式
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):#绘制带箭头的注解
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
def createPlot():
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111,frameon=False)
    mpl.rcParams['font.sans-serif'] = ['SimHei']  # 黑体
    plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

该部分代码定义了描述树节点格式的常量,定义了plotNode函数执行实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。createPlot函数首先创建了一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节。
测试输出结果:
createPlot()
在这里插入图片描述

构造注解树

在绘制一棵完整的树时,我们虽然有x,y坐标,但如何放置所有的树节点是个问题,我们必须知道有多少个叶节点,以便可以正确的确定x轴的长度,我们还需要知道树有多少层,以便可以正确确定y轴的高度,这里我们定义两个函数getNumLeafs和getTreeDepth,来获取叶节点的数目和树的层数。

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict': #测试节点的数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:
            maxDepth = thisDepth
    return maxDepth

第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值,从第一个关键字出发,我们可以遍历整棵树的所有子节点,并使用type函数判断子节点是否为字典类型,若子节点是字典类型,则该节点为一个判断节点,需要递归调用getNumLeafs函数,getNumLeafs函数遍历整棵树,累计叶子节点的个数,并返回该数值。第二个函数getTreeDepth计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。为了节省时间,我们通过retrieveTree函数输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦。

def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

输出测试结果:

print(retrieveTree(1))
myTree=retrieveTree(0)
print(getNumLeafs(myTree))
print(getTreeDepth(myTree))

在这里插入图片描述
retrieveTree函数主要用于测试,返回预定义的树结构,getNumLeafs函数返回值为3,等于树0的叶子节点数,调用getTreeDepth函数也能正确返回树的层数。
现在开始绘制一棵完整的树:

def plotMidText(cntrPt,parentPt,txtString):#在父子节点间填充文本信息
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)#计算宽与高
    depth = getTreeDepth(myTree)
    firstStr = list(myTree)[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)#标记子节点属性值
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#减少y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

全局变量plotTree.totalW存储树的宽度,plotTree.totalD存储树的高度,我们使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。树的宽度用于计算放置判断节点的位置,主要计算原则是将它放在所有叶子节点中间,而不仅是它子节点的中间,同时用plotTree.xOff和plotTree.yOff追踪已绘制的节点位置,以及放置下一个节点的恰当位置。并且我们按照叶子节点的数目将x轴划分为若干部分,这样无需关心实际输出图形的大小,一旦图形大小发生变化,函数会自动按照图形大小重新绘制。接着,绘出子节点的特征值,或沿此分支向下的数据实例必须具有的特征值,然后在父节点和子节点的中间位置添加简单的文本标签信息。然后,按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点,节点可以是叶子节点也可以是判断节点,只需保存绘制图形的轨迹,由于我们是自上向下绘制图形,因此需依次递减y坐标值。然后采用函数getNumLeafs和getTreeDepth以相同的方式递归遍历整棵树,若节点是叶子节点,则在图像上画出节点,若不是叶子节点,则递归调用plotTree函数,绘制所有子节点后,增加全局变量Y的偏移。

测试输出结果:

myTree=retrieveTree(0)
createPlot(myTree)

在这里插入图片描述

myTree['no surfacing'][3]='maybe'
print(myTree)
createPlot(myTree)

在这里插入图片描述
在这里插入图片描述

测试和存储分类器

测试算法:使用决策树执行分类

依靠训练数据构造完决策树后,我们将它用于实际数据分类,在执行数据分类时,需要使用决策树以及用于构造决策树的标签向量,然后比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型。

def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree)[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr) #将标签字符串转化为索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':#若值为字典,则该节点为判断节点
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:#若值是类标签,则该子节点为叶子节点
                classLabel = secondDict[key]
    return classLabel

测试输出结果:

myData,labels = createDataset()
    print(labels)
    myTree = retrieveTree(0)
    print(myTree)
    print(classify(myTree,labels,[1,0]))
    print(classify(myTree,labels,[1,1]))

在这里插入图片描述
第一节点为no surfacing,它有两个子节点:一个名为0的叶子节点,类标签为no;另一个名为flippers的判断节点,此处进入递归调用,flippers节点有两个子节点,以前绘制的树形图和此处代表树的数据结构完全相同。

使用算法:决策树的存储

为了节省时间,最好能够在每次执行分类时调用已构造好的决策树,为了解决这个问题,我们需要使用pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来,任何对象都可以执行序列化操作,字典对象也不例外。

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename,'rb')
    return pickle.load(fr)

测试输出结果:

    myTree = retrieveTree(0)
    storeTree(myTree,'classifierStorage.txt')
    print(grabTree('classifierStorage.txt'))

在这里插入图片描述
在这里插入图片描述

示例:使用决策树预测隐形眼镜类型

隐形眼镜数据集是非常著名的数据集,它包含很多患者眼部状况的观察条件和医生推荐的隐形眼镜类型,隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜,数据来源于UCI数据库。

	fr = open('lenses.txt')
    lenses=[inst.strip().split('t') for inst in fr.readlines()]
    lensesLabels= ['age','prescript','astigmatic','tearRate']
    lensesTree = createTree(lenses,lensesLabels)
    print(lensesTree)

    createPlot(lensesTree)

测试输出结果:
在这里插入图片描述
从树形图可以看出,沿着决策树的不同分支,我们可以得到不同患者需要佩戴的隐形眼镜类型,我们可以发现,医生最多需要问四个问题就能确定患者需要佩戴哪种类型的隐形眼镜。
决策树很好的匹配了实验数据,然而这些匹配选项太多了,我们将其称为过度匹配,为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子节点,若叶子节点只能增加少许信息,则可以删除节点,将它并入到其他叶子节点中。
ID3算法无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数据,但若存在太多的特征划分,ID3算法仍然会面临其他问题。

算法改进

信息增益(ID3算法)

ID3算法是一个好的算法,但仍存在着不足:信息增益偏向取值更多的特征。当特征取值较多时,根据此特征划分获得纯度更高的子集。比如将编号作为一个特征取值的话,每个节点只有一种可能取值,熵值为0,以此类推,总体熵值也为0,因此信息增益更大,决策树就会把编号作为根节点,但实际上并没有意义。

信息增益率(C4.5算法)

C4.5算法采用了一个启发式方法:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。
可定义增益率:
G a i n Gain Gain_ r a t i o ( D , a ) = G a i n ( D , a ) / I V ( a ) ratio(D,a)=Gain(D,a)/IV(a) ratio(D,a)=Gain(D,a)/IV(a)

I V ( a ) = − Σ v = 1 V ∣ D v ∣ / ∣ D ∣ l o g 2 ∣ D v ∣ / ∣ D ∣ IV(a)=-Σ^V_{v=1}|D^v|/|D|log_2|D^v|/|D| IV(a)=Σv=1VDv/Dlog2Dv/D称为属性a的固有值,属性a的可能取值数目越多,即V越大,则IV(a)越大。
但该算法对可取值数目较少的属性有所偏好,当某个属性的可取值数目较少时,IV(a)较小,则增益率较大。

基尼指数(CART算法)

在分类问题中,假设D有K个类,样本点属于第k个类的概率为 p k p_k pk,则概率分布的基尼指定义为:
G i n i ( D ) = Σ k = 1 K p k ( 1 − p k ) = 1 − Σ k = 1 K p k 2 Gini(D)=Σ^K_{k=1}p_k(1-p_k)=1-Σ^K_{k=1}p^2_k Gini(D)=Σk=1Kpk(1pk)=1Σk=1Kpk2
Gini(D)越小,数据集D的纯度越高;反映了随机抽取两个样本,其类别标记不一致的概率。
给定数据集D,属性a的基尼指数定义为:
G i n i i n d e x ( D , a ) = Σ v = 1 V ∣ D v ∣ / ∣ D ∣ G i n i ( D v ) Gini_{index}(D,a)=Σ^V_{v=1}|D^v|/|D|Gini(D^v) Giniindex(D,a)=Σv=1VDv/DGini(Dv)
在侯选属性集合A中,选择那个使得划分后基尼指数最小的属性作为最有划分属性。

剪枝

为什么要进行剪枝?

  1. “剪枝”是决策树学习算法对付“过拟合”的主要手段
  2. 可通过“剪枝”来一定程度避免因决策分支过多,以致于把训练集
    自身的一些特点当做所有数据都具有的一般性质而导致的过拟合

剪枝的策略:

  1. 预剪枝
  2. 后剪枝

预剪枝:通过提前停止树的构建而对树剪枝,主要方法有:
1.当决策树达到预设的高度时就停止决策树的生长
2.达到某个节点的实例具有相同的特征向量,即使这些实例不属
于同一类,也可以停止决策树的生长。
3.定义一个阈值,当达到某个节点的实例个数小于阈值时就可以
停止决策树的生长。
4.通过计算每次扩张对系统性能的增益,决定是否停止决策树的
生长。
上述不足:阈值属于超参数,很难找到过拟合–欠拟合的trade-off

预剪枝的优缺点
•优点
–降低过拟合风险
–显著减少训练时间和测试时间开销。
•缺点
–欠拟合风险:有些分支的当前划分虽然不能提升泛化性能,但在其基础上进行的后续划分却有可能显著提高性能。预剪枝基于“贪心”本质禁止这些分支展开,带来了欠拟合风险。

后剪枝:先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行分析计算,若将该结点对应的子树替换为叶结点能带来决策树泛化性
能提升,则将该子树替换为叶结点。

后剪枝的优缺点
•优点
–后剪枝比预剪枝保留了更多的分支,欠拟合风险小,泛化性能往往优于预剪枝决策树
•缺点
–训练时间开销大:后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶结点逐一计算

实现树的可视化

	import matplotlib.pyplot as plt
	from sklearn import tree
	from sklearn.datasets import load_iris
	from sklearn.tree import plot_tree
	iris = load_iris()
    x = iris.data[:,2:]
    y = iris.target
    clf = tree.DecisionTreeClassifier(random_state=0,max_depth=2)#训练树模型,树深度为2
    clf = clf.fit(x,y)
    plt.figure()
    plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
    plt.show()

实现树的可视化可以直接使用sklearn.tree自带的plot_tree()方法,也可以使用Graphviz,还可以使用pydotplus模块,这里我使用的是sklearn自带的数据集以及plot_tree()方法。
测试输出结果:
在这里插入图片描述
精度评估:
计算该决策树的准确率:

	from sklearn.metrics import accuracy_score
	from sklearn.model_selection import train_test_split
	iris = load_iris()
    x = iris.data[:,2:]
    y = iris.target
    feature_train, feature_test, target_train, target_test = train_test_split(x,y, test_size=0.3)
    clf = tree.DecisionTreeClassifier(random_state=0,max_depth=2)#训练树模型,树深度为2
    clf = clf.fit(feature_train,target_train)
    predict_results = clf.predict(feature_test)
    print(accuracy_score(predict_results, target_test))

输出结果:
在这里插入图片描述
当树的深度max_depth为2时,得到的准确率为0.98,当改变树的深度时,准确率也会发生相应改变。

总结

通过这次实验,我将书上的代码实现了一遍,并且对决策树的构造原理,流程,算法等都有了进一步的理解。另外,我也使用sklearn的数据集实现了决策树的可视化,同时对ID3算法,C4.5算法以及CART算法都进行了学习。

最后

以上就是糟糕纸飞机为你收集整理的决策树的完整实现决策树的全部内容,希望文章能够帮你解决决策树的完整实现决策树所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(53)

评论列表共有 0 条评论

立即
投稿
返回
顶部