我是靠谱客的博主 正直夕阳,最近开发中收集的这篇文章主要介绍scikit-learn k-近邻分类算法(kNN),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

kNN 的作用机制为 在目标周围选取最近k个点,这k个点哪种占比最大,就可以把这个目标分类到那个分类,即有分到相似属性多的类别。
该算法和回归,决策树不同之处是,回归和决策树是通过训练集确定参数,参数一旦确定直接就能拿来进行测试,而kNN不同,它的分类要凭借训练数据,或者说并没有训练这一过程。

数据

应用sklearn的莺尾花数据集

# 加载数据
iris = load_iris()
# 为了绘制3d图少取了一项数据
X = iris.data[:, :3]
y = iris.target
print(iris.feature_names)
print(X[:3])
print(y[:3])

结果:

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
[[ 5.1  3.5  1.4]
 [ 4.9  3.   1.4]
 [ 4.7  3.2  1.3]]
[0 0 0]

TRAIN

看一下k为多少,准确率最高

for i in range(5,20):
    knn = KNeighborsClassifier(n_neighbors=i, weights='uniform')    
    knn.fit(X, y)
    print("准确率", knn.score(X, y))

结果:

5 准确率 0.96
6 准确率 0.96
7 准确率 0.966666666667
8 准确率 0.96
9 准确率 0.973333333333
10 准确率 0.96
11 准确率 0.966666666667
12 准确率 0.973333333333
13 准确率 0.966666666667
14 准确率 0.966666666667
15 准确率 0.966666666667
16 准确率 0.946666666667
17 准确率 0.953333333333
18 准确率 0.953333333333
19 准确率 0.96

相差不大,9,和12大,选9就好了。

TEST

全数据分类预测,绘制3d散点图:

import plotly.graph_objs as go
from plotly.graph_objs import *
import plotly

Z = knn.predict(X)

trace = go.Scatter3d(
    x=X[:, 2], y=X[:, 1], z=X[:, 0], mode='markers',
    marker=dict(
        size='10',
        color=Z,
        colorscale='Jet',
        showscale=False,
        line=Line(color='black',width=2),
    ),
)

plotly.offline.plot([trace], 's3d.html')

这里写图片描述

将特征改为2维,绘制分类图及数据集数据点:

import plotly.graph_objs as go
import plotly
import numpy as np
h = .02
cmap_light =[[0, 'rgba(255, 192, 203,0.7)'], [0.5, 'rgba(0, 229, 238, 0.7)'], [1, 'rgba(124, 252, 0, 0.7)']]
cmap_bold = [[0, '#FF0000'], [0.5, '#0000FF'], [1, '#00FF00']]
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
x_ = np.arange(x_min, x_max, h)
y_ = np.arange(y_min, y_max, h)
xx, yy = np.meshgrid(x_, y_)
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
trace1 = go.Heatmap(x=x_, y=y_, z=Z,
                    showscale=False,
                    colorscale=cmap_light)
trace2 = go.Scatter(
    x=X[:, 0], y=X[:, 1], mode='markers',
    marker=dict(size='10', color=y, colorscale=cmap_bold, showscale=False,
        line=dict(color='black', width=1)
    ),
)
data = [trace1,trace2]
plotly.offline.plot(data,'s1.html')

这里写图片描述

代码在这。

最后

以上就是正直夕阳为你收集整理的scikit-learn k-近邻分类算法(kNN)的全部内容,希望文章能够帮你解决scikit-learn k-近邻分类算法(kNN)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(40)

评论列表共有 0 条评论

立即
投稿
返回
顶部