TL;DR

  • 场景:小样本、低维特征的监督分类,用”相似样本投票”快速出结果
  • 结论:KNN核心是距离度量+K个邻居投票;特征尺度决定距离可信度,K影响偏差/方差
  • 产出:一套可复现实验(葡萄酒二分类)、距离计算/排序/投票流程、KNN函数封装

监督学习算法

KNN/K近邻算法

K近邻算法(K-Nearest Neighbors, KNN)的核心思想是基于样本间的距离度量来判断相似性。具体来说,它通过计算待分类样本与训练集中各个样本之间的特征空间距离(常用的距离度量包括欧式距离、曼哈顿距离或余弦相似度等),如果两个样本在特征空间中的距离足够接近,就认为它们具有较高的相似度,很可能属于同一类别。

在实际应用中,仅依靠单个最近邻样本进行分类容易受到噪声和异常值的影响,导致分类结果不稳定。因此KNN算法会选取距离最近的K个样本(即K个最近邻),这些近邻样本构成了待分类样本的局部邻域。算法会统计这些近邻样本的类别标签分布情况(标签代表样本所属的真实类别,如”猫”、“狗”等分类结果),然后采用投票机制:将K个近邻中出现次数最多的类别作为待分类样本的预测结果。

举例来说,在图像分类任务中,假设K=5,待分类图片的5个最近邻中包含3张”猫”和2张”狗”的图片,那么算法就会将该图片分类为”猫”。这种基于局部邻域投票的机制使得KNN算法对噪声数据具有较好的鲁棒性,同时K值的选择(通常通过交叉验证确定)会直接影响算法的分类性能。

实现过程

假设X_test待标记的数据样本,X_train为已标记的数据集。

  1. 遍历已标记数据集中的所有样本,计算每个样本与待标记的点的距离,并把距离保存在Distance数组中。
  2. 对Distance数组进行排序,取距离最近的K个点,记为X_knn。
  3. 在X_knn中统计每个类别的个数,即class0在X_knn中有几个样本,class1在X_knn中有几个样本。
  4. 待标记样本的类别,就是在X_knn中样本个数最多的那个类别。

距离的确定

该算法的【距离】在二维坐标轴就表示两点之间的距离,计算距离的公式有很多。我们常说的欧拉公式,即”欧氏距离”。

欧氏距离公式: $$d(A, B) = \sqrt{(x_1-x_2)^2 + (y_1-y_2)^2}$$

当特征数量有很多个形式多维空间时,N维空间中有两个点A和B,它们坐标分别为: $$d(A, B) = \sqrt{\sum_{i=1}^{n}(a_i - b_i)^2}$$

在机器学习中,坐标轴上的x1、x2、x3等,正是我们样本上的N个特征。

算法优点

  1. k值的作用机制

    • 当k=1时,模型仅考虑最近的一个样本点,这会使决策边界变得非常复杂
    • 随着k值增大,模型会考虑更多邻居的投票结果,使得决策边界趋于平滑
  2. k值与模型偏差的关系

    • 较大的k值(如k=15)会使模型偏差增大,因为决策会基于更多样本的平均值
    • 这会使模型对个别噪声数据点(如标注错误的样本)的敏感度降低
    • 极端情况下,当k接近训练集大小时,模型会简单地预测多数类
  3. k值与模型方差的关系

    • 较小的k值(如k=1或3)会使模型方差增大
    • 模型容易捕捉到训练数据中的随机波动和噪声
  4. 参数选择的实践经验

    • 通常从k=5开始尝试,这是经验法则
    • 可以通过交叉验证来寻找最优k值,常用方法是绘制k值与准确率的曲线图
    • 在sklearn中,可以使用GridSearchCV进行k值调优
  5. 不同场景下的k值选择

    • 对于噪声较多的数据集(如传感器数据),建议使用较大的k值(7-15)
    • 对于清晰可分的数据(如MNIST手写数字),较小的k值(3-5)可能更合适
    • 当类别分布不平衡时,k值不宜过小,否则容易受到少数类样本的影响

算法变种

变种1:加权KNN

默认情况下,在计算距离时,权重都是相同的,但实际上可以针对不同的邻居指定不同的距离权重,比如距离越近权重越高。可以通过指定算法的weights参数来实现。

