我是靠谱客的博主 彪壮悟空,最近开发中收集的这篇文章主要介绍sklearn包中K近邻分类器 KNeighborsClassifier的使用,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

1. KNN算法

K近邻(k-Nearest Neighbor,KNN)分类算法的核心思想是如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法可用于多分类,KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,作为预测值。

KNeighborsClassifier在scikit-learn 在sklearn.neighbors包之中。KNeighborsClassifier使用很简单,三步:

1)创建KNeighborsClassifier对象,

2)调用fit函数,

3)调用predict函数进行预测。

以下代码说明了用法。

例子一:

[python]  view plain  copy
  1. from sklearn.neighbors import KNeighborsClassifier  
  2.   
  3. X = [[0], [1], [2], [3],[4], [5],[6],[7],[8]]  
  4. y = [000111222]  
  5.   
  6. neigh = KNeighborsClassifier(n_neighbors=3)  
  7. neigh.fit(X, y)  
  8.   
  9. print(neigh.predict([[1.1]]))   #结果[0]
  10. print(neigh.predict([[1.6]]))   #结果[0]
  11. print(neigh.predict([[5.2]]))   #结果[1]
  12. print(neigh.predict([[5.8]]))   #结果[2]
  13. print(neigh.predict([[6.2]]))   #结果[3]

例子二:

from sklearn import datasets
from sklearn
import *
# from sklearn.neighbors import KNeighborsClassifier
# from sklearn.cross_validation import train_test_split
iris=datasets.load_iris()
iris_X=iris.data
iris_Y=iris.target
X_train,X_test,Y_train,Y_test = train_test_split(iris_X,iris_Y,test_size=0.3)
knn=KNeighborsClassifier()
knn.fit(X_train,Y_train)
print(knn.predict(X_test))
print(Y_test)

2. 实例

1)小麦种子数据集 (seeds)

七个特征,面积、周长、紧密度、谷粒的长度、谷粒的宽度、偏度系数和谷粒槽长度。数据格式如下:

[plain]  view plain  copy
  1. 15.26   14.84   0.871   5.763   3.312   2.221   5.22    Kama  
  2. 14.88   14.57   0.8811  5.554   3.333   1.018   4.956   Kama  
  3. 14.29   14.09   0.905   5.291   3.337   2.699   4.825   Kama  
  4. 13.84   13.94   0.8955  5.324   3.379   2.259   4.805   Kama  
  5. 16.14   14.99   0.9034  5.658   3.562   1.355   5.175   Kama  
  6. 14.38   14.21   0.8951  5.386   3.312   2.462   4.956   Kama  
  7. 14.69   14.49   0.8799  5.563   3.259   3.586   5.219   Kama  
  8. 14.11   14.1    0.8911  5.42    3.302   2.7     5.0     Kama  
  9. 16.63   15.46   0.8747  6.053   3.465   2.04    5.877   Kama  

2)代码

