概述
sklearn 源码分析系列:neighbors(2)
我起初一直在纠结是否需要把kd_tree的实现也放在这一篇中讲,如果讲算法实现,就违背了源码分析的初衷,过早钻入细节,是阅读源码的大忌。算法和框架的分析应属两部分内容,所以最终决定,所有sklearn源码分析系列不涉及具体算法,而是保证每个方法调用的连通性,重点关注架构,以及一些必要的python实现细节。
Note:
这篇文章主要分析Neighbors包中的Unsupervised Nearest Neighbors相关接口,对应于官方文档1.6.1章节,详见文档。
Finding the Nearest Neighbors实操
详细实操代码可参考Github kaggle项目,详见链接。
在实现最近邻算法时,常用的算法有”kd_tree”,”ball_tree”,”brute”三种,它们对应于不同的应用场景,这里不再赘述。
数据生成与可视化
# 1.6.1 Unsupervised Nearest Neighbors
from sklearn.neighbors import NearestNeighbors
import numpy as np
import matplotlib.pyplot as plt
# 1.6.1.1 Finding the Nearest Neighbors
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
plt.figure()
plt.scatter(X[:,0],X[:,1])
plt.xlim(X[:,0].min()-1,X[:,0].max()+1)
plt.ylim(X[:,1].min()-1,X[:,1].max()+1)
plt.title("Unsupervised nearest neighbors")
plt.show()
# k个最近的点中包含自己
nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(X)
distances,indices = nbrs.kneighbors(X)
# k个最近点的下标,按升序排列
indices
输出:
array([[0, 1, 2],
[1, 0, 2],
[2, 1, 0],
[3, 4, 5],
[4, 3, 5],
[5, 4, 3]], dtype=int64)
# k个最近点的最短距离,按升序排列
distances
Out[2]:
array([[ 0.
,
1.
,
2.23606798],
[ 0.
,
1.
,
1.41421356],
[ 0.
,
1.41421356,
2.23606798],
[ 0.
,
1.
,
2.23606798],
[ 0.
,
1.
,
1.41421356],
[ 0.
,
1.41421356,
2.23606798]])
kneighbors(X)默认返回两个参数,其中k个最近邻中还包含了自己,距离和下标均按照升序排列。
# k个最近点生成的邻接矩阵
nbrs.kneighbors_graph(X).toarray()
Out [3]:
array([[ 1.,
1.,
1.,
0.,
0.,
0.],
[ 1.,
1.,
1.,
0.,
0.,
0.],
[ 1.,
1.,
1.,
0.,
0.,
0.],
[ 0.,
0.,
0.,
1.,
1.,
1.],
[ 0.,
0.,
0.,
1.,
1.,
1.],
[ 0.,
0.,
0.,
1.,
1.,
1.]])
# 1.6.1.2 KD Tree and Ball Tree Classes
from sklearn.neighbors import KDTree
import numpy as np
# 可直接用KDtree实现最近邻查找
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
kdt = KDTree(X, leaf_size=30, metric='euclidean')
kdt.query(X,k = 3,return_distance = False)
Out [4]:
array([[0, 1, 2],
[1, 0, 2],
[2, 1, 0],
[3, 4, 5],
[4, 3, 5],
[5, 4, 3]], dtype=int64)
源码剖析
我们先从整体上来看看,实现NearestNeighbors所需关联到的python文件及对应的文件结构是什么样子的。
相比于Neighbors(1)中的内容,它多了unsupervised.py
文件而已。所以,我们直接顺藤摸瓜开始分析。
unsupervised.py
class NearestNeighbors(NeighborsBase, KNeighborsMixin,
RadiusNeighborsMixin, UnsupervisedMixin):
def __init__(self, n_neighbors=5, radius=1.0,
algorithm='auto', leaf_size=30, metric='minkowski',
p=2, metric_params=None, n_jobs=1, **kwargs):
self._init_params(n_neighbors=n_neighbors,
radius=radius,
algorithm=algorithm,
leaf_size=leaf_size, metric=metric, p=p,
metric_params=metric_params, n_jobs=n_jobs, **kwargs)
这是一个明显的子类继承多个父类的情况,其中KNeighborsMixin
和RadiusNeighborsMixin
属于功能相同,但具体实现细节有所差异,只单独分析一例。
先来看看它的构造方法吧,构造方法中传入了,9个参数,都是带默认值的。但令人奇怪的是,它同样是空有型而无内容的【初始化类】,该类只与客户端打交道,而真正的参数初始化都交给了其中的某个父类的__init__params()
方法。为什么要这么做?不急,先看看到底是哪个父类完成了参数初始化。
所有父类集中在neighbors包下的base.py
文件中。
经过一番寻找总算找到了初始化参数方法,在类neighborsBase
中
class NeighborsBase(six.with_metaclass(ABCMeta, BaseEstimator)):
"""Base class for nearest neighbors estimators."""
@abstractmethod
def __init__(self):
pass
def _init_params(self, n_neighbors=None, radius=None,
algorithm='auto', leaf_size=30, metric='minkowski',
p=2, metric_params=None, n_jobs=1):
self.n_neighbors = n_neighbors
self.radius = radius
self.algorithm = algorithm
self.leaf_size = leaf_size
self.metric = metric
self.metric_params = metric_params
self.p = p
self.n_jobs = n_jobs
if algorithm not in ['auto', 'brute',
'kd_tree', 'ball_tree']:
raise ValueError("unrecognized algorithm: '%s'" % algorithm)
if algorithm == 'auto':
if metric == 'precomputed':
alg_check = 'brute'
else:
alg_check = 'ball_tree'
else:
alg_check = algorithm
if callable(metric):
if algorithm == 'kd_tree':
# callable metric is only valid for brute force and ball_tree
raise ValueError(
"kd_tree algorithm does not support callable metric '%s'"
% metric)
elif metric not in VALID_METRICS[alg_check]:
raise ValueError("Metric '%s' not valid for algorithm '%s'"
% (metric, algorithm))
if self.metric_params is not None and 'p' in self.metric_params:
warnings.warn("Parameter p is found in metric_params. "
"The corresponding parameter from __init__ "
"is ignored.", SyntaxWarning, stacklevel=3)
effective_p = metric_params['p']
else:
effective_p = self.p
if self.metric in ['wminkowski', 'minkowski'] and effective_p < 1:
raise ValueError("p must be greater than one for minkowski metric")
# 重点关注
self._fit_X = None
self._tree = None
self._fit_method = None
喔,原来NeighborsBase是要作为整个Neighbors最具领导力的类?起码这家伙拿到了全局信息吧,我的一个猜测是,除了unsupervised
需要用到这些参数之外,其他类也同样需要用这些参数做些有趣的事吧?所以既然大家都要复用这些参数!那就放在一个基类中吧,此处就叫NeighborsBase
吧。(待检验)
我们关注下方法本身中的参数:
1. self.n_neighbors = n_neighbors ## k近邻中的k
2. self.radius = radius ## 不知
3. self.algorithm = algorithm ## 使用何种k近邻算法,如’kd_tree’
4. self.leaf_size = leaf_size ## 生成’kd_tree’树需要传入的参数
5. self.metric = metric ## 计算其他各种形式的两点间距离
6. self.metric_params = metric_params ## 不知
7. self.p = p ## 不知
8. self.n_jobs = n_jobs ## 并发创建的线程数
除此之外,在初始化最后,还占了三个位:
1. self._fit_X = None ## fit_X 和传入的X之间有何关系?
2. self._tree = None ## _tree表示返回的树结构
3. self._fit_method = None ## fit传入的算法
NeighborsBase就这些内容,它还有一个_fit()
方法,稍后分析。总的来说,当客户端调用诸如nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30)
的构造方法时,NearestNeighbors什么都没做,把参数初始化任务交给了它的父类NeighborsBase(该小组的老大!),而这老大具体也没做什么具体的事,把该初始化的参数初始化,并做一些参数合法性的检查,完工。
模型参数初始完毕之后,自然到了fit
步骤,正如,客户端调用那样nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30).fit(X)
我把数据X,传给了谁?谁来拟合这些数据呢?
记得NearestNeighbors
中的几个父类吧,完成fit操作的是UnsupervisedMixin
类,接着来看看它的代码。
class UnsupervisedMixin(object):
def fit(self, X, y=None):
"""Fit the model using X as training data
Parameters
----------
X : {array-like, sparse matrix, BallTree, KDTree}
Training data. If array or matrix, shape [n_samples, n_features],
or [n_samples, n_samples] if metric='precomputed'.
"""
return self._fit(X)
非常简短,针对非监督的数据,全部交给了自己的self._fit(X)
方法,所以它又是个代理类?这个代理类更狠,什么都没做,直接转交给NearestNeighbors
中的某个父类来完成。调用_fit()
方法后,就又回到了NeighborsBase
中去了,所以当客户端要调用fit方法时,先交给了NeighborsBase
的手下UnsupervisedMixin
做一些前期的处理操作,但这手下学会了偷懒,什么都没做直接交给了领导,直接让领导来处理咯,真坏。那领导真的有功夫,有能力处理这个fit任务?领导也不傻,我们看看领导怎么做的。
def _fit(self, X):
......
# 做些必要的检查
X = check_array(X, accept_sparse='csr')
# 还是在做检查
n_samples = X.shape[0]
if n_samples == 0:
raise ValueError("n_samples must be greater than 0")
......
#前面占的位子给补上
self._fit_method = self.algorithm
self._fit_X = X
......
# 嘿,领导开始派发任务了
if self._fit_method == 'ball_tree':
self._tree = BallTree(X, self.leaf_size,
metric=self.effective_metric_,
**self.effective_metric_params_)
# 看到了熟悉的kd_tree了
elif self._fit_method == 'kd_tree':
self._tree = KDTree(X, self.leaf_size,
metric=self.effective_metric_,
**self.effective_metric_params_)
elif self._fit_method == 'brute':
self._tree = None
else:
raise ValueError("algorithm = '%s' not recognized"
% self.algorithm)
# 检查,为什么不放在一开始做?
if self.n_neighbors is not None:
if self.n_neighbors <= 0:
raise ValueError(
"Expected n_neighbors > 0. Got %d" %
self.n_neighbors
)
return self
唉,领导也没有干活啊,做了一些检查,根据来的参数,交给对应的具体执行者去做!但返回的还是自己,因为我要和客户端打交道。我们来分析下具体的执行者做了些什么操作。看如下代码,
elif self._fit_method == 'kd_tree':
self._tree = KDTree(X, self.leaf_size,
metric=self.effective_metric_,
**self.effective_metric_params_)
在NeighborsBase
的fit()
方法中,并不是返回某个模型对象,而是把模型对象内嵌到了NeighborsBase
中的self._tree
中去,这是为什么?kd_tree
模型本身有查询最近邻的方法,为什么不直接暴露给客户端呢?在这里我并不理解它这样做的用意是什么。(待解决)
所以对于数据真正的fit()
是交给具体算法来完成的,咱们接下来就看看kd_tree.py
吧。关于kd_tree
的算法细节,可以参考之前我的一篇博文【K近邻法学习笔记】。关于sklearn中kd_tree的具体分析,不作为本文内容,日后单独开辟一章来讲解。本文重点关注各接口的实现与内在联系。
所以当NeighborsBase
构造了kd_tree
时,就调用了它的构造方法,走。
def __init__(self, data, leafsize=10):
self.data = np.asarray(data)
self.n, self.m = np.shape(self.data)
self.leafsize = int(leafsize)
if self.leafsize < 1:
raise ValueError("leafsize must be at least 1")
self.maxes = np.amax(self.data,axis=0)
self.mins = np.amin(self.data,axis=0)
# 关键步骤
self.tree = self.__build(np.arange(self.n), self.maxes, self.mins)
前面也是做了一些初始化操作,接着开始构建kd_tree
的数据结构了。调用__build()
方法,由传入的数据的生成了对应的数据结构。到这里,数据到结构的映射完成了。
总结下,NearsetNeighbors
和客户端打交到,而NeighborsBase
统筹规划所有调度。
既然有了数据X到结构的映射,那自然要做真正的查询操作了(k近邻查询),我们继续来看看,客户端调用如下distances,indices = nbrs.kneighbors(X)
,在NearestNeighbors
中只要初始化方法,并没有kneighbors(X)
方法,该方法在它的另外一个父类KNeighborsMixin
中。
class KNeighborsMixin(object):
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
......
n_samples, _ = X.shape
sample_range = np.arange(n_samples)[:, None]
......
elif self._fit_method in ['ball_tree', 'kd_tree']:
if issparse(X):
raise ValueError(
"%s does not work with sparse matrices. Densify the data, "
"or set algorithm='brute'" % self._fit_method)
result = Parallel(n_jobs, backend='threading')(
delayed(self._tree.query, check_pickle=False)(
X[s], n_neighbors, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
if return_distance:
dist, neigh_ind = tuple(zip(*result))
result = np.vstack(dist), np.vstack(neigh_ind)
else:
result = np.vstack(result)
很多东西都可以忽略不看,只需要关注一行代码就可以了。
result = Parallel(n_jobs, backend='threading')(
delayed(self._tree.query, check_pickle=False)(
X[s], n_neighbors, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
前面它包了一个并发的类,咱们不去研究,在delay方法中,传入了self._tree.query
这是一个方法名,在之前KDTree
类的接口中,有相应的实现,也就是说KNeighborsMixin
类也不做任何查询操作,同样把查询交给了KDTree
来完成,的确如此,只有KDTree
中存放了相应的数据结构,不是它做查询谁来做查询,KNeighborsMixin
只是简单的把KDTree
返回的查询结果交给客户端就可以了,别无其他。
综上,整个关于数据X到kd_tree
的结构映射调用就完成了,也没有太多东西,理清各个类之间的关系就可以了。同样的,当要进行k近邻查询时,交给了NearestNeighbors
中的父类KNeighborsMixin
来代理查询,真正的查询操作还是kd_tree
来完成,前期都是些琐碎的调用流程,而算法的核心在于kd_tree
,起码数据在到kd_tree
之前,能够做很多前期处理,保证了算法对数据的要求。看来是时候研究下kd_tree
的核心算法了。
最后
以上就是俭朴背包为你收集整理的sklearn 源码分析系列:neighbors(2) sklearn 源码分析系列:neighbors(2) 的全部内容,希望文章能够帮你解决sklearn 源码分析系列:neighbors(2) sklearn 源码分析系列:neighbors(2) 所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复