变种2:半径近邻

使用一定半径内的点取代距离最近的k个点。在scikit-learn中,RadiusNeighborsClassifier实现了这种算法的变种。当数据采样不均匀时,该算法变种可以获得更好的性能。

代码实现

导入相关包

# 全部行都能输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 解决坐标轴刻度负号乱码
plt.rcParams['axes.unicode_minus'] = False
# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['Simhei']
plt.style.use('ggplot')

构建已经分类好的原始数据集

rowdata = {
    '颜色深度': [14.13,13.2,13.16,14.27,13.24,12.07,12.43,11.79,12.37,12.04],
    '酒精浓度': [5.64,4.28,5.68,4.80,4.22,2.76,3.94,3.1,2.12,2.6],
    '品种': [0,0,0,0,0,1,1,1,1,1]
}
# 0 代表 "黑皮诺",1 代表 "赤霞珠"
wine_data = pd.DataFrame(rowdata)

数据探索与可视化

X = np.array(wine_data.iloc[:,0:2]) #我们把特征(酒的属性)放在X
y = np.array(wine_data.iloc[:,-1]) #把标签(酒的类别)放在Y
# 探索数据,假如我们给出新数据[12.03,4.1] ,你能猜出这杯红酒是什么类别么?
new_data = np.array([12.03,4.1])
plt.scatter(X[y==1,0], X[y==1,1], color='red', label='赤霞珠')
plt.scatter(X[y==0,0], X[y==0,1], color='purple', label='黑皮诺')
plt.scatter(new_data[0],new_data[1], color='yellow')
plt.xlabel('酒精浓度')
plt.ylabel('颜色深度')
plt.legend(loc='lower right')
plt.savefig('葡萄酒样本.png')

计算已知类别数据集中的点与当前之间的距离

我们使用欧式距离公式,计算新数据点new_data与现存的X数据集每一个点的距离:

from math import sqrt
distance = [sqrt(np.sum((x-new_data)**2)) for x in X]
distance

执行结果:

[2.6041505332833594,
 1.1837651794169315,
 1.9424983912477256,
 2.3468276459936295,
 1.2159358535712326,
 1.3405968819895113,
 0.4308131845707605,
 1.0283968105745949,
 2.0089798406156287,
 1.500033332962971]

将距离升序排列 选取距离最小的K个点

sort_dist = np.argsort(distance)
sort_dist

确定前k个点所在类别的计数

k = 3
topK = [y[i] for i in sort_dist[:k]]
topK

# 投票统计
pd.Series(topK).value_counts().index[0]

封装函数

def KNN(new_data,dataSet,k):
    '''
    函数功能:KNN分类器
    参数说明:
    new_data: 需要预测分类的数据集
    dataSet: 已知分类标签的数据集
    k: k-近邻算法参数,选择距离最小的k个点
    return:
    result: 分类结果
    '''
    from math import sqrt
    from collections import Counter
    import numpy as np
    import pandas as pd

    result = []
    distance = [sqrt(np.sum((x-new_data)**2)) for x in np.array(dataSet.iloc[:,0:2])]
    sort_dist = np.argsort(distance)
    topK = [dataSet.iloc[:,-1][i] for i in sort_dist[:k]]
    result.append(pd.Series(topK).value_counts().index[0])
    return result

# 测试函数
new_data = np.array([12.03,4.1])
k = 3
KNN(new_data,wine_data,k)

错误速查

症状根因修复方案
中文/负号乱码字体未安装或字体名不匹配安装/指定可用中文字体(如SimHei/微软雅黑);保留plt.rcParams['axes.unicode_minus']=False
散点图坐标轴含义对不上X列顺序与xlabel/ylabel语义不一致统一”X的第0/1列”与坐标轴标签
KNN结果不稳定/偏某一类K过小/过大;类别不平衡;距离被尺度主导做特征缩放;用交叉验证选K;必要时改用加权KNN
高维数据效果明显变差距离集中现象(维度灾难),欧氏距离区分度下降降维/特征选择;换距离度量
半径近邻找不到邻居半径设置不合理、采样不均匀调整radius;设置回退策略