我是靠谱客的博主 害羞巨人,最近开发中收集的这篇文章主要介绍机器学习笔记:回归树,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

回归树:

优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型

提到回归树一定会联想到决策树。

ID3决策树

决策树是一种贪心算法,它要在给定的时间内做出最佳选择,但不关心能否达到全局最优,之前记录的决策树是ID3算法,ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。
1. 当一个特征被切分使用后,就不再起作用了,显然这样的划分只能保证局部最优的。
2. 对于连续值,ID3算法也需要预先将它转化为离散值,这样的话连续特征值包含的某些特性就被忽略了。

二元切分法

即每次把数据集切成两份,如果特征值大于给定值时,按照这个给定值划分整个数据集,大于的进入左子树,小于等于的进入右子树。假设:

a = matrix([[ 1.,  2.,  0.,  0.,  0.,  0.],
            [ 0.,  1.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  1.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  1.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  1.,  0.],
            [ 0.,  0.,  0.,  0.,  0.,  0.]])

对第四列特征,按照特征值大于0.5的数据集划分,可以得到:

 mat1 = matrix([[ 1.,  0.,  0.,  0.,  0.,  0.],
                [ 0.,  1.,  0.,  0.,  0.,  0.],
                [ 0.,  0.,  1.,  0.,  0.,  0.],
                [ 0.,  0.,  0.,  0.,  1.,  0.],
                [ 0.,  0.,  0.,  0.,  0.,  1.]])

 mat0=matrix([[ 0.,  0.,  0.,  1.,  0.,  0.]])

Python 3.6 实现回归树

《机器学习实战》书中的回归树一章存在相当多的错误以及Python3.6与2.x版本不兼容的情况,这里全部已经修改过了,并在注释中解释:

import numpy as np

#加载数据集
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr:
        current = line.strip().split('t')
         #python3.6中 map返回值为object 需要使用list转换
        fltLine = list(map(float,current))
        dataMat.append(fltLine)        
    return dataMat


#切分整个数据集
def binSplitDataSet(dataSet, feature, value):
    '''
    mat0,mat1 = binSplitDataSet(a,3,0.9)

    mat1=matrix([[ 1.,  0.,  0.,  0.,  0.,  0.],
                 [ 0.,  1.,  0.,  0.,  0.,  0.],
                 [ 0.,  0.,  1.,  0.,  0.,  0.],
                 [ 0.,  0.,  0.,  0.,  1.,  0.],
                 [ 0.,  0.,  0.,  0.,  0.,  1.]])

    mat0=matrix([[ 0.,  0.,  0.,  1.,  0.,  0.]])
    也就是将整个数据集按照这个特征的某个阈值切分
    '''
    mat0 = dataSet[np.nonzero(dataSet[:,feature]>value)[0],:]
    mat1 = dataSet[np.nonzero(dataSet[:,feature]<=value)[0],:]
    return mat0,mat1

#计算目标变量---label的平均值
def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

#计算目标变量---label的总方差
def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

#选择最好的切分阈值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regLeaf, ops=(1,4)):
    tols = ops[0]; toln = ops[1];
    #判断目标变量---label是否只存在唯一的了
    #print(dataSet[:,-1])
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)

    #n为还存在的特征数
    m, n = np.shape(dataSet)
    #获取数据集中目标变量的总方差
    s = errType(dataSet)
    bests = np.inf;
    bestIndex = 0;
    bestValue = 0;
    for featIndex in range(n-1):
        #这里原本书上的这句代码是错误
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            #如果两个数据集一个数据量不足toln则换一个划分
            if (np.shape(mat0)[0] < toln) or (np.shape(mat1)[0] < toln):
                continue
            #得到两个数据集目标变量---label的总方差和
            news = errType(mat0) + errType(mat1)
            if news < bests:
                bestIndex = featIndex
                bestValue = splitVal
                bests = news
    #如果最终得到的误差不大, 则推出
    if (s - bests) < tols:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)

    #如果最终得到的数据集很小也推出
    if (np.shape(mat0)[0] < toln) or (np.shape(mat1)[0] < toln):
        return None, leafType(dataSet)

    return bestIndex, bestValue


#建树,#ops=(容许的误差下降值,切分的最少样本数)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    #得到最佳index和最佳切分阈值
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        #到达叶子节点,这里的val其实就是均值
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    #通过最佳index和最佳阈值得到切分后的两个数据集
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    #递归建树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

树剪枝

#判断是否是树
def isTree(obj):
    return (type(obj).__name__ == 'dict')

#如果树的两个子树都存在,则计算子树平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['right']+tree['left'])/2

#剪枝
def prune(tree, testData):
    if np.shape(testData)[0] == 0:
        return getMean(tree)
    if (isTree(tree['left']) or isTree(tree['right'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], testData)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], testData)

    if (not isTree(tree['right'])) and (not isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2))+sum(
                np.power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2
        errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print('merging!')
            return treeMean
        else:
            return tree
    else:
        return tree

最后

以上就是害羞巨人为你收集整理的机器学习笔记:回归树的全部内容,希望文章能够帮你解决机器学习笔记:回归树所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部