首页 > 其他分享 >机器学习代码——线性模型

机器学习代码——线性模型

时间:2024-03-31 23:00:44浏览次数:21  
标签:plt fit 代码 self test train 线性 np 模型

3.1线性回归

import numpy as np
import matplotlib.pyplot as plt


class LinearRegressionClosedFormSol:
    """
    线性回归,模型的闭式解
    1.数据的预处理:是否训练偏置项fit_intercept(默认True),是否标准化normalized(True)
    2.模型的训练:闭式解 fit(self, x_train, y_train)
    3.模型的预测 predict(self, x_test)
    4.均方误差,判决系数
    5.模型预测可视化
    """
    def __init__(self, fit_intercept=True, normalized=True):
        self.fit_intercept = fit_intercept  # 是否训练偏置项
        self.normalized = normalized  # 是否对样本进行标准化
        self.theta = None
        if self.normalized:
            # 如果需要标准化,则计算样本特征的均值和标准方差,以便对测试样本标准化,模型系数的还原
            self.feature_mean, self.feature_std = None, None
            self.mse = None  # 模型预测的均方误差
            self.r2, self.r2_adj = 0.0, 0.0  # 判决系数和修正判决系数
            self.n_samples, self.n_features = 0, 0  # 样本量和特征属性数目

    def fit(self, x_train, y_train):
        """
        样本的预处理,模型系数的求解,闭式解公式
        :param x_train: 训练样本:ndarray,m*k
        :param y_train: 目标值:ndarray,m*1
        :return:
        """
        if self.normalized:
            self.feature_mean = np.mean(x_train, axis=0)  # 样本特征均值  axis = 0:压缩行,对各列求均值,返回1*n的矩阵,
            self.feature_std = np.std(x_train, axis=1) + 1e-8  # 样本特征标准方差,1e-8是避免分母是0
            x_train = (x_train - self.feature_mean) / self.feature_std   # 标准化
        if self.fit_intercept:
            x_train = np.c_[x_train, np.ones_like(y_train)]  # 在样本后面加一列1,np.c_ 用于连接两个矩阵
        # 训练模型
        self._fit_closed_form_solution(x_train, y_train)

    def _fit_closed_form_solution(self, x_train, y_train):
        """
        模型系数的求解,闭式解公式
        :param x_train:数据预处理后的训练样本:ndarray,m*k
        :param y_train:目标值:ndarray,m*1
        :return:
        """
        # pinv:伪逆,(X'*X)^(-1)*X'
        self.theta = np.linalg.pinv(x_train).dot(y_train)
        # xtx = np.dot(x_train.T, x_train) + 0.01 * np.eye(x_train.shape[1])  # 防止不可逆
        # self.theta = np.linalg.inv(xtx).dot(x_train.T).dot(y_train)

    def get_params(self):
        """
        获取模型的系数
        :return:
        """
        if self.fit_intercept:
            weight, bias = self.theta[:-1], self.theta[-1]
        else:
            weight, bias = self.theta, np.array([0])
        if self.normalized:
            weight = weight / self.feature_std  # 还原模型系数
            bias = bias - weight.T.dot(self.feature_mean)
        return weight, bias

    def predict(self, x_test):
        """
        模型的预测
        :param x_test:
        :return:
        """
        try:
            self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]
        except IndexError:
            self.n_samples, self.n_features = x_test.shape[0], 1
        if self.normalized:
            x_test = (x_test - self.feature_mean) / self.feature_std
        if self.fit_intercept:
            x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]  # shape[0]代表行数
        return x_test.dot(self.theta)

    def cal_mse_r2(self, y_test, y_pred):
        """
        模型预测的均方误差MSE,判决系数和修正判决系数
        :param y_test: 测试样本真值
        :param y_pred: 测试样本预测值
        :return:
        """
        self.mse = ((y_pred - y_test) ** 2).mean()  # 均方误差
        self.r2 = 1 - ((y_test - y_pred) ** 2).sum() / ((y_test - y_pred) ** 2).sum()
        self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / (self.n_samples - self.n_features - 1)
        return self.mse, self.r2, self.r2_adj

    def plt_predict(self, y_test, y_pred, is_sort=True):
        """
        预测结果的可视化
        :param y_test: 测试样本真值
        :param y_pred: 测试样本预测值
        :param is_sort: 是否对预测值进行排序,然后可视化
        :return:
        """
        plt.figure(figsize=(7, 5))
        if is_sort:
            idx = np.argsort(y_test)
            plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")
            plt.plot(y_pred[idx], "r:", lw=1.8, label="Predict Val")
        else:
            plt.plot(y_test, "ko-", lw=1.5, label="Test True Val")
            plt.plot(y_pred, "r*-", lw=1.8, label="Predict Val")
        plt.xlabel("Test samples number", fontdict={"fontsize": 12})
        plt.ylabel("Predicted samples values", fontdict={"fontsize": 12})
        plt.title("The predicted values of test samples \n "
                  "MSE = %.5f, R2 = %.5f, R2_adj = %.5f" % (self.mse, self.r2, self.r2_adj))
        plt.grid(ls=":")
        plt.legend(frameon=False)
        plt.show()


