概述
决策树(Decision Tree)
决策树是一种常见的机器学习方法,它是从根节点开始,一步一步决策,直到走到叶子节点。
最终,所有的样本数据都会落到叶子节点,显然,决策过程的最终结论对应了我们所希望的判定结果 。
它即可以做分类问题,也可以做回归问题。
决策树组成
一般,一棵决策树包含一个根节点,若干个内部节点(非叶子节点)和若干个叶子节点。
1.根节点:第一个选择点
2.内部节点(非叶子节点):中间决策过程
3.叶子节点:最终的决策结果
决策树的训练流程
如何从给点定的训练集中去构造一棵决策树呢?
其实,决策树的构建就是一个递归过程,从根节点开始,根据选择的特征,将原始数据集切分为几个分支,然后依次遍历每个分支,在剩下的特征集中继续选择一个特征,进行对应划分,就是这个流程,一直递归下去,直到不可再分。
那么,什么情况下会不可再分呢?有三种情况:
1.当前节点包含的样本全属于同一类别,不需要再划分
2.当前属性集已为空,或者 当前节点中的所有样本在当前所剩属性集中取值相同,无法再分
3.当前节点包含的样本集为空,无法划分
具体的大致算法流程如下:
我们可以发现,上述算法中,最重要的一点,就是在每个节点上,选择最优的划分特征,也就说,每次在节点处划分时,我们都需要考虑,选择剩余特征集中的哪个属性进行划分,可以更好的划分数据呢?这个,也就是决策树算法的核心
决策树划分选择
通过上述,我们直到,决策树算法的核心是,如何选择最优的划分特征,我们希望随着划分的进行,我们经过每次划分后的分支节点所包含的样本尽可能的属于同一类别,也就是节点中所包含的样本纯度越来越高。从而,我们引入信息熵这个衡量标准
信息熵
信息熵表示的是随机变量不确定性的度量,熵越大,不确定性越强,也就是说纯度越低;
熵越小,不确定性越弱,纯度越高
设置样本集合D中总共有K类样本,其中第k类样本所占的比例为Pk(k=1,2,3,…,K),则的信息熵定义为:
对于一个二分类问题,我们可以画出信息熵的图像看看
上述图像中,X轴表示正样本的概率,Y轴表示对应信息熵。可以看到,当概率为0.5时,对应信息熵最大,也就是说此时不确定度最大,大于0.5,或者小于0.5时,信息熵都会减小。
信息增益
那么,通过信息熵,我们如何进行决策树划分选择的衡量呢,我们引入信息增益这个概念。
我们假设特征集中有一个离散特征a,它有V个可能的取值(a1,a2,…,aV),
如果使用特征a来对样本D进行划分,那么会产V个分支节点,其中第v个分支节点中包含的样本集。我们记为Dv。
于是,可计算出特征a对样本集D进行划分所获得的信息增益为:
解释下上面公式,其实特征a对样本集D进行划分所获得的信息增益 即为 样本集D的信息熵 减去 经过划分后,各个分支的信息熵之和。由于每个分支节点,所包含的样本数不同,所有在计算每个分支的信息熵时,需要乘上对应权重|D^v|除以|D|,即样本数越多的分支节点对应的影响越大
下面,我们具体看个例子:
上那个那个是某人在某月的1到14号的打球记录,我们看下,对应特征有四个,分别为天气(outlook,我们就认为是天气吧- -),温度,湿度,是否有风。输出值为是否打球。
显然,我们的样本总类K=2,其中正例占比p1=914,负例占比p1=514,根节点所包含的样本集D对应的信息熵为:
然后,我们需要计算当前特征集合(天气,温度,湿度,风级)中每个特征的信息增益。
以天气这个特征为例,如果以天气划分,则可将数据集D划分为三个子集,分别标记为:D^1 (outlook = sunny)
D^2(outlook = overcast)
D^3(outlook = rainy)
划分以后,三个分支节点的熵值分别为:
然后,我们可以算出,特征outlook(天气)对应的信息增益是:
同样的,我们可以依次算出其他特征所对应的信息增益,然后判断哪个信息增益最大,则就以此特征来作为当前节点的划分。
假设最后算得,采用outlook来进行当前根节点的划分,则对于生成的三个节点分支,依次再对应每个分支节点进行上述流程(算在此分支节点的数据集上,剩余的特征集合中哪个信息增益最大,作为当前分支节点的分割特征,一直递归下去)。
这其实就是ID3算法,以信息增益作为准则来进行划分特征。
信息增益率
我们思考下,上面说的以信息增益作为准则来进行划分属性,有什么缺点没?
假设对于上面的数据集,我们增加一列特征,为 data(日期),针对上面的14个样本数据,对应的值为(1,2,3,4,5,6,7,8,9,10,11,12,13,14),根据上式可计算出,data(日期)所对应的的信息增益为:Gain(D,data)=0.940,
我们发现,它所对应的信息增益远大于其他特征,所以我们要以data特征,作为第一个节点的划分依据吗?这样划分的话,将产生14个分支,每个分支对应只包含一个样本,可以看到,每个分支节点的纯度已达到最大,也就是说每个分支节点的结果都非常确定。但是,这样的决策树,肯定不是我们想要的,因为它根本不具备任何泛化能力。
这就是ID3算法,也就是信息增益准则的一个缺点,它在选择最优划分特征时,对可取数目比较多的特征有所偏好,如何避免这个问题呢,我们引入增益率这个概念,改为使用增益率来作为最优划分特征的标准,同样,选择增益率最大的那个特征来作为最优划分特征,这也就是C4.5决策树算法
同样假设有数据集D,以及特征a,它有V个可能的取值(a1,a2,…,aV),
如果数据集D在以特征作为划分特征时,增益率定义为:
其中
我们来看下上述增益率公式,其实 IV(a)就是特征a本身的信息熵,也就说对应根据特征a的可能取值,所对应求得的信息熵,
举个例子,对于outlook这个特征,总共有三个类别(sunny,overcast,rainy),所对应类别的数据的个数为为(5,4,5)
则outlook本身的信息熵为:
特征a的对应种类越多,也就是说V越大,则IV(a)的值通常会越大,从而增益率越小。这样,就可以避免信息增益中对可取数目比较多的特征有所偏好的缺点
那直接以信息增益率作为划分的衡量标准,有没有什么缺点呢,其实也有,增益率准则一般对可取数目较少的属性有所偏好。
所以,可以先从当前所有特征中找出信息增益高于平均值的的部分特征,再从中选择增益率最高的作为最优划分特征。
基尼指数
还有一种决策树算法,称为CART决策树,它是使用基尼值来作为衡量标准的。具体流程其实和信息增益的衡量标准类似,只是将信息熵,改为了基尼值
Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一样的概率。故Gini(D)越小,则数据集的纯度越高
连续型特征处理
前面我们所讲的都是基于离散型特征进行划分生成决策树,那对于连续性特征,我们需要怎么来处理呢?这个时候就需要用到连续型特征离散化的方法。最简单的即为二分法。下面我们来具体看下:
给定样本集D和连续特征a, 假设特征a在样本集中总共有n个不同的取值。
1.将个n取值进行从小到大排序,记为A(a1,a2,...,an)
2基于一个划分点t,将A划分为两部分,其中不大于t的部分对应的数据集为Dt-,大于的部分对应的数据集为Dt+
3.我们知道,对于将A进行二分,我们有n-1种分法,另外对于相邻的属性取值与来说,t在区间(ai,a(i+1))中取任意值产生的划分结果相同,t的取值集合为:
4.然后,对于每个划分点,我们进行信息增 益的计算,选择最大的信息增益对应的那个划分点,作为连续型特征a的划分点。公 式为:
决策树剪枝操作
我们想想,如果我们不加限制,最后训练出来的决策树,每个叶子节点的数据都会分成纯净的,这样真的好吗?要知道,我们是希望训练出的决策树模型,对于新给的数据,能够准确预测出对应结果。
所以,决策树,非常容易出现过拟合问题。为了避免这个问题,提供决策树的泛化能力,我们需要进行剪枝操作。一般有两种剪枝操作,“预剪枝”和“后剪枝”
预剪枝
预剪枝即是指在决策树的构造过程中,对每个节点在划分前需要根据不同的指标进行估计,如果已经满足对应指标了,则不再进行划分,否则继续划分。
那么,具体指标都有哪些呢?
1.直接指定树的深度
2.直接指定叶子节点个数
3.直接指定叶子节点的样本数
4.对应的信息增益量
5.拿验证集中的数据进行验证,看分割前后,精度是否有提高。
由于预剪枝是在构建决策树的同时进行剪枝处理,所以其训练时间开销较少,同时可以有效的降低过拟合的风险
但是,预剪枝有一个问题,会给决策树带来欠拟合的风险,1,2,3,4指标,不用过多解释,对于5指标来说,
虽然,当前划分不能导致性能提高,但是,或许在此基础上的后续划分,却能使性能显著提高呢?
后剪枝
后剪枝则是先根据训练集生成一颗完整的决策树,然后根据相关方法进行剪枝。
常用的一种方法是,自底向上,对非叶子节点进行考察,同样拿验证集中的数据,来根据精度进行考察。看该节点划分前和划分后,精度是否有提高,如果划分后精度没有提高,则剪掉此子树,将其替换为叶子节点。
相对于预剪枝来说,后剪枝的欠拟合风险很小,同时,泛化能力往往要优于预剪枝,但是,因为后剪枝先要生成整个决策树后,然后才自底向上依次考察每个非叶子节点,所以训练时间长。
完整代码:
package machinelearning.decisiontree;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.*;
/**
* The ID3 decision tree inductive algorithm.
*
* @author Rui Chen 1369097405@qq.com.
*/
public class ID3 {
/**
* The data.
*/
Instances dataset;
/**
* Is this dataset pure (only one label)?
*/
boolean pure;
/**
* The number of classes. For binary classification it is 2.
*/
int numClasses;
/**
* Available instances. Other instances do not belong this branch.
*/
int[] availableInstances;
/**
* Available attributes. Other attributes have been selected in the path
* from the root.
*/
int[] availableAttributes;
/**
* The selected attribute.
*/
int splitAttribute;
/**
* The children nodes.
*/
ID3[] children;
/**
* My label. Inner nodes also have a label. For example, <outlook = sunny,
* humidity = high> never appear in the training data, but <humidity = high>
* is valid in other cases.
*/
int label;
/**
* The prediction, including queried and predicted labels.
*/
int[] predicts;
/**
* Small block cannot be split further.
*/
static int smallBlockThreshold = 3;
/**
********************
* The constructor.
*
* @param paraFilename
* The given file.
********************
*/
public ID3(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "rn" + ee);
System.exit(0);
} // Of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
} // Of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
} // Of for i
// Initialize.
children = null;
// Determine the label by simple voting.
label = getMajorityClass(availableInstances);
// Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}// Of the first constructor
/**
********************
* The constructor.
*
* @param paraDataset
* The given dataset.
********************
*/
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
// Copy its reference instead of clone the availableInstances.
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
// Initialize.
children = null;
// Determine the label by simple voting.
label = getMajorityClass(availableInstances);
// Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}// Of the second constructor
/**
**********************************
* Is the given block pure?
*
* @param paraBlock
* The block.
* @return True if pure.
**********************************
*/
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0])
.classValue()) {
pure = false;
break;
} // Of if
} // Of for i
return pure;
}// Of pureJudge
/**
**********************************
* Compute the majority class of the given block for voting.
*
* @param paraBlock
* The block.
* @return The majority class.
**********************************
*/
public int getMajorityClass(int[] paraBlock) {
int[] tempClassCounts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassCounts[(int) dataset.instance(paraBlock[i]).classValue()]++;
} // Of for i
int resultMajorityClass = -1;
int tempMaxCount = -1;
for (int i = 0; i < tempClassCounts.length; i++) {
if (tempMaxCount < tempClassCounts[i]) {
resultMajorityClass = i;
tempMaxCount = tempClassCounts[i];
} // Of if
} // Of for i
return resultMajorityClass;
}// Of getMajorityClass
/**
**********************************
* Select the best attribute.
*
* @return The best attribute index.
**********************************
*/
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionalEntropy(availableAttributes[i]);
if (tempMinimalEntropy > tempEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
} // Of if
} // Of for i
return splitAttribute;
}// Of selectBestAttribute
/**
**********************************
* Compute the conditional entropy of an attribute.
*
* @param paraAttribute
* The given attribute.
*
* @return The entropy.
**********************************
*/
public double conditionalEntropy(int paraAttribute) {
// Step 1. Statistics.
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
} // Of for i
// Step 2.
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
} // Of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
} // Of if
tempEntropy += -tempFraction * Math.log(tempFraction);
} // Of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
} // Of for i
return resultEntropy;
}// Of conditionalEntropy
/**
**********************************
* Split the data according to the given attribute.
*
* @return The blocks.
**********************************
*/
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
// System.out.println("Dataset " + dataset + "rn");
// System.out.println("Attribute " + paraAttribute + " has " +
// tempNumValues + " values.rn");
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
// First scan to count the size of each block.
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
} // Of for i
// Allocate space.
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
} // Of for i
// Second scan to fill.
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
// Copy data.
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
} // Of for i
return resultBlocks;
}// Of splitData
/**
**********************************
* Build the tree recursively.
**********************************
*/
public void buildTree() {
if (pureJudge(availableInstances)) {
return;
} // Of if
if (availableInstances.length <= smallBlockThreshold) {
return;
} // Of if
selectBestAttribute();
int[][] tempSubBlocks = splitData(splitAttribute);
children = new ID3[tempSubBlocks.length];
// Construct the remaining attribute set.
int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
for (int i = 0; i < availableAttributes.length; i++) {
if (availableAttributes[i] < splitAttribute) {
tempRemainingAttributes[i] = availableAttributes[i];
} else if (availableAttributes[i] > splitAttribute) {
tempRemainingAttributes[i - 1] = availableAttributes[i];
} // Of if
} // Of for i
// Construct children.
for (int i = 0; i < children.length; i++) {
if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
children[i] = null;
continue;
} else {
// System.out.println("Building children #" + i + " with
// instances " + Arrays.toString(tempSubBlocks[i]));
children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
// Important code: do this recursively
children[i].buildTree();
} // Of if
} // Of for i
}// Of buildTree
/**
**********************************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The prediction.
**********************************
*/
public int classify(Instance paraInstance) {
if (children == null) {
return label;
} // Of if
ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
if (tempChild == null) {
return label;
} // Of if
return tempChild.classify(paraInstance);
}// Of classify
/**
**********************************
* Test on a testing set.
*
* @param paraDataset
* The given testing data.
* @return The accuracy.
**********************************
*/
public double test(Instances paraDataset) {
double tempCorrect = 0;
for (int i = 0; i < paraDataset.numInstances(); i++) {
if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
tempCorrect++;
} // Of i
} // Of for i
return tempCorrect / paraDataset.numInstances();
}// Of test
/**
**********************************
* Test on the training set.
*
* @return The accuracy.
**********************************
*/
public double selfTest() {
return test(dataset);
}// Of selfTest
/**
*******************
* Overrides the method claimed in Object.
*
* @return The tree structure.
*******************
*/
public String toString() {
String resultString = "";
String tempAttributeName = dataset.attribute(splitAttribute).name();
if (children == null) {
resultString += "class = " + label;
} else {
for (int i = 0; i < children.length; i++) {
if (children[i] == null) {
resultString += tempAttributeName + " = "
+ dataset.attribute(splitAttribute).value(i) + ":" + "class = " + label
+ "rn";
} else {
resultString += tempAttributeName + " = "
+ dataset.attribute(splitAttribute).value(i) + ":" + children[i]
+ "rn";
} // Of if
} // Of for i
} // Of if
return resultString;
}// Of toString
/**
*************************
* Test this class.
*
* @param args
* Not used now.
*************************
*/
public static void id3Test() {
ID3 tempID3 = new ID3("D:/data/weather.arff");
// ID3 tempID3 = new ID3("D:/data/mushroom.arff");
ID3.smallBlockThreshold = 3;
tempID3.buildTree();
System.out.println("The tree is: rn" + tempID3);
double tempAccuracy = tempID3.selfTest();
System.out.println("The accuracy is: " + tempAccuracy);
}// Of id3Test
/**
*************************
* Test this class.
*
* @param args
* Not used now.
*************************
*/
public static void main(String[] args) {
id3Test();
}// Of main
}// Of class ID3
最后
以上就是奋斗豌豆为你收集整理的机器学习之决策树的全部内容,希望文章能够帮你解决机器学习之决策树所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复