【统计学习方法】k近邻 kd树的python实现

前言

代码可在Github上下载:代码下载

k近邻可以算是机器学习中易于理解、实现的一个算法了,《机器学习实战》的第一章便是以它作为介绍来入门。而k近邻的算法可以简述为通过遍历数据集的每个样本进行距离测量,并找出距离最小的k个点。但是这样一来一旦样本数目庞大的时候,就容易造成大量的计算。

所以需要将数据用树形结构存储,以便快速检索,这也就是本文要阐述的kd树。

实现

分为两部分,一个是kd树建立,一个是kd树的搜索。

 

kd树建立

# --*-- coding:utf-8 --*--
import numpy as np

先定义一下字符集还有包。

首先我们先实现一个结点类,用来表示kd。

class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

一个结点包含着结点域,左孩子,右孩子。(如果不熟二叉树的话建议先看一些数据结构二叉树的相关知识,以及先序遍历,中序遍历还有后序遍历的相关代码)

二叉树相关代码(C语言实现)

然后是创建kd树的代码,主要根据P41,算法3.2来实现的。

def create(self, dataSet, depth):   #创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出样本行,列
            midIndex = m / 2 #中间数的索引位置
            axis = depth % n    #判断以哪个轴划分数据,对应书中算法3.2(2)公式j()
            sortedDataSet = self.sort(dataSet, axis) #进行排序
            node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2个副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            print leftDataSet
            print rightDataSet
            node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None

以上的代码通过看注释应该可以了解一二,其中需要按轴j(mod k)+1,也就是【depth(深度) mod n(特征数)+1】为轴划分中位数,然后决定插入数据到左结点,右结点。然后注意一下为什么上面的按轴划分的公式是【depth(深度) mod n(特征数)】,这是因为python的数组下标是从0开始的。

def sort(self, dataSet, axis):  #采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]    #由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        print sortDataSet
        return sortDataSet

创建树的时候为了找中位数,需要按轴(某一维度)排序,找出中间那个数。这里我用了冒泡排序。

def preOrder(self, node):
        if node != None:
            print "tttt->%s" % node.data
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

当然我选择了先序遍历来简单检查下树的创建有没有问题。(看下这棵树能否正常遍历,这步可忽略)

 

kd树搜索

    def search(self, tree, x):  #搜索
        self.nearestPoint = None    #保存最近的点
        self.nearestValue = 0   #保存最近的值
        def travel(node, depth = 0):    #递归搜索
            if node != None:    #递归终止条件
                n = len(x)  #特征数
                axis = depth % n    #计算轴
                if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

                #以下是递归完毕,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目标和节点的距离判断
                if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2): #欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

搜索树的时候比较麻烦,首先先说下原理吧。

(1) 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每个结点进行以下操作:
  (a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
  (b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为x的最近邻点。

注意了,先按步骤找到叶结点,然后回朔的时候要做两件事,(a)是更新最新点,(b)是检查是否需要检查父结节点的另外一个结点的区域。

                if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

这段是类似于二叉查找树的过程,直至查找到叶子节点。

                #以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目标和节点的距离判断
                if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)

这段代码,就是P43算法3.3(3)中的内容。

(a)容易实现,但是(b)的原理是判断目标点和最近的一个点的距离为半径画一个圆(就如书本P44图3.5,目标点S和当前最近点D形成了一个圆),是否跟父结点按轴分的那条线(也就是圆内的那条直线)有交集。

说白了,就是公式:|目标值(按轴读值) - 父节点(按轴读值)| < 最近的值(圆的半径),这里按轴读取就是P44图3.5中的x的y轴的值,然后减去相交的那条直线y轴的值,看是否小于半径。

注意:评论里有说这里的node.data不知道是指示哪个结点。这里要说明的是,这个node并不是父节点,而是当前结点。这里如果你对数据结构的二叉树不太熟的话,是不太容易get到这个点的。我只能稍微说下。

“这里应该了解下二叉查找树的过程”

如果找到了的话,把另一结点重新递归一次就好了。对应以下代码:

travel(node.rchild, depth+1)

最后在github贴出全部代码(如果方便的话麻烦给个赞吧,您的支持就是我前进的动力),然后来运行一下代码(这段代码在python3.5下成功运行)。

KNN(KDtree)代码下载

结果输出(5,4)

你可能感兴趣的