首页 > 编程语言 >【算法】KNN、SVM算法详解!

【算法】KNN、SVM算法详解!

时间:2022-10-15 14:02:26浏览次数:53  
标签:KNN SVM 函数 res clf 距离 算法 超平面 omega

什么是KNN算法

在这里插入图片描述

寻找未知分类数据的离它最近的n个已知数据,通过已知数据的分类来推断这个未知数据的分类

KNN的原理

步骤

  1. 计算距离(常用欧几里得距离或马氏距离)
  2. 升序排列(最近的排前面,最远的排后面)
  3. 取前K个
  4. 加权平均

K的选取(算法的核心)

K太大:导致分类模糊

K太小:受个例影响,波动较大

如何取K

靠经验或者慢慢尝试

均方根误差

在这里插入图片描述

实战应用

以一个癌症检测数据集为例

1. 载入数据

在这里插入图片描述

2. 打乱数据,分组,分为测试集和训练集

将2/3的数据作为训练数据,1/3的数据作为测试数据 在这里插入图片描述

3. KNN函数实现

  1. 计算距离(该测试数据与所有训练数据之间的距离),采用欧式距离计算(各项指标差的平方和再开方

在这里插入图片描述

  1. 按照距离升序排序

在这里插入图片描述

  1. 取前K个
res2 = res[0:K]  #此时K = 5
  1. 加权平均(距离小的权重大,距离大的权重小),先测得总距离,利用1-(该测试数据的距离/总距离)作为该测试数据的权重 在这里插入图片描述

4. 对测试数据进行测试输出准确率

利用准确数/总测试数据个数来计算准确率

在这里插入图片描述

5. 输出结果

在这里插入图片描述

6. 代码

import csv

#读取
import random

with open("Prostate_Cancer.csv","r") as file:
    reader = csv.DictReader(file)
    datas = [row for row in reader]

#分组,分为训练集和测试集
random.shuffle(datas)
n = len(datas) // 3

test_set = datas[0:n]
train_set = datas[n:]


#KNN
#距离
def distance(d1,d2):
    res = 0

    for key in ("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):
        res += (float(d1[key]) - float(d2[key])) ** 2

    return res ** 0.5

K = 5
def KNN(data):
    #1.距离
    res = [
        {"result":train["diagnosis_result"],"distance":distance(data,train)}
        for train in train_set
    ]

    #2.升序排序
    res = sorted(res,key=lambda item:item["distance"])

    #3.取前K个
    res2 = res[0:K]

    #4.加权平均
    result = {'B':0,'M':0}

    #总距离
    sum = 0
    for r in res2:
        sum += r["distance"]

    #计算权重
    for r in res2:
        result[r["result"]] += 1-r["distance"]/sum

    #结果
    if result['B'] > result['M']:
        return 'B'
    else:
        return 'M'

#测试
correct = 0
for test in test_set:
    result = test["diagnosis_result"]
    result2 = KNN(test)

    if result == result2:
        correct += 1

print("准确率:{:.2f}%".format(100 * correct / len(test_set)))

什么是SVM算法

SVM(support vector machine)支持向量机,是一个有监督的学习模型,通常用来进行模式识别、分类(异常值检测)以及回归分析。

Hard margin

将两类通过一个阈值而分类开,对于二维来说就是找一条线,三维找一个面,多维找一个超平面

Hard margin:距离超平面最近的点的间隔最大

在这里插入图片描述

最优线:

在SVM中最优分割面(超平面)就是:能使支持向量和超平面最小距离的最大值

在样本空间中,划分超平面可通过一个线性方程来描述: $$ \omega ^ Tx + b = 0 $$ 其中$\omega$=($\omega_1$;$\omega_2$;...;$\omega_3$)为法向量,决定了超平面的方向,b为位移项,决定了超平面与原点之间的距离,划分超平面可被法向量$\omega$和位移b确定

样本空间中任意一点x到超平面($\omega$,b)的距离可写为

在这里插入图片描述

若超平面对应方程为$\omega ^ Tx + b = 0$

在这里插入图片描述

若超平面能够将训练样本正确分类,对于任意($x_i$,$y_i$),若$y_i$ = +1,则有$\omega ^ Tx_i + b > 0$;若$y_i$ = -1,则有$\omega ^ Tx_i + b < 0$

在这里插入图片描述

距离超平面最近的这几个训练样本点使得上式成立,它们被称为"支持向量"(support vector),两个异类支持向量到超平面的距离之和为

在这里插入图片描述

它们被称为“间隔”(margin)

求最大间隔,也就是要找在满足参数$\omega$​和b($y_i(\omega ^ Tx_i + b) >= 1$​)的同时,使得$\gamma$​最大

通过转化:

在满足参数$\omega$和b($y_i(\omega ^ Tx_i + b) >= 1$)的同时,使得$\omega^2/2$​最小​

求解:拉格朗日乘子法

拉格朗日乘子法

假如有方程:

$x^2y=3$

图像: 在这里插入图片描述 求其上的点与原点的最小距离 请添加图片描述

请添加图片描述

即梯度向量平行,用数学符号表示:

请添加图片描述

因此:

请添加图片描述

也就是函数f在g的约束下的极值问题可表示为:

请添加图片描述

可列出方程求解:

请添加图片描述

这就是拉格朗日乘子法

类似地:如果有多个约束条件 请添加图片描述

即可求得解

以上在高等数学拉格朗日求极值有详解

KKT条件

请添加图片描述

请添加图片描述

Soft Margin

在Hard margin的基础上允许有一点错误(loss) 采用Soft Margin可以防止过拟合 在这里插入图片描述

折页损失(high loss)在这里插入图片描述

一般当z<1时分类错误,允许有一点损失,loss=1-yi(wTxi + b) 当z>=1时分类正确,loss = 0

线性分类:

一般地像一维、二维、三维这些可以通过阈值、直线、平面或超平面就能将数据划分的被称为线性分类

非线性分类

数据大多数情况都不可能是线性的,那如何分割非线性数据呢? 在这里插入图片描述 方法就是将数据处理后放到更高的维度上进行分割: 在这里插入图片描述 当f(x)=x时,这组数据是个直线,如上半部分,但是当我把这组数据变为f(x)=x^2时,这组数据就变成了下半部分的样子,也就可以被红线所分割。

比如说,我这里有一组三维的数据X=(x1,x2,x3),线性不可分割,因此我需要将他转换到六维空间去。因此我们可以假设六个维度分别是:x1,x2,x3,x1^2,x1x2,x1x3,当然还能继续展开,但是六维的话这样就足够了。 新的决策超平面:d(Z)=WZ+b,解出W和b后带入方程,因此这组数据的超平面应该是:d(Z)=w1x1+w2x2+w3x3+w4*x1^2+w5x1x2+w6x1x3+b但是又有个新问题,转换高纬度一般是以内积(dot product)的方式进行的,但是内积的算法复杂度非常大。

几种常用核函数:

  1. h度多项式核函数(Polynomial Kernel of Degree h)
  2. 高斯径向基和函数(Gaussian radial basis function Kernel)
  3. S型核函数(Sigmoid function Kernel)

图像分类,通常使用高斯径向基和函数,因为分类较为平滑,文字不适用高斯径向基和函数。没有标准的答案,可以尝试各种核函数,根据精确度判定。

SVM与其他机器学习算法对比

在这里插入图片描述

SVM算法具有以下特征:

  1. SVM可以表示为凸优化问题,因此可以利用已知的有效算法发现目标函数的全局最小值。而其他分类方法都采用一种基于贪心学习的策略来搜索假设空间,这种方法一般只能获得局部最优解。
  2. SVM通过最大化决策边界的边缘来实现控制模型的能力。尽管如此,用户必须提供其他参数,如使用核函数类型和引入松弛变量等。
  3. SVM一般只能用在二类问题,对于多类问题效果不好。

四种核函数的分类效果(代码)

from sklearn import svm
import numpy as np
import matplotlib.pyplot as plt

# 设置子图数量
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(7, 7))
ax0, ax1, ax2, ax3 = axes.flatten()

# 准备训练样本
x = [[1, 8], [3, 20], [1, 15], [3, 35], [5, 35], [4, 40], [7, 80], [6, 49]]
y = [1, 1, -1, -1, 1, -1, -1, 1]

# 设置子图的标题
titles = ['LinearSVC (linear kernel)',
          'SVC with polynomial (degree 3) kernel',
          'SVC with RBF kernel',  # 这个是默认的
          'SVC with Sigmoid kernel']
# 生成随机试验数据(15行2列)
rdm_arr = np.random.randint(1, 15, size=(15, 2))


def drawPoint(ax, clf, tn):
    # 绘制样本点
    for i in x:
        ax.set_title(titles[tn])
        res = clf.predict(np.array(i).reshape(1, -1))
        if res > 0:
            ax.scatter(i[0], i[1], c='r', marker='*')
        else:
            ax.scatter(i[0], i[1], c='g', marker='*')
    # 绘制实验点
    for i in rdm_arr:
        res = clf.predict(np.array(i).reshape(1, -1))
        if res > 0:
            ax.scatter(i[0], i[1], c='r', marker='.')
        else:
            ax.scatter(i[0], i[1], c='g', marker='.')


if __name__ == "__main__":
    # 选择核函数
    for n in range(0, 4):
        if n == 0:
            clf = svm.SVC(kernel='linear').fit(x, y)
            drawPoint(ax0, clf, 0)
        elif n == 1:
            clf = svm.SVC(kernel='poly', degree=3).fit(x, y)
            drawPoint(ax1, clf, 1)
        elif n == 2:
            clf = svm.SVC(kernel='rbf').fit(x, y)
            drawPoint(ax2, clf, 2)
        else:
            clf = svm.SVC(kernel='sigmoid').fit(x, y)
            drawPoint(ax3, clf, 3)
    plt.show()

结果: 在这里插入图片描述 注意: 核函数(这里简单介绍了sklearn中svm的四个核函数,还有precomputed及自定义的)

  1. LinearSVC:主要用于线性可分的情形。参数少,速度快,对于一般数据,分类效果已经很理想
  2. RBF:主要用于线性不可分的情形。参数多,分类结果非常依赖于参数
  3. polynomial:多项式函数,degree 表示多项式的程度-----支持非线性分类
  4. Sigmoid:在生物学中常见的S型的函数,也称为S型生长曲线

标签:KNN,SVM,函数,res,clf,距离,算法,超平面,omega
From: https://blog.51cto.com/u_15623229/5759158

相关文章

  • Problem P28. [算法课回溯] 电话号码的字母组合
    回溯,唯一麻烦的是要建立一个字典,键值对为数字字符对应英文字符串#include<iostream>#include<bits/stdc++.h>#include<cstdio>#include<string>usingnamespaces......
  • 【图像压缩】基于蚁群算法优化小波变换实现图像压缩附matlab代码
    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。......
  • 谁说算法工程师不会写代码
    大家好,我是阿星。我的新书《大规模推荐系统实战》前段时间上市了,收到很多反馈,看着那些或感谢、或认可、或鼓励的句子,我很有成就感,在这里感谢大家的支持!其实上学的时候就有涉......
  • 【算法训练营day3】LeetCode203. 移除链表元素 707. 设计链表 206. 反转链表
    【算法训练营day3】LeetCode203.移除链表元素707.设计链表206.反转链表LeetCode203.移除链表元素题目链接:203.移除链表元素初次尝试题目比较简单,之前刷过链表的......
  • Problem P27. [算法课回溯]目标和
    回溯法比较简单易懂,耗时比较长,也能过。有动态规划的解法大家可以自己想一想。#include<iostream>#include<bits/stdc++.h>#include<cstdio>#include<string>using......
  • 数据结构—算法的时间复杂度
    1、什么是时间复杂度     一般情况下,算法中基本语句重复执行的次数是问题规模n的某个函数f(n),算法的时间量度记作T(n)=O(f(n))。它表示随问题规模n的增大,算法执行时......
  • 代码随想录算法训练营第三天 | 203.移除链表元素 707.设计链表 206.反转链表
    链表的数据结构基础链表结构链表是一种通过指针串联在一起的线性结构。每一个节点由两钟部分构成,一部分是数据域,一部分是指针域,指针域存放的指针指向另一个节点。链表......
  • 深度学习算法基础
    1,基本概念1.1,余弦相似度1.2,欧式距离1.3,余弦相似度和欧氏距离的区别2,容量、欠拟合和过拟合3,正则化方法4,超参数和验证集5,估计、偏差和方差6,随机梯度下降算法......
  • 力扣-排序算法
    部分题解保存排序数组-快速排序classSolution{privatefinalstaticRandomrandom=newRandom(System.currentTimeMillis());publicint[]sortArray(in......
  • InnoDB存储引擎:索引与算法
    InnoDB存储引擎索引概述InnoDB支持以下几种常见的索引:B+树索引(传统意义上的索引,这是目前关系型数据库系统中查找最为常用和最为有效的索引;B+树索引并不能找到一个给......