首页 > 其他分享 >最小二乘解的理解

最小二乘解的理解

时间:2024-09-19 14:46:49浏览次数:14  
标签:多项式 SVD 矩阵 最小 理解 二乘解 拟合 np

记录一下工作时遇到的拟合问题,将两个数据的关系建模为最小二乘的模型:

\[y = a_0 + a_1 x + a_2 x^2 + a_3 x^3 + a_4 x^4 \]

使用了python里面的numpy.linalg.lstsq函数进行拟合,以下是一个简单的示例


import numpy as np
import matplotlib.pyplot as plt

# 样本数据点
x = np.array([-10, -8, -5, -3, 0, 2, 5, 7, 9, 12])
y = np.array([1200, 800, 200, 100, 50, 80, 300, 600, 1000, 1800])

# 构建设计矩阵
X = np.vstack([x**n for n in range(5)]).T  # 包含 x 的 0 次到 4 次方

# 求解最小二乘问题
coefficients, residuals, rank, singular_values = np.linalg.lstsq(X, y, rcond=None)

# 打印系数
print("拟合多项式的系数为:", coefficients)

# 计算拟合的 y 值
y_fit = X @ coefficients

# 绘制结果
plt.scatter(x, y, color='blue', label='原始数据')
plt.plot(x, y_fit, color='red', label='拟合多项式')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('最小二乘法多项式拟合')
plt.show()

调库拟合感觉很抽象,没办法深入理解到底做了什么事情。
于是研究了一下拟合的时候具体做了哪些步骤:

在拟合多项式 $y = a_0 + a_1 x + a_2 x^2 + a_3 x^3 + a_4 x^4 $的系数时,我们需要解决一个线性方程组,使得预测值与实际值之间的误差平方和最小。这就是典型的最小二乘问题。


numpy.linalg.lstsq 的内部操作步骤:

  1. 构建设计矩阵 $ A $:

    • 我们将多项式的每一项视为一个特征,构建一个设计矩阵$ A $,其大小为 $ m \times n $(其中 $ m $是数据点的数量, $ n $ 是多项式系数的数量)。

    • 矩阵 $ A $ 的每一列对应于 $ x $ 的不同次幂:

      \[A = \begin{bmatrix} 1 & x_1 & x_1^2 & x_1^3 & x_1^4 \\ 1 & x_2 & x_2^2 & x_2^3 & x_2^4 \\ \vdots & \vdots & \vdots & \vdots & \vdots \\ 1 & x_m & x_m^2 & x_m^3 & x_m^4 \\ \end{bmatrix} \]

  2. 设定观测向量 $ b $:

    • 向量 $ b $ 包含了对应的 $ y $ 值:

      \[b = \begin{bmatrix} y_1 \\ y_2 \\ \vdots \\ y_m \\ \end{bmatrix} \]

  3. 定义最小二乘目标:

    • 我们的目标是找到系数向量 $ x $(即 $ a_0, a_1, a_2, a_3, a_4 $),使得残差的平方和最小:

      \[\min_x \| Ax - b \|_2^2 \]

  4. 计算矩阵的奇异值分解(SVD):

    • 内部,lstsq 函数使用 奇异值分解(Singular Value Decomposition, SVD) 对矩阵 $ A $ 进行分解:

      \[A = U \Sigma V^T \]

      • $ U $ 是 $ m \times m $ 的正交矩阵。
      • $ \Sigma $ 是 $ m \times n $ 的对角矩阵,对角线上的元素为 $ A $ 的奇异值。
      • $ V^T $ 是 $ n \times n $ 的正交矩阵的转置。
  5. 计算矩阵 $ A $ 的伪逆 $ A^+ $:

    • 基于 SVD 分解,计算 $ A $ 的 Moore-Penrose 伪逆

      \[A^+ = V \Sigma^+ U^T \]

      • $ \Sigma^+ $ 是将 $ \Sigma $ 中非零元素取倒数后转置而得。
  6. 求解最小二乘解 $ x $:

    • 使用伪逆计算解:

      \[x = A^+ b \]

      • 这给出了使得 $ | Ax - b |_2 $ 最小的系数向量 $ x $。
  7. 计算残差 $ r $:

    • 计算残差向量:

      \[r = b - Ax \]

    • 残差的平方和为:

      \[\text{residuals} = \| r \|_2^2 \]

  8. 确定矩阵 $ A $ 的秩:

    • 通过奇异值判断 $ A $ 的 有效秩(rank),这有助于了解问题的可解性和解的稳定性。
  9. 返回结果:

    • lstsq 函数返回以下内容:
      • 系数向量 $ x $:拟合多项式的系数。
      • 残差数组:残差的平方和(若方程组为超定,即方程数大于未知数)。
      • 秩 $ rank $:矩阵 $ A $ 的秩。
      • 奇异值数组:矩阵 $ A $ 的奇异值。

总结:

  • numpy.linalg.lstsq 函数通过对设计矩阵 ( A ) 进行奇异值分解,计算其伪逆,然后求解最小二乘问题,得到使预测值与实际值误差平方和最小的系数。
  • 使用 SVD 有以下优点:
    • 数值稳定性:SVD 对于病态矩阵(条件数很大)能够提供稳定的求解。
    • 处理欠定和超定方程组:即使矩阵 ( A ) 不满秩,SVD 也能找到范数最小的最小二乘解。

