概述
KNN(k最近邻)
K最近邻(K-Nearest Neighbor,KNN)算法,是著名的模式识别统计学方法,在机器学习分类算法中占有相当大的地位。它是一个理论上比较成熟的方法。既是最简单的机器学习算法之一,也是基于实例的学习方法中最基本的,也是常见的文本分类算法之一。
基本思想
如果一个实例在特征空间中的K个最相似(即特征空间中最近邻)的实例中的大多数属于某一个类别,则该实例也属于这个类别。所选择的邻居都是已经正确分类的实例。
该算法假定所有的实例对应于N维欧式空间中的点。通过计算一个点与其他所有点之间的距离,取出与该点最近的K个点,然后统计这K个点里面所属分类比例最大的,则这个点属于该分类。
该算法涉及3个主要因素:实例集、距离或相似的衡量、k的大小。
Tensorflow实现knn:k取1
参考:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py
这个例子使用的数据集是MNIST,这是手写数字识别的数据集,通过识别手写的数字0到9,也就是一共是十个类别。输入为图片,图片像素28*28=784,表示成向量形式。
使用样本形式为(X,label),X为784维的向量,label为0-9之中的一个类别。
import numpy as np
import tensorflow as tf
#最近邻算法,此代码实现类似1-NN
#导入输入数据MNIST
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
#这个例子限制了样本的数目
Xtr, Ytr = mnist.train.next_batch(1000) #1000 条候选样本,测试样本跟候选样本比较,得到最近的K个样本,然后k个样本的标签大多数为某类,测试样本就为某类
Xte, Yte = mnist.test.next_batch(200) #200 条测试样本
# tf Graph Input,占位符,用来feed数据
xtr = tf.placeholder("float", [None, 784])
xte = tf.placeholder("float", [784])
# 最近邻计算距离使用 L1 距离
# 计算L1距离
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# 预测: 获取离测试样本具有最小L1距离的样本(1-NN),此样本的类别作为test样本的类别
pred = tf.arg_min(distance, 0)
accuracy = 0.
# 初始化图
init = tf.global_variables_initializer()
# 发布图
with tf.Session() as sess:
sess.run(init)
#循环测试集
for i in range(len(Xte)):
# Get nearest neighbor
nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]}) #每次循环feed数据,候选Xtr全部,测试集Xte一次循环输入一条
# 获得与测试样本最近样本的类别,计算与真实类别的误差
print("Test", i, "Prediction:", np.argmax(Ytr[nn_index]),
"True Class:", np.argmax(Yte[i]))
# 计算误差率
if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
accuracy += 1. / len(Xte)
print("Done!")
print("Accuracy:", accuracy)
当我候选样本选1000时,结果:Accuracy: 0.8650000000000007
当我候选样本选5000时,结果:Accuracy: 0.9250000000000007
可见,候选样本对精确度影响还是比较大的。
最后
以上就是唠叨篮球为你收集整理的Tensorflow--实现KNN的全部内容,希望文章能够帮你解决Tensorflow--实现KNN所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复