查看原文
其他

第5.5节 从零实现K近邻

空字符 月来客栈 2024-01-19

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点个赞、关个注、下回更新不迷路

  • 5.5 从零实现K近邻
    • 5.5.1 kd树节点定义
    • 5.5.2 kd树构建
    • 5.5.3 kd构建示例
    • 5.5.4 kd树最近邻搜索
    • 5.5.5 kd树K近邻搜索
    • 5.5.6 KNN实现
    • 5.5.7 总结

5.5 从零实现K近邻

在前面几节内容中,笔者已经详细地介绍了KNN的基本思想与原理,但对于具体的实现细节并没有做过多的介绍。下面笔者就开始正式介绍如何从零实现kd树以及完成整个KNN的代码实现。

5.5.1 kd树节点定义

根据第5.4.1节内容介绍,kd树本质上也就等同于二叉搜索树,因此,首先我们需要定义kd树中的节点信息,以及kd树的构建与查询等。同时,由于在KNN的预测结果中需要根据训练样本给出每个预测样本的标签值,因此就需要知道每个训练样本的原始标签值,故需要在节点中保存每个样本索引。最终,kd树的节点信息定义如下:

1 class Node(object):
2     def __init__(self, data=None, index=-1):
3         self.data = data
4         self.left_child = None
5         self.right_child = None
6         self.index = index
7 
8     def __str__(self):
9         return f"data({self.data}),index({int(self.index)})"

在上述代码中,第2-6行定义了节点Node中保存的具体信息,包括样本点、左右子树以及在原始样本中的索引;第8-9行定义了__str__()方法,其作用是在使用print()函数时可以直接打印出节点的信息,而不必用node.data这样的形式来访问节点中的样本。

5.5.2 kd树构建

在完成kd树节点的定义之后,下一步就可以开始定义构建kd树的整个过程。首先,我们需要定义类的初始化函数:

1 class MyKDTree(object):
2     def __init__(self, points):
3         self.root = None
4         self.dim = points.shape[1]
5         points = np.hstack(([points, np.arange(0, len(points)).reshape(-11)]))
6         self.insert(points, order=0)  # 递归构建KD树
7 
8     def is_empty(self):
9         return not self.root

在上述代码中,第3行定义了kd树的根节点;第4行定义了原始样本的维度;第5行用于将在样本点(points)的最后一列附加上每个样本点的索引值;第6行则是调用insert()方法递归完成kd树的构建;第8-9行定义了一个方法来判断当前kd是否为空。

接下来便是完成insert()方法的实现过程,代码如下:

 1     def insert(self, data, order=0):
 2         if len(data) < 1:
 3             return
 4         data = sorted(data, key=lambda x: x[order % self.dim])  # 按某个维度进行排序
 5         idx = len(data) // 2
 6         node = Node(data[idx][:-1], data[idx][-1])
 7         left_data = data[:idx]
 8         right_data = data[idx + 1:]
 9         if self.is_empty():
10             self.root = node  # 整个kd树的根节点
11         node.left_child = self.insert(left_data, order + 1)  # 递归构建左子树
12         node.right_child = self.insert(right_data, order + 1)  # 递归构建右子树
13         return node

在上述代码中,第2-3行用来判断当前传入样本是否为空,如果为空则结束当前递归;第4行用于将当前样本按照某个维度的大小顺序进行排序,其中样本点维度的比较顺序为从左到右依次轮询,order在每次进行递归时都会累加;第5-6行用来获取并保存当前样本点排序后中间位置的样本并保存到一个新初始化的节点中;第7-8行则是分别取当前排序后样本的左边部分和右边部分,以此来分别作为当前节点的左右子树;第11-12行则是分别递归构建左右子树。

5.5.3 kd构建示例

在实现kd树的构建代码后,便可以通过如下方式来进行使用:

1 def test_kd_tree_build(points):
2     tree = MyKDTree(points)
3     tree.level_order()
4     
5 if __name__ == '__main__':
6     points = np.array([[25], [14], [33], [65], [102.], [73], [813], [89], [12]])
7     test_kd_tree_build(points)

在上述代码中,第2行便是根据传入的样本点来递归的构建kd树;第3行则是将构建完成的kd树以层次遍历的方式打印出来。

