我是靠谱客的博主 温柔季节,最近开发中收集的这篇文章主要介绍机器学习--决策树,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

决策树学习1)采用自顶向下的递归方法

                  2)基本思想是以信息熵为度量,向下构造一颗熵值下降最快的树,到叶子结点处熵值为0.

                  3)属于有监督学习

决策树算法历史1)Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法

                         2)Breiman等人在1984年提出的CART算法。

决策树呈树形结构,在分类问题中表示基于特征对实例进行分类的过程。

决策树的分类:1)离散型决策树:目标变量为离散型(CLS,ID3,C45)

                         2)连续型决策树:目标变量是连续型(CART)

决策树的构造过程:

1)特征选择:从若干特征中选择一个特征作为当前节点分裂的标准。

     方法:ID3(信息增益)

                C4.5(信息增益比)

                CART(Gini基尼指数)

2)决策树的生成

根据选择特征评估标准,从上到下递归地生成子节点,直到数据集不可分。目标是使某个特征划分后各个子集纯度更高,不确定性更小。

3)决策树的裁剪

决策树容易过拟合(over-fitting)通过剪枝来缩小结构规模、缓解过拟合。

剪枝方法有:预剪枝:在结点划分前进行预判断,如果划分后能够使子集纯度更纯则进行,反之不进行。

                      后剪枝:先生成一棵完整的树,然后自底向上对非叶结点考察是否替换子树为叶节点。

决策树的优缺点:

优点:可读性强,分类速度快;

缺点:容易出现过拟合,对未知的测试数据结果不一定好。可采用剪枝或者随机森林。

ID3算法

1)决策树中的每一个非叶子结点对应一个特征属性,树枝代表这个属性的值,叶节点代表最终分类属性值。

2)每一个非叶子结点与属性中具有最大信息量的特征属性相关联。

3)熵通常用于测量一个非叶子结点的信息量大小。

实现步骤:

1、创建数据集

2、createTree创建决策树

     1)判断生成叶子结点/结点

     2)选择最佳属性划分方式:选择最大信息增益(重点

emmmmmmmmmm---------------------------------

上python37代码

创建trees.py

#!/usr/bin/python
# -*- coding: UTF-8 -*-
from math import log
import operator
# 划分数据集,axis:按第几个属性划分,value:要返回的子集对应的属性值
def splitDataSet(dataSet,axis,value):
retDataSet=[] featVec=[]
for featVec in dataSet:
if featVec[axis]==value:
#将featVec[axis]单独分出去,剩下的数据集搞到reducedFeatVec中
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 计算信息熵
def calcShannonEnt(dataSet):
numEntries=len(dataSet)# 样本数
labelCounts={}
for featVec in dataSet:# 遍历每个样本
currentLabel=featVec[-1]# 当前样本的类别
if currentLabel not in labelCounts.keys():# 生成类别字典
labelCounts[currentLabel]=0#初始化新类别中0个样本
labelCounts[currentLabel]+=1#对当前样本类别进行计数
shannonEnt=0.0
for key in labelCounts:#计算信息熵
prob=float(labelCounts[key])/numEntries
shannonEnt=shannonEnt-prob*log(prob,2)
return shannonEnt
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1
# 属性个数
baseEntropy=calcShannonEnt(dataSet)
bestInfoGain=0.0
bestFeature=-1
for i in range(numFeatures):#对每个属性计算信息增益
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)#该属性的取值集合
newEntropy=0.0
for value in uniqueVals:#对每一种取值计算信息增益
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy
if (infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
# 递归构建决策树
# 通过排序返回出现次数最多的类别
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
# 递归构建决策树
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]#类别向量
if classList.count(classList[0])==len(classList):#如果只有一个类别返回
return classList[0]
if len(dataSet[0])==1:#如果所有特征都被遍历完了,返回出现次数最多的类别
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]#最优划分属性的标签
myTree={bestFeatLabel:{}}
del (labels[bestFeat])#已经选择的特征不再参与分类
featValues=[example[bestFeat]for example in dataSet]
uniqueValue=set(featValues)#该属性所有可能取值,节点的分支
for value in uniqueValue:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
创建Plotter.py
# -*- coding: cp936 -*-
import matplotlib.pyplot as plt
# 设置决策节点和叶节点的边框形状、边距和透明度,以及箭头的形状
decisionNode = dict(boxstyle="square,pad=0.5", fc="0.9")
leafNode = dict(boxstyle="round4, pad=0.5", fc="0.9")
arrow_args = dict(arrowstyle="<-", connectionstyle="arc3", shrinkA=0,
shrinkB=16)
# 获得树的叶子结点数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]# 获得当前第一个根节点
secondDict = myTree[firstStr] # 获取该根下的子树
for key in secondDict.keys(): # 获得所有子树的根节点进行遍历
if type(secondDict[key]).__name__ == 'dict': # 如果子节点是dict类型则不是子节点需要继续遍历
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 获得树的深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]# 获得当前第一个根节点
secondDict = myTree[firstStr] # 获取该根下的子树
for key in secondDict.keys(): # 获取所有子树节点,进行遍历
if type(secondDict[key]).__name__ == 'dict':# 如果子树类型为dict则不是叶子结点
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 计算父节点到子节点的中点坐标,在该点上标注txt信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# 给createPlot子节点绘图添加注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

测试数据集

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import ID3
import json
import Plotter
fr = open(r'C:UsersLMQuntitledactivityData.txt')
listWm = [inst.strip().split('t') for inst in fr.readlines()]
labels = ['天气', '温度', '湿度', '风速']
Trees = ID3.createTree(listWm, labels)
print(json.dumps(Trees, ensure_ascii=False))
Plotter.createPlot(Trees)

PS:最终在Pycharm上生成的决策图上仍然没有注释

 

 

     

 

 

                          

最后

以上就是温柔季节为你收集整理的机器学习--决策树的全部内容,希望文章能够帮你解决机器学习--决策树所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(43)

评论列表共有 0 条评论

立即
投稿
返回
顶部