KD树实现鸢尾花分类(Numpy实现)

最近也是刚接触KD树,刚开始也是一头雾水,自己也是搜了很多资料,通过自己的理解以及老师的讲解,对KD树有了更深的理解,然后就写个博客来记录一下,也好能帮助其他人去了解KD树。

关于KD树的原理网上有很多,我就不再讲述原理了,需要的数据集我会放在文章末尾,本文我主要用用Numpy去实现,没有涉及sklearn。

首先,导入所需要的库

import pandas as pd
import numpy as np
from collections import Counter

from collections import Counter是用来统计各个标签出现的次数的,因为我取得是前k个最近距离 ,也可省去不要。

读取鸢尾花数据,建立测试标签(x为所有鸢尾花特征值,y为标签,simple为测试数据)

x = np.array(pd.read_csv('iris.csv', usecols=(0, 1, 2, 3), delimiter=',', header=0))  # 读取特征集合
y = np.array(pd.read_csv('iris.csv')['species'])  # 读取标签集

simple = np.array([3.5,2.4,0.3,2.5])

创建一个KD数节点的类,__str__函数用于输出KD树

class KDtreeNode:
    def __init__(self, val, label, dim, left=None, right=None):
        self.val = val  # 特征集
        self.dim = dim  # 维度
        self.label = label  # 标签
        self.left = left  # 左子树
        self.right = right  # 右子树

    def __str__(self):
        return f'特征是:{self.val}, 标签是:{self.label},划分维度:{self.dim}'

接下来要创建一颗KD树,代码中我注释的比较全面,也就不再解释了,看我的注释就行

def CreateKDtree(x, y, dim):
    if x.size == 0:
        return None
    else:
        nidx = np.argsort(x, axis=0)[:, dim]  # 按照dim这个维度排序
        center_num = x.shape[0] // 2  # 中位数的序号

        cut_idx = nidx[center_num]  # 根节点的索引号
        left_idx = nidx[:center_num]  # 左子树的索引号
        right_idx = nidx[center_num + 1:]  # 右子树的索引号

        node_tree = KDtreeNode(x[cut_idx], y[cut_idx], dim)  # KD树的根节点
        dim = (dim + 1) % x.shape[1]  # 更新维度dim值
        node_tree.left = CreateKDtree(x[left_idx], y[left_idx], dim)  # 递归左子树
        node_tree.right = CreateKDtree(x[right_idx], y[right_idx], dim)  # 递归右子树
        return node_tree  # 得到KD树

对KD树进行搜索,得到预测结果

def search_KDtree(simple, k):
    # 初始化距离,最近点为None,最近距离为无穷大
    nearest_knn = np.array([[None, float('inf')] for _ in range(k)])
    # 创建一个列表,用于存放从根节点到一个叶子结点的所有节点,找距离最近的点
    node_list = []
    # 得到KD树,node_tree是一颗KD树
    node_tree = CreateKDtree(x, y, 0)
    while node_tree:
        # 将所有可能的节点加入到列表中,加入的位置为列表的第一个元素
        node_list.insert(0, node_tree)
        dim = node_tree.dim
        if simple[dim] < node_tree.val[dim]:
            node_tree = node_tree.left
        else:
            node_tree = node_tree.right
    #从叶子结点开始,回溯
    for node in node_list:
        #计算欧几里得距离
        distance = np.linalg.norm(node.val - simple, ord=2)
        #np.where返回一个二维数组,及满足要求的位置坐标.less_index为距离小于inf的行的索引
        less_index = np.where(distance < nearest_knn[:,1])[0]
        #print(nearest_knn)
        if less_index.size > 0:
            #对nearest_knn进行更新
            nearest_knn = np.insert(nearest_knn, less_index[0], [node, distance], axis=0)[:k]  #只取前k个距离最短的
        radius = nearest_knn[:,1][k-1]                #radius为k个距离中最远的那个,欧几里得距离
        dis = simple[node.dim] - node.val[node.dim]   #所求点到超平面的距离
        if radius > abs(dis):                              #如果欧几里得距离大于到超平面的距离
            if dis > 0:                               #如果simple[node.dim] > node.val[node.dim],加入左子树
                append_node = node.left
            else:
                append_node = node.right              #否则,加入左右树
            if append_node is not None:
                node_list.append(append_node)
    return([lab[0].label for lab in nearest_knn if lab[0] is not None])

依据KD搜索的原理,我们要从根节点出发,一直找下去,直到叶子节点,将这些节点存放在列表中,这些节点都可能是距离最短的,KD树搜索时,考虑的因素有很多,当欧几里得距离大于到超平面的距离时,同根节点的另外一颗树也可能存在最近距离,所以当条件满足时,还要将另外一棵树的节点添加到列表中

下面是主函数

lb = search_KDtree(simple, 3)
print('预测结果为:'+Counter(lb).most_common(1)[0][0])

直接调用搜索KD树函数就行,我只测试了一个样例,所以比较简单

下边是数据集下载地址,按照上边代码的顺序,直接粘贴过去是可以直接用的啊

提取码为5912,希望能帮到各位

Original: https://blog.csdn.net/qq_51606646/article/details/124060391
Author: (ฅ]ω[ฅ)
Title: KD树实现鸢尾花分类(Numpy实现)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/759152/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球