以上代码运行结束后便会有类似如下信息输出:

 1 当前待划分样本点:[[1. 4. 1.], [1. 2. 8.], [2. 5. 0.], [3. 3. 2.],
 2 [6. 5. 3.], [7. 3. 5.], [ 8. 13. 6.], [8. 9. 7.], [10. 2. 4.]]
 3 父节点:[6. 5. 3.]
 4 左子树: [[1. 4. 1.], [1. 2. 8.], [2. 5. 0.], [3. 3. 2.]]
 5 右子树: [[7. 3. 5.], [ 8. 13.  6.], [8. 9. 7.], [10. 2. 4.]]
 6 ============
 7 当前待划分样本点:[[1. 2. 8.], [3. 3. 2.], [1. 4. 1.], [2. 5. 0.]]
 8 父节点:[1. 4. 1.]
 9 左子树: [[1. 2. 8.], [3. 3. 2.]]
10 右子树: [[2. 5. 0.]]
11 ============
12 ......
13 层次遍历结果为:
14 第1层的节点为:<[6. 5.], idx(3)>
15 第2层的节点为:<[1. 4.], idx(1)> <[8. 9.], idx(7)>
16 第3层的节点为:<[3. 3.], idx(2)> <[2. 5.], idx(0)> <[10.  2.], idx(4)> <[ 8. 13.], idx(6)>
17 第4层的节点为:<[1. 2.], idx(8)> <[7. 3.], idx(5)>

在上述输出结果中,第1-12行便是kd树在构建过程中所输出的信息,需要再次提醒的是样本点的最后一个维度为当前样本点在原始样本中的索引;第13-17行则是构建完成后kd树的层次遍历结果。

根据层次遍历以及构建输出结果,也可以还原得到如图5-9所示的kd树。

图 5-9. kd树构建结果图

5.5.4 kd树最近邻搜索

在实现最近邻的搜索过程之前首先需要根据式(5-1)中的定义来实现两个点之间距离的计算,实现代码如下所示:

1 def distance(p1, p2, p=2):
2     return np.sum((p1 - p2) ** p) ** (1 / p)

在上述代码中,当时就是我们熟悉的欧式距离。

进一步,根据第5.4.2节中kd树最近邻搜索的伪代码实现过程,其对应的代码实现如下所示:

 1     def nearest_search(self, point):
 2         best_node = None
 3         best_dist = np.inf
 4         visited = []  # 用来记录哪些节点被访问过
 5         point = point.reshape(-1)
 6 
 7         def nearest_node_search(point, curr_node, order=0):
 8             nonlocal best_node, best_dist, visited  # 声明这三个变量不是局部变量
 9             logging.debug(f"当前访问节点为:{curr_node}")
10             visited.append(curr_node)
11             if curr_node is None:
12                 return None
13             dist = distance(curr_node.data, point)
14             logging.debug(f"当前访问节点到被搜索点的距离为:{round(dist, 3)},"
15                           f"全局最佳距离为:{round(best_dist, 3)}, 全局最佳点为:{best_node}\n")
16             if dist < best_dist:
17                 best_dist = dist
18                 best_node = curr_node
19             cmp_dim = order % self.dim
20             if point[cmp_dim] < curr_node.data[cmp_dim]:
21                 nearest_node_search(point, curr_node.left_child, order + 1)
22             else:
23                 nearest_node_search(point, curr_node.right_child, order + 1)
24             if np.abs(curr_node.data[cmp_dim] - point[cmp_dim]) < best_dist:
25                 child = curr_node.left_child if curr_node.left_child not in 
26                                 visited else curr_node.right_child
27                 nearest_node_search(point, child, order + 1)
28         nearest_node_search(point, self.root, 0)
29         return best_node, best_dist

在上述代码中,第2-4行定义了相关的全局记录变量;第8行则是用来声明这3个变量不是局部变量而是上面定义的全局变量;第10行则是用来记录当前哪些节点已经访问过;第13行是计算当前节点到被搜索点的距离;第16-18行则是判断是否要更新当前最佳节点;第19行是计算得到进入左右子树时的判断维度;第20-23行是根据维度比较信息递归遍历相应的左子树或右子树;第24-27行则是根据第5.4.2节中的子空间排除原理来判断当前节点左右子树中未访问过的节点是否存在最佳节点,并进行递归遍历。

继续滑动看下一个

第5.5节 从零实现K近邻

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存