首页 > 编程语言 >3.2 KNN算法(k-近邻算法)

3.2 KNN算法(k-近邻算法)

时间:2023-06-10 15:34:43浏览次数:36  
标签:KNN neighbors 算法 train 3.2 test estimate

1.什么是k-近邻算法

例如:
image

如果你不知道你现在在哪,你可以通过你和你的邻居的距离推算出你的位置
你的“邻居”来推断出你的类别

2.原理

2.1 定义

如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别

就是看看样本中离我最近的那个样本属于那个类别,我就属于那个类别
k取1的话,万一跟异常值最相似,所以不准确容易收到异常值的影响

2.2 距离公式

两个样本的距离可以通过如下公式计算,又叫欧式距离
image

2.3 例子

例如:电影类型分析
假设我们有现在几部电影
image
其中? 号电影不知道类别,如何去预测?我们可以利用K近邻算法的思想
image
假如说k=1时,找一个最相近的是He's not Really into dues 是爱情片
假如说k=2时,找两个最相近的时He's not Really into duesBeautiful Woman 也是爱情片
。。。。。。。。。。
k=6 无法确定
如果说k=7,三个爱情片,四个动作片,把他变成了动作片,很明显不是,所以k不能太大
k值取得太大样本不均衡的话会受到影响

k值取得过小,容易受到异常点的影响k值取得过大,样本不均衡的影响

2.4KNN算法步骤和API

1.k值取得过小,容易受到异常点的影响k值取得过大,样本不均衡的影响
2.结合前面的约会对象数据,在进行KNN算法之前需要对数据进行无量纲化处理和标准化处理

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')
n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数,就是k值
algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},
可选用于计算最近邻居的算法:‘ball_tree’将会使用 BallTree,‘kd_tree’将使用 KDTree。‘auto’将尝试根据传递给fit方法的值来决定最合适的算法。 (不同实现方式影响效率)
#初始化一个转换器类
estimate=KNeighborsClassifier(n_neighbors=3)
#模型计算
estimate.fit(x_train,y_train)
#5) 模型评估
#方法1:直接对比真实值和预测值
y_predict=estimate.predict(x_test)

2.5案例

image

步骤:
(1)获取数据
(2)数据集的划分
(3)特征工程: 标准化
(4)KNN预估器流程
(5)模型评估

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

def knn_iris():
    """
    用KNN算法对鸢尾花进行分类
    :return:
    """
    #1) 获取数据
    iris = load_iris()
    #2) 划分数据集
    x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=6)##数据集,目标集
    #3) 特征工程:标准化
    transfer=StandardScaler()
    x_train=transfer.fit_transform(x_train)
    x_test=transfer.transform(x_test)#这时候只需要让测试集转化就行,没有必要计算,fit时计算过程,transform是转化
    #4) knn算法预估器
    estimate=KNeighborsClassifier(n_neighbors=3)
    estimate.fit(x_train,y_train)
    #5) 模型评估
    #方法1:直接对比真实值和预测值
    y_predict=estimate.predict(x_test)
    print("y_perdict:\n",y_predict)
    print("直接比对真实值和预测值:\n",y_test==y_predict)
    #方法2:计算准确率
    score=estimate.score(x_test,y_test)
    print("准确率:\n",score)
    return None
if __name__ == "__main__":
    #代码1 KNN
    knn_iris()

注意一点就是标准化的时候
fit_transform(x_train)
transform(x_test)#这个测试值不需要计算,只需要转化就行

标签:KNN,neighbors,算法,train,3.2,test,estimate
From: https://www.cnblogs.com/lipu123/p/17471206.html

相关文章

  • 代码随想录算法训练营第四天|24. 两两交换链表中的节点 , 19.删除链表的倒数第N个节点
    24.两两交换链表中的节点 个人感觉这个不太难,刚开始打算用步进值为2,来搞,但是没有想到链表应该是怎么样的,原来可以直接用: 1cur=cur->next->next 学到了,这是我自己写的代码:1ListNode*MyLinkedList::swapPairs(ListNode*head)2{3ListNode*dummyHead=new......
  • PID控制算法:位置式PID & 增量式PID
    前面的文章已经介绍过什么是pid了,现在再回顾一下:PID:是过程控制中常用的一种针对某个对象或者参数进行自动控制的一种算法。这一篇分享不打算再深究pid的理论知识,如果有不懂或者对pid感兴趣的朋友,可以自行查阅资料,或者看我前面的文章。这次分享一下pid算法的常见实现和流程。主要简......
  • Python+OpenGL使用Cohen-Sutherland算法实现直线裁剪
    问题描述:编写Python程序,使用OpenGL实现用于直线裁剪的Cohen-Sutherland算法。运行程序,绘制一个矩形表示裁剪窗口,然后通过鼠标单击和移动来绘制直线,鼠标抬起时对刚刚绘制的直线进行裁剪,显示最终落在裁剪窗口中的部分。关于Cohen-Sutherland算法请自行查阅资料。准备工作:安装和配置Py......
  • 3.1分类算法之sklean转换器和预估器
    1.转换器**想一下之前做的特征工程的步骤?1、实例化(实例化的是一个转换器类(Transformer))2、调用fit_transform(对于文档建立分类词频矩阵,不能同时调用)**标准化:(x-mean)/stdfit_transform() fit()计算每一列的平均值 transform()(x-mean)/std进行最终的转换我......
  • 【信道估计】基于多用户MMSE-BLE算法实现信道估计附matlab代码
    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。......
  • Python一句话实现秦九韶算法快速计算多项式的值
    关于秦九韶算法快速计算多项式值的原理描述请参考之前推送的文章Python使用秦九韶算法求解多项式的值。本文重点演示Python函数reduce()和lambda表达式的用法。代码没加注释,如果不好理解的话,可以先参考文末相关阅读中的介绍。......
  • 二叉树先序遍历算法的步骤
    //创建二叉树类型的结构体 //创建显得树节点并赋值并将该节点的左子树指针域和右子树指针域分别赋为NULL; //创建一个函数用于遍历二叉树并打印节点的值 //主函数将并将指针分别指向新的树节点  //执行遍历打印二叉树的节点的值 ......
  • 【基础算法】关于高精度计算的问题【很高位数数据的加减乘除(相关代码用C++实现)】
    前言当我们在利用计算机进行一些计算时,可能会遇到这类问题:有些计算要求精度高,希望计算的数的位数可达几十位甚至几百位,虽然计算机的计算精度也算较高了,但因受到硬件的限制,往往达不到实际问题所要求的精度。这时我们就可以通过程序设计来解决这类问题,例如:<fontcolor=red>创建......
  • 代码随想录算法训练营第三天| 203.移除链表元素 、 707.设计链表 、206.反转链表
    链表的构造:link.h:1#ifndefLINK_H2#defineLINK_H3#include<vector>45structListNode{6intval;7ListNode*next;8ListNode():val(0),next(nullptr){}9ListNode(intx):val(x),next(nullptr){}10ListNode(in......
  • 算法
    算法合集读入输出importjava.io.BufferedReader;importjava.io.InputStreamReader;importjava.io.PrintWriter;importjava.util.StringTokenizer;publicclassMain{ publicstaticBufferedReaderin=newBufferedReader(newInputStreamReader(System.in)); pub......