机器学习入门(五):K近邻

0x00 介绍

这一节,我们会介绍K近邻算法.K近邻是一种基本的机器学习算法.直观,简单,容易理解.下面只介绍分类的算法.
K近邻的思想是: 如果要预测样本的类别,那么就看和样本最像的K个样本是哪个类别.

K近邻没有目标函数.不像感知机从训练集训练出一个超平面函数,K近邻不会训练出一个函数来拟合数据.如果训练好一个感知机模型,那么你只要将预测样本输入后,看函数的输出就可以了.但是K近邻不同,K近邻没有目标函数,所以每一次预测样本都是需要将预测样本和所有数据进行比较.

0x01 正文

K近邻算法

输入:训练数据集T=(x1,y1),(x2,y2),…(xN,yN)T=(x1,y1),(x2,y2),…(xN,yN)
输出:实例x所属的类y
算法步骤:
(1)根据给定的距离度量,在训练集T中找出与x最近邻的k个点,涵盖这k个点的x的邻域记作Nk(x)Nk(x)
(2)在Nk(x)Nk(x)中根据分类决策规则,如多数表决决定x的类别y。

K近邻关键步骤:

  1. K近邻需要计算预测样本和每一个训练样本的相似度

    由于K近邻算法是在所有训练数据中找到与其最为相似的数据,那么我们需要保存所有数据,以此应对于每一次计算.

  2. 如何定义两个样本的相似度

    我们通过距离度量两个样本的相似度,有很多种距离度量可以选,一般我们选择欧氏距离, 也就是L2范式,还有L1范式.

  3. K值的选择

    一般K值选5个,如果K值较小,噪音对预测结果影响较大,如果K值较大,那么有些和K值不太像的样本也混进来了,使得误差变大.

  4. 分类决策

    选定K个与之最像的样本之后,我们通过多数表决的方式确定预测样本的分类.(其实也是经验风险最小化)

近似误差和估计误差

近似误差指的是目标点对于其原样本点的可信度,误差越小,对于原样本点的信任度越高,也就是说,目标点可能只需要对最近的点确认一次就可以标注自己的标签,而无需去询问其他目标点。而估计误差则是原模型本身的真实性,也就是说,该模型所表现出的分类特性,是不是就是真实的分类特性,比如有噪点影响,有错误数据记录,或者本身数据分布就不是很好,都会是影响估计误差的因素,而询问的点越多,那么这些坏点对于目标点的标签影响就越小

Kd树

我们通过上述描述,应该清楚了一个K近邻算法的基本运作思想,由于没有训练过程,没有预测模型,使得K近邻算法的计算量十分巨大,因为它需要把所有的样本点都和目标点进行一次距离度量,很难适应大规模的数据样本。那么kd树就应运而生了
kd树,指的是k-dimensional tree,是一种分割K维数据空间的数据结构,主要用于多维空间关键数据的搜索。kd树是二进制空间分割树的特殊情况。

0x02 代码

这是最简单的K近邻代码部分

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import operator
import struct



def classify(inX, dataSet, labels, k):
    dataSetSize  = dataSet.shape[0]    # 获取样本数量
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet   # tile函数使得 inX在行上重复datasetsize次 列1次,
                                                         # 以此和样本数据相减
    sqDiffMat = diffMat**2   # 距离的平方
    sqDistances = sqDiffMat.sum(axis=1)   # 距离的和
    distances = sqDistances**0.5  # 距离的开根
    sortedDistance = distances.argsort()  # 以距离由小到大排列,并返回索引值给y
    classCount = {}
    for i in range(k):   # 距离最小的k个点
        votelabel = labels[sortedDistance[i]]  # 依次用距离最小的值的索引值得出类别
        classCount[votelabel] = classCount.get(votelabel, 0) + 1  # 类别出现数依次增加
    sortedClasscount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 倒序排列

    return sortedClasscount[0][0]  # 输出出现次数最大类别