[python]  view plain  copy
  1. # -*- coding:utf-8 -*-  
  2. import numpy as np  
  3. from matplotlib import pyplot as plt  
  4. from matplotlib.colors import ListedColormap  
  5. from sklearn.neighbors import KNeighborsClassifier  
  6. from sklearn.cross_validation import KFold, cross_val_score  
  7.   
  8. feature_names = [  
  9.     'area',  
  10.     'perimeter',  
  11.     'compactness',  
  12.     'length of kernel',  
  13.     'width of kernel',  
  14.     'asymmetry coefficien',  
  15.     'length of kernel groove',  
  16. ]  
  17.   
  18. COLOUR_FIGURE = False  
  19.   
  20.   
  21. def plot_decision(features, labels, num_neighbors=3):  
  22.     y_min, y_max = features[:, 2].min() * .9, features[:, 2].max() * 1.1  
  23.     x_min, x_max = features[:, 0].min() * .9, features[:, 0].max() * 1.1  
  24.     X, Y = np.meshgrid(np.linspace(x_min, x_max, 1000), np.linspace(y_min, y_max, 1000))  
  25.   
  26.     model = KNeighborsClassifier(num_neighbors)  
  27.     model.fit(features[:, (0,2)], labels)  
  28.     C = model.predict(np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)  
  29.     if COLOUR_FIGURE:  
  30.         cmap = ListedColormap([(1., .7, .7), (.71., .7), (.7, .71.)])  
  31.     else:  
  32.         cmap = ListedColormap([(1.1.1.), (.2, .2, .2), (.6, .6, .6)])  
  33.     fig,ax = plt.subplots()  
  34.     ax.set_xlim(x_min, x_max)  
  35.     ax.set_ylim(y_min, y_max)  
  36.     ax.set_xlabel(feature_names[0])  
  37.     ax.set_ylabel(feature_names[2])  
  38.     ax.pcolormesh(X, Y, C, cmap=cmap)  
  39.     if COLOUR_FIGURE:  
  40.         cmap = ListedColormap([(1., .0, .0), (.1, .6, .1), (.0, .01.)])  
  41.         ax.scatter(features[:, 0], features[:, 2], c=labels, cmap=cmap)  
  42.     else:  
  43.         for lab, ma in zip(range(3), "Do^"):  
  44.             ax.plot(features[labels == lab, 0],  
  45.                     features[labels == lab, 2],  
  46.                     ma,  
  47.                     c=(1.1.1.),  
  48.                     ms=6)  
  49.     return fig, ax  
  50.   
  51.   
  52. def load_csv_data(filename):  
  53.     data = []  
  54.     labels = []  
  55.     datafile = open(filename)  
  56.     for line in datafile:  
  57.         fields = line.strip().split('t')  
  58.         data.append([float(field) for field in fields[:-1]])  
  59.         labels.append(fields[-1])  
  60.     data = np.array(data)  
  61.     labels = np.array(labels)  
  62.     return data, labels  
  63.   
  64.   
  65. def accuracy(test_labels, pred_lables):  
  66.     correct = np.sum(test_labels == pred_lables)  
  67.     n = len(test_labels)  
  68.     return float(correct) / n  
  69.   
  70.   
  71. if __name__ == '__main__':  
  72.     opt = input("raw_inputp[1 or 2]: ")  
  73.     features, labels = load_csv_data('data/seeds.tsv')  
  74.     if opt == '1':  
  75.         knn = KNeighborsClassifier(n_neighbors=5)  
  76.         kf = KFold(len(features), n_folds=3, shuffle=True)  
  77.         result_set = [(knn.fit(features[train], labels[train]).predict(features[test]), test) for train, test in kf]  
  78.         score = [accuracy(labels[result[1]], result[0]) for result in result_set]  
  79.         print(score)  
  80.     elif opt == '2':  
  81.         names = sorted(set(labels))  
  82.         labels = np.array([names.index(ell) for ell in labels])  
  83.         fig, ax = plot_decision(features, labels)  
  84.         plt.show()  
  85.     else:  
  86.         print('input 1 or 2 !')  

代码简要说明 

load_csv_data 从数据文件,读取数据。

accuracy 计算预测的准确度。

plot_decision 画决策边界图,挑两个特征。这个函数要注意pcolormesh。

主程序:输入1进行预测,输入2画图。第一个选项中,

a)首先生成分类器,

b)调用KFold来生产学习数据和测试数据,

3)训练和预测,

4)计算精度。

这里充分利用了“列表解析”和“向量”使代码简洁。



最后

以上就是彪壮悟空为你收集整理的sklearn包中K近邻分类器 KNeighborsClassifier的使用的全部内容,希望文章能够帮你解决sklearn包中K近邻分类器 KNeighborsClassifier的使用所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(53)

评论列表共有 0 条评论

立即
投稿
返回
顶部