附加说明:

  • 为什么不是直接求解正常方程 $ A^T A x = A^T b $?

    • 直接求解可能导致数值不稳定,尤其是当 $ A^T A $ 的条件数很大时。
    • 使用 SVD 可以避免这些问题,提供更可靠的解。
  • 计算复杂度:

    • SVD 的计算复杂度较高,但对于中小规模的问题(如多项式拟合),现代计算机可以快速完成计算。

举例:

假设我们有以下数据点:

import numpy as np

# 样本数据点
x = np.array([-10, -8, -5, -3, 0, 2, 5, 7, 9, 12])
y = np.array([1200, 800, 200, 100, 50, 80, 300, 600, 1000, 1800])

# 构建设计矩阵 A
A = np.vstack([x**0, x**1, x**2, x**3, x**4]).T

# 使用最小二乘法求解
coefficients, residuals, rank, singular_values = np.linalg.lstsq(A, y, rcond=None)

print("拟合多项式的系数为:", coefficients)
print("残差的平方和为:", residuals)
print("矩阵 A 的秩为:", rank)
print("矩阵 A 的奇异值为:", singular_values)

详细步骤解读:

  1. 构建设计矩阵 $ A $:

    • 每一列对应 $ x $ 的 0 到 4 次幂,共 5 列。
    • $ A $ 的大小为 $ 10 \times 5 $。
  2. 使用 np.linalg.lstsq 求解:

    • 函数内部对 $ A $ 进行 SVD 分解。
    • 计算 $ A $ 的伪逆 $ A^+ $。
    • 求解系数向量 $ x = A^+ b $。
  3. 解的解释:

    • coefficients:拟合的多项式系数 $ [a_0, a_1, a_2, a_3, a_4] $。
    • residuals:预测值与实际 $ y $ 值之间的误差平方和。
    • rank:矩阵 $ A $ 的秩,表示特征的线性独立数量。
    • singular_values:用于判断矩阵是否存在病态或奇异情况。

标签:多项式,SVD,矩阵,最小,理解,二乘解,拟合,np
From: https://www.cnblogs.com/hemol/p/18420518

相关文章

  • 深度学习-16-深入理解BERT基于本地数据微调训练文本分类模型的流程
    文章目录1加载库和设置通用参数1.1DistilBert1.2模型库1.3微调任务2准备数据2.1加载数据2.2切分数据2.3数据分词2.4制作数据集3使用TrainerAPI微调transformer3.1加载预训练模型3.2定义训练器3.3执行训练3.4评估性能3.5保存模......
  • 深度学习-17-深入理解BERT基于Hugging Face的模型训练步骤
    文章目录1大模型的架构1.1Transformer架构1.2BERT(双向Transformer架构)1.3GPT(GenerativePretrainedTransformer)1.4T5(Text-To-TextTransferTransformer)1.5DistilBERT1.6不同架构的优缺点对比2HuggingFace模型训练步骤2.1平台功能2.1......
  • 【高中数学/等比中项/极值/基本不等式】已知a>0,b>0,9是3^a与27^b的等比中项,求:(a^2+2)
    【问题】(某地模考题)已知a>0,b>0,9是3^a与27^b的等比中项,求:(a^2+2)/a+(3b^2+1)/b的最小值?【解答】由”9是3^a与27^b的等比中项“得到3^a/9=9/27^b,继而得到a+3b=4......(1)(a^2+2)/a+(3b^2+1)/b=a+2/a+3b+1/b=4+2/a+1/b......(2)由(1)得出2=a/2+3b/2,1=a/4+3b/4代入(2)得4+1/2+3b/2a+a......
  • 深入理解 dladdr:符号信息查询与应用场景详解
    dladdr是一个用于获取与特定地址相关的符号信息的函数,它在Linux和类UNIX系统中非常有用,尤其是在进行调试或诊断时。以下是详细的介绍和一些使用示例:1.基本概念dladdr函数通常用于获取共享库中的符号信息。它可以根据给定的地址,返回该地址对应的符号信息,例如函数名称、所在的......
  • 深入理解Go并发编程:避免Goroutine泄漏与错误处理
    Go语言以其强大的并发模型和高效的协程(goroutine)而闻名。协程的轻量级和易用性使得并发编程变得更加简单。然而,如果不正确管理协程,可能会导致Goroutine泄漏,从而消耗系统资源,影响程序性能。本文将深入探讨如何避免Goroutine泄漏,并提供实用的代码示例和技巧,帮助您编写更加健壮......
  • 《深入理解 Java 线程池:高效管理线程的利器》
    线程池1.什么是线程池?​线程池内部维护了若干个线程,没有任务的时候,这些线程都处于等待空闲状态。如果有新的线程任务,就分配一个空闲线程执行。如果所有线程都处于忙碌状态,线程池会创建一个新线程进行处理或者放入队列(工作队列)中等待。2.线程池常用类和接口​在Java标......
  • 对面向对象的理解
    面向对象编程(Object-OrientedProgramming,简称OOP)是一种编程范式,它将软件结构建模为对象的集合,每个对象都是数据和行为的封装体。以下是对面向对象编程的深入理解:核心概念对象(Object):对象是面向对象编程的基本单元,它代表现实世界中的一个实体。对象具有属性(称为字段或属性)和......
  • MySQL 二进制日志(binlog):理解与应用
    在MySQL数据库的世界里,二进制日志(binlog)是一个至关重要的组成部分。那么,什么是MySQL的二进制日志呢?它又有着哪些重要的作用呢?让我们一起来深入探讨。一、什么是MySQL的二进制日志(binlog)MySQL的二进制日志是一种记录数据库变更的文件。它以二进制格式记录了数据库中......