标签:plt,fit,代码,self,test,train,线性,np,模型
From: https://blog.csdn.net/weixin_67870062/article/details/137211853

相关文章

  • 11天【代码随想录算法训练营34期】 第五章 栈与队列part02(● 20. 有效的括号 ● 1047
    20.有效的括号classSolution:defisValid(self,s:str)->bool:stk=[]upper=["(","{","["]lower=[")","}","]"]dictionary={")":"(&qu......
  • 【数据结构】线性表-单链表
    编程语言:C++前言:节点:节点是链表的一个基本单元,包含两部分——数据域和指针域,数据域用于存储数据,指针域存储下一个节点的地址,形成链结。什么是单链表:n个节点链结成一个链表,即为线性表(a1,a2,a3……)的链式存储结构,每个节点只包含一个指针域的链表叫做单链表。链表组成:头节点、......
  • YOLOV8逐步分解(3)_trainer训练之模型加载
    yolov8逐步分解(1)--默认参数&超参配置文件加载yolov8逐步分解(2)_DetectionTrainer类初始化过程接上2篇文章,继续讲解yolov8训练过程中的模型加载过程。使用默认参数完成训练器trainer的初始化后,执行训练函数train()开始YOLOV8的训练。1.train()方法实现代码如下所示:......
  • 【信号分析】基于模拟数字信号ASK FSK PSK QAM调制及自相关法估计功率谱、周期图计算
      ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,代码获取、论文复现及科研仿真合作可私信。......
  • 【信道估计】大规模MIMO-OFDM稀疏多径QPSK调解制的DL信道估计附matlab代码
     ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,代码获取、论文复现及科研仿真合作可私信。......
  • Python与供应链-2预测误差及指数平滑需求预测模型
    主要介绍预测误差和指数平滑模型的相关理论,然后再通过Python的statsmodels封装的指数平滑函数预测需求。1预测误差预测误差是指预测结果与预测对象发展变化的真实结果之间的差距。这种误差分为绝对误差和相对误差。绝对误差是预测值与实际观测值的绝对差距,而相对误差则是这种......
  • 小波特征提取算法代码
    functiontezhengtiqu%新归一化方法小波矩特征提取----------------------------------------------------------F=imread('a1.bmp');F=im2bw(F);F=imresize(F,[128128]);%求取最上点fori=1:128forj=1:128if(F(i,j)==1)ytop=i;......
  • PHP代码审计——Day1-Wish List
    前言:发现红日安全代码审计小组写了关于php代码审计demo的系列文章,于是跟着一起学习。参考:[红日安全]代码审计Day1-in_array函数缺陷RIPS-PHP-SECURITY-CALENDAR-2017学习记录漏洞解析classChallenge{constUPLOAD_DIRECTORY='./solutions/';private$file;priv......
  • InnoDB 事务模型
    参考资料https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-transaction-model.htmlACID模型ACID模型是一组数据库设计原则,强调业务数据存储的可靠和关键型应用程序运行的稳定。InnoDB存储引擎遵循了ACID设计,可以保证数据不会因软件崩溃和硬件故障等异常情况而丢失。......
  • EfficientNetV2:谷歌又来了,最小的模型,最高的准确率,最快的训练速度 | ICML 2021
     论文基于training-awareNAS和模型缩放得到EfficientNetV2系列,性能远优于目前的模型。另外,为了进一步提升训练速度,论文提出progressivelearning训练方法,在训练过程中同时增加输入图片尺寸和正则化强度。从实验结果来看,EfficientNetV2的效果非常不错。来源:晓飞的算法工程笔记......