复合函数的前向微分与反向自动微分计算
关于
- 首次发表日期:2024-09-13
- 参考:
- https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation
- Calculus Early Transcendentals 9e - James Stewart (2020)
- https://en.wikipedia.org/wiki/Automatic_differentiation
- 水平有限,如有错误,请不吝指出
前向与反向自动微分:数学
先复习一下微积分求导法则
微积分求导法则复习
乘法法则
\[f(x) = u(x) \times v(x) \]\[\begin{aligned} \frac{dy}{dx} &= \frac{du}{dx} \times v + \frac{dv}{dx} \times u \\ f'(x) &= u'v + v'u \end{aligned} \]\[\begin{aligned} f(x)&=(3 x-5) \times(4 x+7) \\ u&=3 x-5 \quad v=4 x+7 \\ u^{\prime}&=3 \quad v^{\prime}=4 \\ f^{\prime}(x)&=3(4 x+7)+4(3 x-5) \\ &=12 x+21+12 x-20=24 x+1 \\ &=24 x+1 \end{aligned} \]除法法则
\[f(x) = \frac{u(x)}{v(x)} \]\[\begin{aligned} f'(x) &= \frac{u'v - v'u}{v^2} \\ \frac{dy}{dx} &= \frac{\frac{du}{dx}v - \frac{dv}{dx}u}{v^2} \end{aligned} \]\[\begin{aligned} f(x)&=\frac{3 x-5}{4 x+7} \\ u&=3 x-5 \quad v=4 x+7 \\ u^{\prime}&=3 \quad v^{\prime}=4 \\ f^{\prime}(x)&=\frac{3(4 x+7)-4(3 x-5)}{(4 x+7)^2} \\ &=\frac{12 x+21-12 x+20}{(4 x+7)^2} \\ &=\frac{41}{(4 x+7)^2} \end{aligned} \]cos和sin求导
\[\begin{aligned} y &= \sin(x) \\ \frac{dy}{dx} &= \cos(x) \end{aligned} \]\[\begin{aligned} y = \cos(x) \\ \frac{dy}{dx} = -\sin(x) \end{aligned} \]链式法则(单变量复合函数)
\[y = f(u) \quad u = f(x) \]\[\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} \]\[\begin{aligned} y&=(2 x+4)^3 \\ y&=u^3 \text { and } u=2 x+4 \\ \frac{d y}{d u}&=3 u^2 \quad \frac{d u}{d x}=2 \\ \frac{d y}{d x}&=3 u^2 \times 2=2 \times 3(2 x+4)^2 \\ &=6(2 x+4)^2 \end{aligned} \]多变量链式法则(Case 1)
\[\begin{aligned} z &= f(x,y) \\ x &= g(t) \\ y &= h(t) \\ \end{aligned} \]\[\frac{d z}{d t}=\frac{\partial f}{\partial x} \frac{d x}{d t}+\frac{\partial f}{\partial y} \frac{d y}{d t} \]多变量链式法则(Case 2)
\[\begin{aligned} z &= f(x,y) \\ x & = g(s,t) \\ y &= h(s,t) \end{aligned} \]\[\frac{\partial z}{\partial s}=\frac{\partial z}{\partial x} \frac{\partial x}{\partial s}+\frac{\partial z}{\partial y} \frac{\partial y}{\partial s} \quad \frac{\partial z}{\partial t}=\frac{\partial z}{\partial x} \frac{\partial x}{\partial t}+\frac{\partial z}{\partial y} \frac{\partial y}{\partial t} \]当计算\(\frac{\partial z}{\partial s}\)时,我们保持(hold)\(t\) 固定并计算 \(z\) 对 \(s\) 的普通导数,即应用多变量链式法则(Case 1)。计算\(\frac{\partial z}{\partial t}\)时同理。
多变量链式法则(广义版)
\[\begin{aligned} u &= f(x_1, x_2, \ldots, x_n) \\ x_k &= g(t_1, t_2, \ldots, t_m) \qquad \text{for } 1 \leq k \leq n \end{aligned} \]\[\begin{aligned} &\frac{\partial u}{\partial t_i}=\frac{\partial u}{\partial x_1} \frac{\partial x_1}{\partial t_i}+\frac{\partial u}{\partial x_2} \frac{\partial x_2}{\partial t_i}+\cdots+\frac{\partial u}{\partial x_n} \frac{\partial x_n}{\partial t_i} \end{aligned} \qquad \text{for } 1 \leq i \leq m \]复合函数,偏微分,链式法则,前向和反向自动微分
前向与反向的计算顺序
对于组合函数:
\[\begin{aligned} y & =f(g(h(x)))=f\left(g\left(h\left(w_0\right)\right)\right)=f\left(g\left(w_1\right)\right)=f\left(w_2\right)=w_3 \\ w_0 & =x \\ w_1 & =h\left(w_0\right) \\ w_2 & =g\left(w_1\right) \\ w_3 & =f\left(w_2\right)=y \end{aligned} \]链式法则将给出:
\[\begin{aligned} \frac{\partial y}{\partial x}&=\frac{\partial y}{\partial w_2} \frac{\partial w_2}{\partial w_1} \frac{\partial w_1}{\partial x}=\frac{\partial f\left(w_2\right)}{\partial w_2} \frac{\partial g\left(w_1\right)}{\partial w_1} \frac{\partial h\left(w_0\right)}{\partial x} \end{aligned} \]计算顺序:
- 前向微分计算时 ,先计算\(\partial w_1 / \partial x\),然后计算\(\partial w_2/\partial w_1\),最后计算\(\partial y / \partial w_2\)
- 反向微分计算时,先计算\(\partial y / \partial w_2\),然后计算\(\partial w_2/\partial w_1\),最后计算\(\partial w_1 / \partial x\)
前向微分
对于组合函数:
\[\begin{aligned} r &= ? \\ s &= ? \\ t &= ? \\ x &= g(r,s,t) \\ y & = h(r,s,t) \\ z &= i(r,s,t) \\ u &= f(x,y,z) \end{aligned} \]前向微分计算:
\[\begin{aligned} \frac{\partial r}{\partial v} &= ? \\ \frac{\partial s}{\partial v} &= ? \\ \frac{\partial t}{\partial v} &= ? \\ \\ \frac{\partial x}{\partial v} &= \frac{\partial x}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial x}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial x}{\partial t}\frac{\partial t}{\partial v} \\ \frac{\partial y}{\partial v} &= \frac{\partial y}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial y}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial y}{\partial t}\frac{\partial t}{\partial v} \\ \frac{\partial z}{\partial v} &= \frac{\partial z}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial z}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial z}{\partial t}\frac{\partial t}{\partial v} \\ \\ \frac{\partial u}{\partial v}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial v}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial v}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial v} \end{aligned} \]当\(v=r\),即将\(r\)作为独立变量并将\(s\)和\(t\)固定时,可得
\[\begin{aligned} \frac{\partial r}{\partial v} &= 1 \\ \frac{\partial s}{\partial v} &= 0 \\ \frac{\partial t}{\partial v} &= 0 \\ \frac{\partial u}{\partial r}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial r}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial r}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial r} \end{aligned} \]当\(v=s\),即将\(s\)作为独立变量并将\(r\)和\(t\)固定时,可得
\[\begin{aligned} \frac{\partial r}{\partial v} &= 0 \\ \frac{\partial s}{\partial v} &= 1 \\ \frac{\partial t}{\partial v} &= 0 \\ \frac{\partial u}{\partial s}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial s}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial s}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial s} \end{aligned} \]当\(v=t\),即将\(t\)作为独立变量并将\(s\)和\(r\)固定时,可得
\[\begin{aligned} \frac{\partial r}{\partial v} &= 0 \\ \frac{\partial s}{\partial v} &= 0 \\ \frac{\partial t}{\partial v} &= 1 \\ \frac{\partial u}{\partial t}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial t}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial t}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial t} \end{aligned} \]反向微分
对于组合函数:
\[\begin{aligned} u_1 &= r(x_1, x_2) \\ u_2 &= s(x_1, x_2) \\ y_1 &= f(u_1, u_2) \\ y_2 &= g(u_1, u_2) \\ y_3 &= h(u_1, u_2) \end{aligned} \]反向微分计算:
\[\begin{aligned} \frac{\partial s}{\partial y_1} &= ? \\ \frac{\partial s}{\partial y_2} &= ? \\ \frac{\partial s}{\partial y_3} &= ? \\ \\ \frac{\partial s}{\partial u_1} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_1} + \frac{\partial s}{\partial y_2}\frac{\partial y_2}{\partial u_1} + \frac{\partial s}{\partial y_3}\frac{\partial y_3}{\partial u_1} \\ \frac{\partial s}{\partial u_2} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_2} + \frac{\partial s}{\partial y_2}\frac{\partial y_2}{\partial u_2} + \frac{\partial s}{\partial y_3}\frac{\partial y_3}{\partial u_2} \\ \\ \frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_1} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_1} \\ \frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_x} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_x} \end{aligned} \]可以想象有一个函数\(s=function(y_1,y_2,y_3)\)
当\(s=y_1\),即将\(y_1\)作为独立变量并将\(y_2\)和\(y_3\)固定时,可得
\[\begin{aligned} \frac{\partial s}{\partial y_1} &= 1 \\ \frac{\partial s}{\partial y_2} &= 0 \\ \frac{\partial s}{\partial y_3} &= 0 \\ \\ \frac{\partial s}{\partial u_1} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_1}\\ \frac{\partial s}{\partial u_2} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_2} \\ \\ \frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_1} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_1} \\ \frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_x} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_x} \end{aligned} \]以例子说明自动微分的计算
例子
假设有2个输入变量(\(x_1\), \(x_2\))和2个输出变量(\(y_1\), \(y_2\)):
\[\begin{aligned} m_1 &= x_1 \cdot x_2 + \sin(x_1) \\ m_2 &= 4x_1 + 2x_2 + \cos(x_2) \\ y_1 &= m_1 + m_2 \\ y_2 &= m_1 \cdot m_2 \end{aligned} \tag{1} \]即:
\[\begin{aligned} y_1 &= x_1 \cdot x_2 + \sin(x_1) + 4x_1 + 2x_2 + \cos(x_2) \\ y_2 &= (x_1 + x_2 + \sin(x_1)) \cdot (4x_1 + 2x_2 + \cos(x_2)) \end{aligned} \]其中:
\[\begin{aligned} \frac{\partial y_1}{\partial x_1} &= x_2 + \cos(x_1) + 4 \\ \frac{\partial y_1}{\partial x_2} &= x_1 + 2 - \sin(x_2) \\ \frac{\partial y_2}{\partial x_1} &= (x_2 + \cos(x_1)) \cdot m_2 + m_1 \cdot 4 \end{aligned} \]接下来,我们将以这个例子说明如何进行前向自动微分和反向自动微分
前向自动微分
我们将用到如下的链式法则:
\[\begin{align} \frac{\partial w}{\partial t} &= \sum_i \left(\frac{\partial w}{\partial u_i} \cdot \frac{\partial u_i}{\partial t}\right) \\ &= \frac{\partial w}{\partial u_1} \cdot \frac{\partial u_1}{\partial t} + \frac{\partial w}{\partial u_2} \cdot \frac{\partial u_2}{\partial t} + \cdots \end{align} \]其中:
- \(w\)表示输出
- 在例子中,为\(y_1\)或者\(y_2\)
- \(u_i\)表示直接影响\(w\)的输入变量
- 在例子中,为\(a\)和\(b\)
- \(t\)表示有待给出的输入变量
- 在例子中,为\(x_1\)或者\(x_2\)其中之一
在计算之前,我们先将公式(1)分解为简单的算子计算:
\[\begin{aligned} x_1 &= ? \\ x_2 &= ? \\ \\ a &= x_1 \cdot x_2 \\ b &= \sin(x_1) \\ \\ c &= 4x_1 + 2x_2 \\ d &= \cos(x_2) \\ \\ m_1 &= a + b \\ m_2 &= c + d \\ \\ y_1 &= m_1 + m_2 \\ y_2 &= m_1 \cdot m_2 \end{aligned} \tag{2} \]现在我们对有待给出的变量\(t\)求导:
\[\begin{aligned} \frac{\partial x_1}{\partial t} &= ? \\ \frac{\partial x_2}{\partial t} &= ? \\ \\ \frac{\partial a}{\partial t} &= x_2\frac{\partial x_1}{\partial t} + x_1 \frac{\partial x_2}{\partial t} \\ \frac{\partial b}{\partial t} &= \cos(x_1) \frac{\partial x_1}{\partial t} \\ \\ \frac{\partial c}{\partial t} &= 4\frac{\partial x_1}{\partial t} + 2 \frac{\partial x_2}{\partial t} \\ \frac{\partial d}{\partial t} &= -\sin(x_2)\frac{\partial x_2}{\partial t} \\ \\ \frac{\partial m_1}{\partial t} &= \frac{\partial a}{\partial t} + \frac{\partial b}{\partial t} \\ \frac{\partial m_2}{\partial t} &= \frac{\partial c}{\partial t} + \frac{\partial d}{\partial t} \\ \\ \frac{\partial y_1}{\partial t} &= \frac{\partial m_1}{\partial t} + \frac{\partial m_2}{\partial t} \\ \frac{\partial y_2}{\partial t} &= \frac{\partial m_1}{\partial t} \cdot m_2 + \frac{\partial m_2}{\partial t} \cdot m_1 \end{aligned} \]前面有提到\(t\)是有待给出的,现在是时候给出了:
- 将\(t=x_1\)代入以上公式,则\(\frac{\partial x_1}{\partial t} = 1\)而\(\frac{\partial x_2}{\partial t}=0\),然后可以计算\(\frac{\partial y_1}{\partial x_1}\)和\(\frac{\partial y_2}{\partial x_1}\)
- 将\(t=x_2\)代入以上公式,则\(\frac{\partial x_1}{\partial t} = 0\)而\(\frac{\partial x_2}{\partial t}=1\),然后可以计算\(\frac{\partial y_1}{\partial x_2}\)和\(\frac{\partial y_2}{\partial x_2}\)
可以推断:
- 当有\(n\)个输入变量时(本例中有2个),需要计算\(n\)次上述公式。
- 假设神经网络中的输入是一张1280 x 720的图片,输出是51个浮点数,那么前向微分方法则需要计算921600次。
反向自动微分
我们将用到如下的链式法则:
\[\begin{align} \frac{\partial s}{\partial u} &= \sum_i \left(\frac{\partial w_i}{\partial u} \cdot \frac{\partial s}{\partial w_i}\right) \\ &= \frac{\partial w_1}{\partial u} \cdot \frac{\partial s}{\partial w_1} + \frac{\partial w_2}{\partial u} \cdot \frac{\partial s}{\partial w_2} + \cdots \end{align} \]其中:
- \(u\) 表示输入变量
- \(w_i\) 表示依赖 \(u\) 的输出变量
- \(s\) 表示有待给出的变量
回顾拆解后的简单算子计算(2):
\[\begin{aligned} x_1 &= ? \\ x_2 &= ? \\ \\ a &= x_1 \cdot x_2 \\ b &= \sin(x_1) \\ \\ c &= 4x_1 + 2x_2 \\ d &= \cos(x_2) \\ \\ m_1 &= a + b \\ m_2 &= c + d \\ \\ y_1 &= m_1 + m_2 \\ y_2 &= m_1 \cdot m_2 \end{aligned} \tag{2} \]现在计算反向微分:
\[\begin{aligned} \frac{\partial s}{\partial y_1} &= ? \\ \frac{\partial s}{\partial y_2} &= ? \\ \\ \frac{\partial s}{\partial m_1} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_1} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_1} \\ \frac{\partial s}{\partial m_2} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_2} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_2} \\ \\ \frac{\partial s}{\partial a} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial a} \\ \frac{\partial s}{\partial b} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial b} \\ \frac{\partial s}{\partial c} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial c} \\ \frac{\partial s}{\partial d} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial d} \\ \\ \frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial b}\frac{\partial b}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} \\ \frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} + \frac{\partial s}{\partial d}\frac{\partial d}{\partial x_1} \end{aligned} \]当\(s=y_1\)时:
\[\begin{aligned} \frac{\partial s}{\partial y_1} &= 1 \\ \frac{\partial s}{\partial y_2} &= 0 \\ \\ \frac{\partial s}{\partial m_1} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_1} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_1} = 1 \\ \frac{\partial s}{\partial m_2} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_2} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_2} = 1 \\ \\ \frac{\partial s}{\partial a} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial a} = 1 \\ \frac{\partial s}{\partial b} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial b} = 1 \\ \frac{\partial s}{\partial c} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial c} = 1 \\ \frac{\partial s}{\partial d} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial d} = 1 \\ \\ \frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial b}\frac{\partial b}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} = 1 \cdot x_2 + 1 \cdot \cos(x_1) + 1 \cdot 4 = x_2 + \cos(x_1) + 4 \\ \frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_2} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_2} + \frac{\partial s}{\partial d}\frac{\partial d}{\partial x_2} = 1 \cdot x_1 + 1 \cdot 2 + 1 \cdot (-\sin(x_2)) = x_1 + 2 -\sin(x_2) \end{aligned} \]同理可以计算当\(s=y_2\)时。
可以推断:
- 当有\(n\)个输出变量时(本例中有2个),需要计算\(n\)次上述公式。
- 假设神经网络中的输入是一张1280 x 720的图片,输出是51个浮点数,那么反向微分方法则需要计算51次。