首页 > 编程语言 >KNN _ K近邻算法 的实现 ----- 机器学习

KNN _ K近邻算法 的实现 ----- 机器学习

时间:2022-11-04 15:23:24浏览次数:40  
标签:KNN iris pred 近邻 np train ----- test self

导入相关包

import numpy as np
import pandas as pd

# 引入 sklearn 里的数据集,iris(鸢尾花)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split # 切分为训练集和测试集
from sklearn.metrics import accuracy_score # 计算分类预测的准确率

1.数据加载预处理

iris = load_iris()
df = pd.DataFrame(data = iris.data, columns = iris.feature_names)
df['class'] = iris.target
df['class'] = df['class'].map({0:iris.target_names[0], 1:iris.target_names[1], 2:iris.target_names[2]})
df.describe() # 描述
x = iris.data
y = iris.target.reshape(-1,1)
print(x.shape, y.shape)
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state=35, stratify=y)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

2. 核心算法实现

# 距离函数定义
def l1_distance(a, b):
    return np.sum(np.abs(a-b), axis=1 )
def l2_distance(a, b):
    return np.sqrt(np.sum((a-b) ** 2, axis=1) )

# 分类器实现
class kNN(object):
    # 定义一个初始化方法:__init__是类的构造方法
    def __init__(self, n_neighbors = 1, dist_func = l1_distance):
        self.n_neighbors = n_neighbors
        self.dist_func = dist_func
    
    # 调整模型方法
    def fit(self, x, y):
        self.x_train = x
        self.y_train = y
    
    # 模型预测方法
    def predict(self, x):
        # 初始化预测分类数组
        y_pred = np.zeros((x.shape[0], 1), dtype = self.y_train.dtype)
        
        # 遍历输入X的数据点 (每一个测试点的下标序号i和数据)
        for i, x_test in enumerate(x):
            # 测试数据x_test和训练数据计算距离
            distances = self.dist_func(self.x_train,x_test)
            
            # 由近到远排序,取得索引值
            nn_index = np.argsort(distances) # 输出索引值
            
            # 选取最近的K个点, 保存它们对应的分类类别
            nn_y = self.y_train[nn_index[:self.n_neighbors] ].ravel() # 变成一维数组
            
            #统计类别中出现频率最高的那个, 赋给y_pred[i]
            y_pred[i] = np.argmax(np.bincount(nn_y))# binnary count 统计每个值出现的次数 输出成数组
            
        return y_pred

 

3. 测试

# 定义实例
knn = kNN(n_neighbors = 3)
# 训练模型
knn.fit(x_train, y_train)
# 传入测试数据, 做预测
y_pred = knn.predict(x_test)

# 求出预测准确率
accuracy = accuracy_score(y_test, y_pred)

print("预测准确率:",accuracy)

 

 

# 定义实例
knn = kNN()
# 训练模型
knn.fit(x_train, y_train)

# list保存结果
result_list = []

# 针对不同的参数选取,做预测
for p in [1, 2]:
knn.dist_func = l1_distance if p == 1 else l2_distance

# 考虑不同的K取值. 步长为2 ,避免二元分类 偶数打平
for k in range(1, 10, 2):
knn.n_neighbors = k
# 传入测试数据, 做预测
y_pred = knn.predict(x_test)
# 求出预测准确率
accuracy = accuracy_score(y_test, y_pred)
result_list.append([k, 'l1_distance' if p == 1 else 'l2_distance', accuracy])
df= pd.DataFrame(result_list, columns=['k', '距离函数', '预测准确率'])
df

 

 

 

 

标签:KNN,iris,pred,近邻,np,train,-----,test,self
From: https://www.cnblogs.com/slowlydance2me/p/16857906.html

相关文章

  • 使用koa-generator生成koa2项目
    1、新建项目目录,准备在哪里创建项目和写代码,就在哪里创建即可。2、打开命令行窗口。安装koa-generator,安装命令为:npminstall-gkoa-generator(全局安装)3、使用koa-generat......
  • 12-组件篇之消息队列(1)_ev
               ......
  • 学习笔记-VSFTP
    VSFTP配置案例安装服务端yuminstall-yvsftpd客户端yuminstall-yftp匿名访问参数作用anonymous_enable=YES允许匿名访问模式anon_umask......
  • 学习笔记-God-Linux
    God-Linuxbash#判断当前是否是登陆式或非登陆式shellecho$0#上一个命令的最后一个参数.例如:上一条命令(vimtest.txt),cat!$=cattest.txt!$#以......
  • 学习笔记-Secure-Linux
    Secure-LinuxLinux加固+维护+应急响应参考文档内容仅限Linux,web服务和中间件的加固内容请看加固大纲文件可疑文件文件恢复系统密码重置会话......
  • 详解随机森林-机器学习中调参的基本思想【菜菜的sklearn课堂笔记】
    视频作者:[菜菜TsaiTsai]链接:[【技术干货】菜菜的机器学习sklearn【全85集】Python进阶_哔哩哔哩_bilibili]调参的方式总是根据数据的状况而定,所以没有办法一概而论那我......
  • 学习笔记-Iptables
    Iptables什么是iptablesLinux系统在内核中提供了对报文数据包过滤和修改的官方项目名为Netfilter,它指的是Linux内核中的一个框架,它可以用于在不同阶段将某些钩子函......
  • 学习笔记-mysql
    mysqlmy.cnf配置文件port=3309socket=/usr/local/mysql/tmp/mysql.sock[mysqld]#服务器端配置!include/usr/local/mysql/etc/mysqld.......
  • 学习笔记-LAMP
    LAMPLAMP指的Linux(操作系统)、ApacheHTTP服务器,MySQL(有时也指MariaDB,数据库软件)和PHP(有时也是指Perl或Python)的第一个字母,一般用来建立web应用平台Mai......
  • 学习笔记-nfs 配置案例
    nfs配置案例案例1服务端在Centos上配置nfs服务以只读的形式方式共享目录/public(目录需要自己创建).yum-yinstallnfs-utilsvim/etc/exports/public......