首页 > 其他分享 >关于吴恩达机器学习中反向传播的理解

关于吴恩达机器学习中反向传播的理解

时间:2024-01-14 20:45:35浏览次数:22  
标签:吴恩达 right 机器 -- boldsymbol 反向 delta Theta left

title: 关于吴恩达机器学习中反向传播的理解
date: 2022-04-12
categories: 数学
mathjax: true
tags:
- 机器学习
- 线性代数

原文

在机器学习视频反向传播章节[1]中:

我们用 \(\delta\) 来表示误差,则: \(\boldsymbol\delta^{\left(4\right)}=\boldsymbol a^{\left(4\right)}−\boldsymbol y\) 。我们利用这个误差值来计算前一层的误差:

\(\boldsymbol\delta^{\left(3\right)}=\left(\boldsymbol\Theta^{\left(3\right)}\right)^T\boldsymbol\delta^{\left(4\right)}\cdot g^\prime\left(\boldsymbol z^{\left(3\right)}\right)\) 。其中 \(g^\prime\left(\boldsymbol{z}^{\left(3\right)}\right)\) 是 \(S\) 形函数的导数,

\(g^\prime\left(\boldsymbol z^{\left(3\right)}\right)=\boldsymbol a^{\left(3\right)}\cdot\left(1−\boldsymbol a^{\left(3\right)}\right)\) 。而 \(\left(\boldsymbol\Theta^{\left(3\right)}\right)^T\boldsymbol\delta^{\left(4\right)}\) 则是权重导致的误差的和。

问题

\[\boldsymbol\delta^{\left(3\right)}=\left(\boldsymbol\Theta^{\left(3\right)}\right)^T\boldsymbol\delta^{\left(4\right)}\cdot g^\prime\left(\boldsymbol z^{\left(3\right)}\right) \]

看到这道算式时我百思不得其解。为什么凭空会有转置?

在我自己推一遍之后,发现原公式中可能有些不严谨的地方,所以在此阐述我的理解,欢迎大家指正:

前提

对数似然代价函数: \(J\left(\Theta\right)=y\ln h_\Theta\left(x\right)+\left(1-y\right)\ln\left(1-h_\Theta\left(x\right)\right)\)

估计函数: \(h_\Theta\left(x\right)=\sum_i\Theta_ix_i= \begin{bmatrix}\Theta_1&\Theta_2&\cdots&\Theta_n\end{bmatrix} \begin{bmatrix}x_1\\x_2\\\vdots\\x_n\end{bmatrix}\)

Logistic激活函数: \(g\left(x\right)=\frac1{1+{\rm e}^{-x}}\)

此外激活函数导数为: \(g^\prime\left(x\right)=g\left(x\right)\left[1-g\left(x\right)\right]\)

我的理解

flowchart LR x1--"(Θ<sub>1</sub><sup>(1)</sup>)<sub>1</sub>"-->z12 x1--"(Θ<sub>1</sub><sup>(1)</sup>)<sub>2</sub>"-->z22 x2--"(Θ<sub>2</sub><sup>(1)</sup>)<sub>1</sub>"-->z22 x2--"(Θ<sub>2</sub><sup>(1)</sup>)<sub>2</sub>"-->z12 a12--"(Θ<sub>1</sub><sup>(2)</sup>)<sub>1</sub>"-->z13 a12--"(Θ<sub>1</sub><sup>(2)</sup>)<sub>2</sub>"-->z23 a22--"(Θ<sub>2</sub><sup>(2)</sup>)<sub>1</sub>"-->z23 a22--"(Θ<sub>2</sub><sup>(2)</sup>)<sub>2</sub>"-->z13 z12--g-->a12 z22--g-->a22 z13--g-->a13 z23--g-->a23 a13-.->y1-.->j a23-.->y2-.->j subgraph x x1((x<sub>1</sub>)) x2((x<sub>2</sub>)) end subgraph 第一层 direction LR z12(("z<sub>1</sub><sup>(2)</sup>")) a12(("a<sub>1</sub><sup>(2)</sup>")) z22(("z<sub>2</sub><sup>(2)</sup>")) a22(("a<sub>2</sub><sup>(2)</sup>")) end subgraph 第二层 z13(("z<sub>1</sub><sup>(3)</sup>")) a13(("a<sub>1</sub><sup>(3)</sup>")) z23(("z<sub>2</sub><sup>(3)</sup>")) a23(("a<sub>2</sub><sup>(3)</sup>")) end subgraph y y1((ŷ<sub>1</sub>)) y2((ŷ<sub>2</sub>)) end j(("J(θ)"))

