首页 > 其他分享 >手写持向量机(SVM)实现

手写持向量机(SVM)实现

时间:2024-10-15 15:48:29浏览次数:6  
标签:SVM self param np 手写 array 向量 lambda

下面是一个简单的支持向量机(SVM)实现,用于解决线性可分问题。

这个实现不使用任何机器学习库,只使用NumPy进行矩阵运算。

请注意,这个实现主要用于教学目的,实际应用中推荐使用成熟的库,如scikit-learn。

import numpy as np

class SVM:
    def __init__(self, learning_rate=0.001, lambda_param=0.01, n_iterations=1000):
        """
        初始化SVM分类器。
        
        参数:
        learning_rate (float): 学习率。
        lambda_param (float): 正则化参数。
        n_iterations (int): 迭代次数。
        """
        self.lr = learning_rate
        self.lambda_param = lambda_param
        self.n_iterations = n_iterations
        self.w = None
        self.b = None

    def fit(self, X, y):
        """
        训练SVM模型。
        
        参数:
        X (numpy.array): 特征矩阵。
        y (numpy.array): 标签向量。
        """
        n_samples, n_features = X.shape
        self.w = np.zeros(n_features)
        self.b = 0

        # 转换标签为1和-1
        y_ = np.where(y <= 0, -1, 1)

        for _ in range(self.n_iterations):
            for idx, x_i in enumerate(X):
                # 计算条件,检查是否满足SVM的间隔条件
                condition = y_[idx] * (np.dot(x_i, self.w) - self.b) >= 1
                if condition:
                    # 如果满足条件,执行梯度下降更新权重(带正则化)
                    self.w -= self.lr * (2 * self.lambda_param * self.w)
                else:
                    # 如果不满足条件,执行梯度下降更新权重和偏置项
                    self.w -= self.lr * (2 * self.lambda_param * self.w - np.dot(x_i, y_[idx]))
                    self.b -= self.lr * y_[idx]

    def predict(self, X):
        """
        使用训练好的SVM模型进行预测。
        
        参数:
        X (numpy.array): 特征矩阵。
        
        返回:
        predictions (numpy.array): 预测标签。
        """
        linear_output = np.dot(X, self.w) - self.b
        return np.sign(linear_output)

# 生成一些合成数据
X = np.array([[5, 5], [3, 5], [4, 3], [2, 3], [5, 3], [5, 4], [3, 5], [4, 4], [3, 3], [4, 2], [3, 2], [2, 4]])
y = np.array([-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1])

# 创建SVM模型实例
svm = SVM(learning_rate=0.001, lambda_param=0.01, n_iterations=1000)

# 训练模型
svm.fit(X, y)

# 进行预测
predictions = svm.predict(X)

# 打印预测值和真实值
print("Predictions:", predictions)
print("Real values:", y)

代码解释:

  1. 初始化:在__init__方法中,我们初始化了学习率、正则化参数、迭代次数、权重向量w和偏置项b

  2. 训练:在fit方法中,我们首先将标签转换为-1和1,然后进行迭代,对每个样本进行梯度下降。如果样本满足间隔条件(即y_i * (w^T x_i + b) >= 1),我们只更新权重向量w,否则我们同时更新权重向量w和偏置项b

  3. 预测:在predict方法中,我们计算线性输出,然后使用np.sign函数将输出转换为预测标签。

请注意,这个简单的SVM实现没有包含一些高级特性,如核技巧、软间隔或更复杂的优化算法。对于非线性可分的数据集或需要更高性能的应用场景,建议使用成熟的库,如scikit-learn中的SVM实现。

标签:SVM,self,param,np,手写,array,向量,lambda
From: https://www.cnblogs.com/redufa/p/18467665

相关文章

  • 支持向量机 --优化
    支持向量机1.支持向量SVM最优化问题SVM想要的就是找到各类样本点到超平面的距离最远,也就是找到最大间隔超平面。任意超平面可以用下面这个[线性方程]来描述:\[\omega^Tx+b=0\]二维空间点$(x,y)$到直线$Ax+By+C=0$​的距离公式是:\[\frac{|Ax+By+C|}{\sqrt{A^2+B^2}......
  • 基于MATLAB的BP神经网络手写数字识别系统
    介绍*:本课题为基于MATLAB的BP神经网络手写数字识别系统。带有GUI人机交互式界面。读入测试图片,通过截取某个数字,进行预处理,经过bp网络训练,得出识别的结果。可经过二次改造成识别中文汉字,英文字符等课题。运行效果示例图:......
  • LLM中词向量的表示和词嵌入的一些疑问
    LLM中词向量的表示和词嵌入的一些疑问词向量的一些特点在3blue1brown的视频【官方双语】GPT是什么?直观解释Transformer|深度学习第5章_哔哩哔哩_bilibili中,在15min左右介绍了LLM的词嵌入的过程.其中提到mother的词向量减去father的词向量,会近似于women的词向量-man的词向......
  • 万字详解AI实践,零手写编码用AI完成开发 + 数据清洗 + 数据处理 的每日新闻推荐,带你快
    用AI+dify完成前后端开发+数据处理和数据清洗。引言数据获取和数据处理dify构建workflow进行数据清洗前端页面构建和前后端交互总结引言AI时代对开发人员的加强是非常明显的,一个开发人员可以依靠AI横跨数个自己不熟悉的领域包括前后端、算法等。让我们来做个实践,全程......
  • 手写mybatis之把反射用到出神入化
    前言但在实操上,很多码农根本没法阅读框架源码。首先一个非常大的问题是,面对如此庞大的框架源码,不知道从哪下手。与平常的业务需求开发相比,框架源码中运用了大量的设计原则和设计模式对系统功能进行解耦和实现,也使用了不少如反射、代理、字节码等相关技术。如果你有......
  • 手写mybatis之数据源池化技术实现
    前言在上一章节我们解析了XML中数据源配置信息,并使用Druid创建数据源完成数据库的操作。但其实在Mybatis中是有自己的数据源实现的,包括无池化的UnpooledDataSource实现方式和有池化的PooledDataSource实现方式。你可以把池化技术理解为享元模式的具体实现方......
  • 大模型agent开发之文本向量化
    文本向量化实现方式 在复杂的大模型中文本向量化有很多好处,比如提高检索速度,在大规模数据集上向量通过相似表示可以快速找到相似文本,在处理长文本和跨语言对齐等任务上也可以减少很多开销。在langchain中可以从包langchain.embeddings.openai中可以引入方法OpenAIEmbeddings定义......
  • 更强的RAG:向量数据库和知识图谱的结合
    传统RAG的局限性经典的RAG架构以向量数据库(VectorDB)为核心来检索语义相似性上下文,让大语言模型(LLM)不需要重新训练就能够获取最新的知识,其工作流如下图所示:这一架构目前广泛应用于各类AI业务场景中,例如问答机器人、智能客服、私域知识库检索等等。虽然RAG通过知识增强一......
  • 支持向量机
    一、SVM基本原理SVM(SupportVectorMachine)SVM是机器学习中常用的分类算法。SVM算法可以将......
  • 基于MSER和HOG特征提取的SVM交通标志检测和识别算法matlab仿真
    1.算法运行效果图预览(完整程序运行后无水印)   2.算法运行软件版本matlab2017b 3.部分核心程序(完整版代码包含中文注释和操作步骤视频)function[Ic,Xmin3,Xmax3,Ymin3,Ymax3]=func_merge(I,Trafficxy,Smj,SCALE);%提取交通标志的中心点,判断是否为同一......