def file2matirx(filename):
    fr = open(filename)
    arrayLines = fr.readlines()
    numberOfLines = len(arrayLines)  # 获得行数
    returnMat = np.zeros((numberOfLines, 3))  # 创建一个矩阵,用来存储之后的特征
    classLabelVector = []
    index = 0
    for line in arrayLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index, :] = listFromLine[0:3]  # 将前三个特征量写入到矩阵中
        classLabelVector.append((int(listFromLine[-1])))  # 将标签类别写入到标类中
        index += 1
    return returnMat, classLabelVector


def autoNorm(dataSet):   # 特征值归一化处理
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    m = dataSet.shape[0]
    normDataSet = dataSet - np.tile(minVals, (m, 1))  # 矩阵每一项都减去最小值
    normDataSet = normDataSet/np.tile(ranges, (m, 1))
    return normDataSet


def datingClassTest():
    hoRatio = 0.10
    datingDataMat, datingLabels = file2matirx('test.txt')  # 获得特征矩阵和标量
    normMat, ranges, minVals = autoNorm(datingDataMat) # 特征归一化
    m = normMat.shape[0]  # 取样本数量
    numTestVecs = int(m*hoRatio)  # 获得测试样本数量
    errorcount = 1.0
    for i in range(numTestVecs):    # 以前n个测试量进行测试
        classifierResult = classify(normMat[i, :], datingDataMat[numTestVecs:m, :],  # 利用KNN算法对测试集测试
                                    datingLabels[numTestVecs:m], 3)
        print('the classifier came back whth : %d, the real answer is : %d'
              % (classifierResult, datingLabels[i]))
        if classifierResult != datingLabels[i]: errorcount += 1.0
    print('the total error rate is :%f' % (errorcount/float(numTestVecs)))


def handwritingClassTest():    # 手写数字利用KNN识别
    '''
    trainfile_X 
    trainfile_y 
    testfile_X 
    testfile_y  
    '''

    m = Testimages.shape[0]
    errorcount = 0
    for i in range(10000):   # 此次只测试100个
        classResult = classify(Testimages[i], Trainimages, Trainlabels, 7)
        print("the classifier came back with: %d ,the real ansuer is: %d"
              % (classResult, Testlabels[i]))
        if (classResult != Testlabels[i]): errorcount += 1.0

    print('the total number of error is : %d' % errorcount)
    print('the total correct rate is : %f ' % (1 - (errorcount/float(10000))))


def outImg(arrX, arrY):
    """
    根据生成的特征和数字标号,输出png的图像
    """
    # 每张图是28*28=784Byte
    for i in range(1):
        img = np.array(arrX)
        img = img.reshape(28, 28)
        plt.figure()
        plt.imshow(img, cmap='binary')  # 将图像黑白显示

handwritingClassTest()

这是Kd树

# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
        self.split = split  # 整数(进行分割维度的序号)
        self.left = left  # 该结点分割超平面左子空间构成的kd-tree
        self.right = right  # 该结点分割超平面右子空间构成的kd-tree


class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # 数据维度

        def CreateNode(split, data_set):  # 按第split维划分数据集exset创建KdNode
            if not data_set:  # 数据集为空
                return None
            # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
            # data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2  # //为Python中的整数除法
            median = data_set[split_pos]  # 中位数分割点
            split_next = (split + 1) % k  # 在k维度中循环选取

            # 递归的创建kd树
            return KdNode(median, split,
                          CreateNode(split_next, data_set[:split_pos]),  # 创建左子树
                          CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树

        self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点


# KDTree的中序遍历
def preorder(root):
    if root.left:  # 节点不为空
        preorder(root.left)
    print(root.dom_elt)
    if root.right:
        preorder(root.right)


if __name__ == "__main__":
    data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
    kd = KdTree(data)
    preorder(kd.root)

0x03 总结

这一节我们简单的讲了K近邻,还有提了一下Kd树,关于KD树详细资料可以上网搜一搜