如图(省略了偏置),输入数据为 \(\boldsymbol x=\begin{bmatrix}x_1\\x_2\end{bmatrix}\) ,实际输出为 \(\boldsymbol y=\begin{bmatrix}y_1\\y_2\end{bmatrix}\)

这张图上表示了所有的运算,例如:

\[a_1^{\left(2\right)}=g\left(z_1^{\left(2\right)}\right) \]

\[z_2^{\left(2\right)}=\left(\Theta_1^{\left(1\right)}\right)_2x_1+\left(\Theta_2^{\left(1\right)}\right)_2x_2 \]

同时,此图认为预测输出为 \(\hat y_1=a_1^{\left(3\right)}\) ,即有误差(注意此处不是定义而是结论):

\[\delta_1^{\left(3\right)}=\hat y_1-y_1=a_1^{\left(3\right)}-y_1 \]

下面我们将上列函数改写成对应元素的写法,先作定义:

  • \(L\) :被 \(\Theta\) 作用的层

  • \(m\) : \(L\) 层单元数量,用 \(j\) 进行遍历(即 \(j\in\left\{1,2,\cdots,m\right\}\) )

  • \(n\) : \(L+1\) 层单元数量,用 \(i\) 进行遍历

推导

综上可得,若 \(L\) 是倒数第二层,则给出定义

\[\begin{align*}\delta_i^{\left(L+1\right)} &=\frac{\partial J}{\partial z_i^{\left(L+1\right)}}\\ &=\frac{\partial J}{\partial a_i^{\left(L+1\right)}}&&\cdot \frac{\partial a_i^{\left(L+1\right)}}{\partial z_i^{\left(L+1\right)}}\\ &=\left(\frac{-y_i}{a_i^{\left(L+1\right)}}+\frac{1-y_i}{1-a_i^{\left(L+1\right)}}\right)&&\cdot g^\prime z_i^{\left(L+1\right)}\\ &=\left(\frac{-y_i}{a_i^{\left(L+1\right)}}+\frac{1-y_i}{1-a_i^{\left(L+1\right)}}\right)&&\cdot a_i^{\left(L+1\right)}\left(1-a_i^{\left(L+1\right)}\right)\\ &=a_i^{\left(L+1\right)}-y_i \end{align*}\]

将同一层 \(\delta_i^{\left(L+1\right)}\) 合并为矩阵得( \(\boldsymbol\delta,\boldsymbol a,\boldsymbol y\) 都是列向量):

\[\boldsymbol\delta^{\left(L+1\right)}=\boldsymbol a^{\left(L+1\right)}-\boldsymbol y \]

下面推隐含层,以第一个单元为例:

\[\begin{align*} \delta_1^{\left(2\right)}&=\frac{\partial J}{\partial z_1^{\left(2\right)}}\\ &=\frac{\partial J}{\partial z_1^{\left(3\right)}}&& \cdot\frac{\partial z_1^{\left(3\right)}}{\partial a_1^{\left(2\right)}}&& \cdot\frac{\partial a_1^{\left(2\right)}}{\partial z_1^{\left(2\right)}}&&+ \frac{\partial J}{\partial z_2^{\left(3\right)}}&& \cdot\frac{\partial z_2^{\left(3\right)}}{\partial a_1^{\left(2\right)}}&& \cdot\frac{\partial a_1^{\left(2\right)}}{\partial z_1^{\left(2\right)}}\\ &=\delta_1^{\left(3\right)}&& \cdot\left(\Theta_1^{\left(2\right)}\right)_1&& \cdot g^\prime z_1^{\left(2\right)}&&+ \delta_2^{\left(3\right)}&& \cdot\left(\Theta_1^{\left(2\right)}\right)_2&& \cdot g^\prime z_1^{\left(2\right)} \end{align*}\]

令:

