概述
k-近邻 (KNN)思想:
一个样本数据集合(亦称训练样本集),并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数,最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
代码
python3版本代码(对于小白,在每次测试时候,可以打断点到测试那里,更容易理解):
from numpy import *
import operator
#导入数据和标签,此处为手写,也可以直接导入其他文本数据集
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
group, labels = createDataSet()
#测试
print("group:")
print(group)
print("labels:")
print(labels)
print("---------"*10)
###################################################################################################################
#K紧邻算法
'''
1、计算已知类别数据集中的点与当前点的距离;
2、按照距离递增次序排序;
3、选取与当前点距离最小的k个点;
4、确定前k个点所在类别的出现频率
5、返回前k个点出现频率最高的类别作为当前点的预测分类
'''
def KNN(inX, dataSet, labels, k):
"""
:param inX: 输入点
:param dataSet: 数据集
:param labels: 标签集
:param k: 选择最近邻居的数目
:return:分类标签A或者B
"""
dataSetSize = dataSet.shape[0]#获取数据集的行数为 4
#1、计算距离:欧式计算公式d=sqrt( (x1-x2)^2+(y1-y2)^2 ),下面计算全部是矩阵计算
diffMat = tile(inX, (dataSetSize, 1)) - dataSet# 相当于x1-x2,但是这是矩阵运算,而tile方法是扩充成和Setdata一样行数
sqDiffMat = diffMat**2#这个是平方
sqDistances = sqDiffMat.sum(axis=1)#这个是横向求和
distances =sqDistances**0.5#开根号
#查看该点与各个数据集的距离
print("距离各点距离:")
print(distances)
# 2、返回的是升序排序后的distances的下标索引
sortedDistIndices = distances.argsort()
#查看排序后索引
print("排序后索引:")
print(sortedDistIndices)
#3、4、找出距离inX最近的k个标签,就是排序后,sortedDistIndices前k个值的索引,并确定频率
classCount = {}#字典dict类型,存放标签的个数
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]#找到索引对应的标签 A或者B
count = classCount.get(voteIlabel, 0) #dict.get方法,找到voteIlabel的值,如果不存在则返回一个0并存入dict,存在则不用管
classCount[voteIlabel] = count + 1#统计距离inX近的标签
#查看标签个数,即频率
print("标签的频率:")
print(classCount)
#5、对频率进行排序,python3以上用classCount.items(),3以下classCount.iteritems()
#这里使用了sorted()函数sorted(iterable, cmp=None, key=None, reverse=False),items()将dict分解为元组列表,operator.itemgetter(1)表示按照第二个元素的次序对元组进行排序,注意sort()的区别
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
#可以查看一下这是sorted()后的元组
print("转换后的元组合:")
print(sortedClassCount)
print("---------"*10)
#返回第一个就是频率最高的那一个
return sortedClassCount[0][0]
#测试KNN,对于点随便自己手动输入,我这里输入的是[1,3]
print("分类为:"+KNN([1, 3], group, labels, 3))
结果:
分类为:A
最后
以上就是高贵银耳汤为你收集整理的机器学习——分类算法1:k-近邻 (KNN) 思想和代码k-近邻 (KNN)思想:代码的全部内容,希望文章能够帮你解决机器学习——分类算法1:k-近邻 (KNN) 思想和代码k-近邻 (KNN)思想:代码所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复