首页 > 编程语言 >使用Python和scikit-learn实现支持向量机(SVM)

使用Python和scikit-learn实现支持向量机(SVM)

时间:2024-07-05 22:02:48浏览次数:21  
标签:SVM 边界 Python clf scikit 分类器 plt 向量 gamma

        支持向量机(Support Vector Machine,SVM)是一种强大的监督学习算法,广泛用于分类和回归问题。它能够有效处理线性和非线性数据,并在复杂数据集中表现出色。本文将介绍如何使用Python和scikit-learn库实现SVM,以及如何通过可视化不同参数设置来理解其工作原理。

一、什么是支持向量机(SVM)?

        支持向量机是一种二类分类模型,它的基本思想是在特征空间中找到一个最优的超平面,能够将不同类别的数据点分隔开来,并且使得两侧距离最近的数据点(支持向量)到超平面的距离最大化。对于非线性可分的数据集,SVM通过核函数将数据映射到高维空间,使得数据线性可分。

二、实现步骤

        我们将使用Python的scikit-learn库来实现一个简单的支持向量机分类器,并在一个合成的数据集上进行可视化展示。

1、导入必要的库和数据集生成
from sklearn import svm
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as np
2、定义绘制决策边界和支持向量的函数
# 定义绘制决策边界和支持向量的函数
def plot_hyperplane(clf, X, y, h=0.02, draw_sv=True, title='Hyperplane'):
    # 确定绘图边界
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))

    # 设置绘图属性
    plt.title(title)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())

    # 生成网格数据并进行预测
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # 绘制决策边界
    plt.contourf(xx, yy, Z, cmap='hot', alpha=0.5)

    # 定义不同类别的标记和颜色
    markers = ['o', 's', '^']
    colors = ['b', 'r', 'c']
    labels = np.unique(y)
    
    # 绘制数据点
    for label in labels:
        plt.scatter(X[y==label][:, 0],
                    X[y==label][:, 1],
                    c=colors[label],
                    marker=markers[label])
    
    # 如果指定绘制支持向量,则绘制支持向量
    if draw_sv:
        sv = clf.support_vectors_
        plt.scatter(sv[:, 0], sv[:, 1], c='y', marker='x')
3、生成合成数据集

        使用make_moons函数生成一个合成的月亮形数据集,用于后续的分类器训练和可视化。

#生成包含150个样本的合成数据集,噪声为0.15
X, y = make_moons(n_samples=150, noise=0.15, random_state=42)
  • make_moons函数生成一个合成的月亮形状数据集。
  • n_samples: 数据集样本数量。
  • noise: 噪声水平,增加数据的随机性。
  • random_state: 随机种子,确保每次运行生成的数据一致性。
4、初始化不同参数设置的SVM分类器

        初始化了六个不同参数设置的SVM分类器,分别探索了不同的gammaC参数组合对分类性能的影响。

# 初始化不同参数设置的SVM分类器
#初始化一个RBF核支持向量机分类器,设置参数C和gamma。
clf_rbf1 = svm.SVC(C=1, kernel='rbf', gamma=0.01)
clf_rbf2 = svm.SVC(C=1, kernel='rbf', gamma=5)
clf_rbf3 = svm.SVC(C=100, kernel='rbf', gamma=0.01)
clf_rbf4 = svm.SVC(C=100, kernel='rbf', gamma=5)
clf_rbf5 = svm.SVC(C=10000, kernel='rbf', gamma=0.01)
clf_rbf6 = svm.SVC(C=10000, kernel='rbf', gamma=5)

# 创建绘图区域
plt.figure()

# 将所有分类器放入列表中
clfs = [clf_rbf1, clf_rbf2, clf_rbf3, clf_rbf4, clf_rbf5, clf_rbf6]
# 设置每个子图的标题
titles = ['gamma=0.01, C=1',
          'gamma=0.01, C=100',
          'gamma=0.01, C=10000',
          'gamma=5, C=1',
          'gamma=5, C=100',
          'gamma=5, C=10000']
  • 初始化了六个不同参数设置的SVM分类器对象。
  • svm.SVC: 创建一个支持向量机分类器。
  • C: 正则化参数。
  • kernel: 核函数类型,这里使用径向基函数(RBF)。
  • gamma: RBF核函数的参数,影响决策边界的灵活性和复杂度。
  • clfs: 包含所有分类器的列表。
  • titles: 每个分类器对应的标题,用于在图表中显示参数设置。
5、绘制不同参数设置下的SVM决策边界

        最后,我们通过循环遍历每个分类器,对其进行训练并绘制出相应的决策边界和支持向量。

# 对每个分类器进行训练和绘图
for clf, i in zip(clfs, range(len(clfs))):
    clf.fit(X, y)  # 训练分类器
    plt.subplot(3, 2, i+1)  # 创建3行2列的子图,并选择当前子图
    plot_hyperplane(clf, X, y, title=titles[i])  # 绘制决策边界和支持向量

plt.show()  # 显示图形
  • 创建一个12x10英寸大小的图形窗口。
  • 使用zip函数将每个分类器clf与其对应的索引i组合。
  • 对每个分类器进行训练(clf.fit(X, y))并调用plot_hyperplane函数绘制决策边界。
  • plt.subplot(3, 2, i+1): 将图形分成3行2列的子图,当前绘制第i+1个子图。
  • plt.tight_layout(): 调整子图的布局,防止重叠。
  • plt.show(): 显示绘制的图形。