\[\left\{\begin{align*} \boldsymbol\delta^{\left(L\right)}&=\begin{bmatrix}\delta_1^{\left(L\right)}\\\delta_2^{\left(L\right)}\\\vdots\\\delta_n^{\left(L\right)}\end{bmatrix}\\ \boldsymbol\Theta_i^{\left(L\right)}&=\begin{bmatrix} \left(\Theta_i^{\left(L\right)}\right)_1& \left(\Theta_i^{\left(L\right)}\right)_2& \cdots& \left(\Theta_i^{\left(L\right)}\right)_n \end{bmatrix}\end{align*}\right.\]

可将上式化为矩阵:

\[\delta_1^{\left(2\right)} =\boldsymbol\Theta_1^{\left(2\right)}\boldsymbol\delta^{\left(3\right)} \cdot g^\prime z_1^{\left(2\right)}\]

结论

由上,可写出递推普式

\[\delta_j^{\left(L\right)} =\boldsymbol\Theta_j^{\left(L\right)}\boldsymbol\delta^{\left(L+1\right)}\cdot g^\prime z_j^{\left(L\right)}\]

其中最后一层:

\[\boldsymbol\delta^{\left(Last\right)}=\boldsymbol a^{\left(Last\right)}-\boldsymbol y \]


  1. 机器学习视频反向传播章节 ↩︎

标签:吴恩达,right,机器,--,boldsymbol,反向,delta,Theta,left
From: https://www.cnblogs.com/pokersang/p/17964152

相关文章

  • 机器视觉 - YoloV8 是采用预训练还是从零开始训练的模型
    关于Fine-tuning和预训练和fromscratch训练yolo命令行model的参数的说明既可以选择yolov8n.pt,也可以选择yolov8n.yaml,区别是:model=yolov8n.pt,即为Fine-tuning训练,yolov8n.pt模型文件已经包含了yolov8网络结构、超参数、训练参数、权重参数信息,它是官方的pre......
  • 机器视觉 - YoloV8 划分数据集
    train/val/test的关系纯训练命令行参数mode=trainval=Falsemodel=yolov8n.pt训练+val命令行参数mode=trainval=Truemodel=yolov8n.pt验证预训练模型的命令行参数mode=valsplit=valmodel=yolov8n.pt验证自有模型的命令行参数mode=valsplit=valm......
  • 机器视觉 - YoloV8 命令行使用
    准备data.yaml文件从roboflow上下载CS游戏数据集,因为只有CPU,我对数据集做了瘦身,train:689张,val:23张,test:40张.https://universe.roboflow.com/roboflow-100/csgo-videogame/dataset/2train:../train/imagesval:../valid/imagestest:../test/imagesnc......
  • 实验七:Spark机器学习库Mtlib编程实践
    1、数据导入导入相关的jar包:importorg.apache.spark.ml.feature.PCAimportorg.apache.spark.sql.Rowimportorg.apache.spark.ml.linalg.{Vector,Vectors}importorg.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimportorg.apache.spark.ml.{Pipeline,......
  • 机器视觉 - YoloV8 命令行安装
    创建python环境下载并安装miniconda安装包,注意miniconda和python版本对应关系,不要选择python最新的版本,以免yolo或pytorch不能兼容最新版python.这里到安装到C:\miniconda3配置conda环境,修改conda配置文件内容,文件名为C:\Users\myuser\.condarcpy虚拟环......
  • 书籍推荐-《多层建筑中的移动工作机器人框架》
    书籍:ARoboticFrameworkfortheMobileManipulator:TheoryandApplication作者:NguyenVanToan,PhanBuiKhoi出版:CRCPress来源:公众号【一点人工一点智能】关注 51CTO @一点人工一点智能,了解更多移动机器人&人工智能信息01书籍介绍《多层建筑中的移动工作机器人框架》通过......
  • 迈向2024:医疗机器人的市场前景与技术革新
    原创|文BFT机器人医疗机器人技术正以前所未有的速度在主流医学领域取得卓越进展,新应用、新技术不断涌现,使得该领域在过去一年中取得了令人惊叹的增长。然而,这仅仅是冰山一角,未来的发展空间仍然广阔无垠。展望2024年,医疗机器人领域将有几个潜力巨大的机会正在等待发掘。PART1医......
  • 机器学习-概率图模型系列-隐含马尔科夫模型-33
    目录1.HiddenMarkovModel2.HMM模型定义注:参考链接https://www.cnblogs.com/pinard/p/6945257.html1.HiddenMarkovModel隐马尔科夫模型(HiddenMarkovModel,以下简称HMM)是比较经典的机器学习模型了,它在语言识别,自然语言处理,模式识别等领域得到广泛的应用,深度学习的崛起,......
  • 必看!2023年机器人领域十大事件!
    原创|文BFT机器人2023年,机器人产业快速发展,成就了机器人领域的一个又一个里程碑。机器人行业涌现了许多令人瞩目的事件,实现了重大突破,展示了机器人技术在各个领域的广泛应用和革命性变革。本文将对2023年机器人领域的十大事件进行盘点,带您回顾这一年的重要突破和创新,展望未来机......
  • 数据科学 机器学习 (训练营)
    地址:https://offerbang.io/......