概述
回归树:
优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型
提到回归树一定会联想到决策树。
ID3决策树
决策树是一种贪心算法,它要在给定的时间内做出最佳选择,但不关心能否达到全局最优,之前记录的决策树是ID3算法,ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。
1. 当一个特征被切分使用后,就不再起作用了,显然这样的划分只能保证局部最优的。
2. 对于连续值,ID3算法也需要预先将它转化为离散值,这样的话连续特征值包含的某些特性就被忽略了。
二元切分法
即每次把数据集切成两份,如果特征值大于给定值时,按照这个给定值划分整个数据集,大于的进入左子树,小于等于的进入右子树。假设:
a = matrix([[ 1., 2., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0.],
[ 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 1., 0., 0.],
[ 0., 0., 0., 0., 1., 0.],
[ 0., 0., 0., 0., 0., 0.]])
对第四列特征,按照特征值大于0.5的数据集划分,可以得到:
mat1 = matrix([[ 1., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0.],
[ 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0.],
[ 0., 0., 0., 0., 0., 1.]])
mat0=matrix([[ 0., 0., 0., 1., 0., 0.]])
Python 3.6 实现回归树
《机器学习实战》书中的回归树一章存在相当多的错误以及Python3.6与2.x版本不兼容的情况,这里全部已经修改过了,并在注释中解释:
import numpy as np
#加载数据集
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr:
current = line.strip().split('t')
#python3.6中 map返回值为object 需要使用list转换
fltLine = list(map(float,current))
dataMat.append(fltLine)
return dataMat
#切分整个数据集
def binSplitDataSet(dataSet, feature, value):
'''
mat0,mat1 = binSplitDataSet(a,3,0.9)
mat1=matrix([[ 1., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0.],
[ 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0.],
[ 0., 0., 0., 0., 0., 1.]])
mat0=matrix([[ 0., 0., 0., 1., 0., 0.]])
也就是将整个数据集按照这个特征的某个阈值切分
'''
mat0 = dataSet[np.nonzero(dataSet[:,feature]>value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature]<=value)[0],:]
return mat0,mat1
#计算目标变量---label的平均值
def regLeaf(dataSet):
return np.mean(dataSet[:,-1])
#计算目标变量---label的总方差
def regErr(dataSet):
return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
#选择最好的切分阈值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regLeaf, ops=(1,4)):
tols = ops[0]; toln = ops[1];
#判断目标变量---label是否只存在唯一的了
#print(dataSet[:,-1])
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
#n为还存在的特征数
m, n = np.shape(dataSet)
#获取数据集中目标变量的总方差
s = errType(dataSet)
bests = np.inf;
bestIndex = 0;
bestValue = 0;
for featIndex in range(n-1):
#这里原本书上的这句代码是错误
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
#如果两个数据集一个数据量不足toln则换一个划分
if (np.shape(mat0)[0] < toln) or (np.shape(mat1)[0] < toln):
continue
#得到两个数据集目标变量---label的总方差和
news = errType(mat0) + errType(mat1)
if news < bests:
bestIndex = featIndex
bestValue = splitVal
bests = news
#如果最终得到的误差不大, 则推出
if (s - bests) < tols:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
#如果最终得到的数据集很小也推出
if (np.shape(mat0)[0] < toln) or (np.shape(mat1)[0] < toln):
return None, leafType(dataSet)
return bestIndex, bestValue
#建树,#ops=(容许的误差下降值,切分的最少样本数)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
#得到最佳index和最佳切分阈值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None:
#到达叶子节点,这里的val其实就是均值
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
#通过最佳index和最佳阈值得到切分后的两个数据集
lSet, rSet = binSplitDataSet(dataSet, feat, val)
#递归建树
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
树剪枝
#判断是否是树
def isTree(obj):
return (type(obj).__name__ == 'dict')
#如果树的两个子树都存在,则计算子树平均值
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['right']+tree['left'])/2
#剪枝
def prune(tree, testData):
if np.shape(testData)[0] == 0:
return getMean(tree)
if (isTree(tree['left']) or isTree(tree['right'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], testData)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], testData)
if (not isTree(tree['right'])) and (not isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2))+sum(
np.power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2
errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print('merging!')
return treeMean
else:
return tree
else:
return tree
最后
以上就是害羞巨人为你收集整理的机器学习笔记:回归树的全部内容,希望文章能够帮你解决机器学习笔记:回归树所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复