三、结果分析

这幅图展示了使用较小的gamma(0.01)和较小的正则化参数C(1)训练的SVM模型。
决策边界相对平滑,模型的复杂度较低。
可以看到一些支持向量被标记为黄色的'x'符号。

这幅图展示了相同的低gamma值(0.01),但更大的正则化参数C(100)。
决策边界仍然平滑,但稍微更接近一些数据点,对噪声更敏感。

使用了极大的C值(10000),同时gamma仍然较低(0.01)。
决策边界非常接近许多数据点,模型非常复杂,几乎适应了每个数据点,可能存在过拟合的风险。

这幅图展示了较高的gamma值(5)和较小的正则化参数C(1)。
决策边界非常复杂,几乎适应了每个数据点,可能出现了过拟合。

在高gamma(5)和较大C(100)的设置下,决策边界略有平滑化,但仍然相对复杂。

最后一幅图展示了高gamma(5)和非常大的C值(10000)。
决策边界非常复杂,几乎适应了每个数据点,存在严重的过拟合可能性。


总体分析:

        参数gamma控制了决策边界的灵活性,较大的gamma值会导致决策边界更复杂,更贴近训练数据点。参数C是正则化参数,控制了对错误分类的惩罚,较大的C值会导致模型更关注训练数据,可能会导致过拟合。在低gamma和低C值的情况下,决策边界相对平滑,模型简单;而在高gamma和高C值的情况下,决策边界更复杂,可能会过度适应训练数据。

标签:SVM,边界,Python,clf,scikit,分类器,plt,向量,gamma
From: https://blog.csdn.net/2301_77444219/article/details/140188131

相关文章

  • python简单入门(五)
    一、面对对象程序设计基础1. 面对对象程序设计思想概述面向对象程序设计(Object-OrientedProgramming,简称OOP)是一种编程范式,它将数据和操作数据的方法封装在一个对象中。这种方法强调的是将现实世界中的实体抽象为对象,每个对象都有其独特的属性和行为。在Python中,面向对象......
  • 极限学习机(Extreme Learning Machine,ELM)及其Python和MATLAB实现
    极限学习机(ExtremeLearningMachine,ELM)是一种快速而有效的机器学习算法,最初由马洪亮等人于2006年提出。ELM是一种单隐层前馈神经网络,其背景源于对传统神经网络训练过程中反向传播算法的改进与优化。相比传统神经网络,ELM在网络训练速度上具有明显优势,同时在一些实际应用中取得......
  • 蝙蝠优化算法(Bat Algorithm,BA)及其Python和MATLAB实现
    蝙蝠优化算法(BatAlgorithm,简称BA)是一种基于蝙蝠群体行为的启发式优化算法,由Xin-SheYang于2010年提出。该算法模拟了蝙蝠捕食时在探测目标、适应环境和调整自身位置等过程中的行为,通过改进搜索过程来实现优化问题的求解。蝙蝠群体中每一只蝙蝠代表一个潜在解,在搜索过程中,蝙蝠......
  • unbutu源码安装python3.12
    1安装依赖项sudoaptupdatesudoaptinstall-ybuild-essentialzlib1g-devlibncurses5-devlibgdbm-devlibnss3-devlibssl-devlibsqlite3-devlibreadline-devlibffi-devwget2下载Python3.12源代码#下载wgethttps://www.python.org/ftp/python/3.12.0/Python-......
  • Python:自制密码的加密与破译
    importtkinterastkupper_password={'A':('△','▽','○'),'B':('◇','□','☆'),'C':('▷','◁','♤'),'D':('♡&......
  • Python基于卷积神经网络分类模型(CNN分类算法)实现时装类别识别项目实战
    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。1.项目背景在深度学习领域,卷积神经网络(ConvolutionalNeuralNetworks,CNNs)因其在图像识别和分类任务上的卓越表现而备受关注。CNNs能够自动检测图像中的特......
  • Python实现ABC人工蜂群优化算法优化循环神经网络分类模型(LSTM分类算法)项目实战
    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。1.项目背景人工蜂群算法(ArtificialBeeColony,ABC)是由Karaboga于2005年提出的一种新颖的基于群智能的全局优化算法,其直观背景来源于蜂群的采蜜行为,蜜蜂根......
  • Python学习笔记29:进阶篇(十八)常见标准库使用之质量控制中的数据清洗
    前言本文是根据python官方教程中标准库模块的介绍,自己查询资料并整理,编写代码示例做出的学习笔记。根据模块知识,一次讲解单个或者多个模块的内容。教程链接:https://docs.python.org/zh-cn/3/tutorial/index.html质量控制质量控制(QualityControl,QC),主要关注于提高......
  • 傻瓜式安装Python解释器
    一,Python解释器安装配置1.在哪安装?? 任意浏览器搜索python.org    (小伙伴们要注意,看清楚官网!!官网!!还是官网!!!带有广告字样的一定要忍住了,不能点)     进入官网(如下) 2.如何安装??找到Downloads,并在其下找到Windows点击进入,下拉找到你要下载的版本,并点......
  • python实现从某个网址爬取图片到本地电脑
    源码如下:importurllib#导入urllib包importurllib.request#导入urllib包里的request方法importre#导入re正则库#这个函数实现打开传入的路径并将页面数据读取出来,实现代码,包括发送请求,打开页面,获取数据。defload_page(url):    request=urllib.request.Req......