在 Matplotlib 中绘制 k-NN 决策边界

在 Matplotlib 中绘制 k-NN 决策边界

机器学习中的 k-NN(k-Nearest Neighbors)算法是一种简单有效的分类算法之一,其原理是找到样本中和新样本最相似的 k 个样本点,然后将新样本分类为出现次数最多的那一类。在这篇文章中,我们将利用 Python 的 Matplotlib 库来绘制 k-NN 决策边界。

准备数据集

为了实现 k-NN 决策边界,我们需要准备一组数据集。这里我们使用 scikit-learn 库中的 make_blobs 函数来生成一些随机的数据点。make_blobs 函数可以生成一个 n_samples 行 2 列的矩阵,其中每一行代表一个数据点,n_features 和 centers 参数分别指定了数据点的特征数和中心点数。

from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=100, n_features=2, centers=2, random_state=42)

上面的代码会生成一个包含 100 个数据点和 2 个中心点的数据集。其中 X 是一个 100×2 的矩阵,y 是一个代表每个数据点类别的向量,其值为 0 或 1。

绘制散点图

绘制数据集的散点图是理解决策边界的重要一步。我们可以使用 Matplotlib 库中的 scatter 函数来实现。

import matplotlib.pyplot as plt

plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.show()

上面的代码中,X[:, 0] 和 X[:, 1] 分别代表了 X 矩阵中第一列和第二列的值。c 参数则指定了每个点的颜色。由于 y 的值只有 0 和 1,所以我们需要将它们映射为两种颜色。cmap 参数指定了采用哪种颜色映射。上面的代码会生成一个散点图,其中红色代表一类,蓝色代表另一类。

实现 k-NN 算法

接下来,我们需要实现 k-NN 算法。为了避免重复计算,我们可以使用 KDTree 数据结构对数据集进行预处理。KDTree 是一种二叉树结构,可以加速查找数据点中距离目标点最近的 k 个数据点。

from sklearn.neighbors import KDTree

tree = KDTree(X)

上面的代码会生成一个 KDTree 对象,其中 X 是数据矩阵。

接下来,我们可以实现 k-NN 算法,其伪代码如下:

  1. 预处理数据集 X,生成 KDTree 对象
  2. 对于每个测试样本 x,使用 KDTree.query 函数查找距离 x 最近的 k 个数据点,获取它们所在的类别
  3. 将 x 分类为 k 个数据点中出现次数最多的类别
import numpy as np

def knn(X_train, y_train, x_test, k):
    tree = KDTree(X_train)
    dists, indices = tree.query(x_test.reshape((1, -1)), k=k)
    nn_labels = y_train[indices][0]
    return np.bincount(nn_labels).argmax()

上面的代码中,X_train 和 y_train 分别是训练数据集的特征矩阵和类别向量,x_test 是测试样本特征向量。k 是 k-NN 算法中的超参数 k,表示最近邻居个数。dists 和 indices 分别是 KDTree.query 返回的距离和下标矩阵。nn_labels 是这 k 个数据点所属的类别向量。np.bincount函数会统计每个类别出现的次数,然后返回出现次数最多的类别。

绘制决策边界

有了 k-NN 算法,我们就可以根据它来绘制决策边界。我们可以用 Meshgrid 函数生成网格点,然后对每个网格点调用 k-NN 算法得到分类结果,最后将网格点和分类结果绘制成等高线图。在 Matplotlib 中,我们可以使用 contourf 函数来绘制等高线图。

xx, yy = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))
Z = np.zeros(xx.shape)
for i in range(xx.shape[0]):
    for j in range(xx.shape[1]):
        Z[i, j] = knn(X, y, [xx[i, j], yy[i, j]], k=5)
plt.contourf(xx, yy, Z, cmap="coolwarm", alpha=0.5)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.show()

上面的代码可以生成一个包含决策边界的散点图。其中 alpha 参数指定了等高线图的透明度,使得散点图能够在其上方显示出来。

结论

在本文中,我们利用 Python 的 Matplotlib 库实现了 k-NN 算法并绘制了决策边界。通过本文的学习,你可以更加深入地了解 k-NN 算法的原理和实现过程,同时也能够通过 Matplotlib 库将其可视化。此外,你也可以尝试调整数据集和 k 的取值,观察决策边界的变化。

Camera课程

Python教程

Java教程

Web教程

数据库教程

图形图像教程

办公软件教程

Linux教程

计算机教程

大数据教程

开发工具教程