首页 > 其他分享 >三、为什么扩散模型使用均方误差损失(选看)

三、为什么扩散模型使用均方误差损失(选看)

时间:2024-10-19 11:10:34浏览次数:1  
标签:误差 right mathbf mid 均方 选看 alpha theta left

高能预警:这篇文章难度很大,包含很多的数学推导,如果不想接触太多的数学内容,那么可以跳过不看。

看这篇文章之前,你需要了解:什么是马尔科夫链,什么是极大似然估计,什么是KL散度,两个正态分布的KL散度,什么是贝叶斯公式

以下内容参考了主要参考了博客What are Diffusion Models? 以及李宏毅老师的课程

目录

1. 马尔科夫链与\(p_\theta(\mathbf{x})\)

本节推导得出的结论:

  • \(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)\),\(p(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)\)
  • \(p_{\theta}(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)\)

在扩散模型中,为了方便计算,我们假设前向过程中的图片\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,并将前向过程中图片\(\mathbf{x}\)的概率分布记作\(q(\mathbf{x})\)

因此,我们有

\[q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \]

同时,我们令\(p_\theta(\mathbf{x})\)表示:在反向过程中,模型生成图片\(\mathbf{x}\)的概率。

因此,在对扩散模型使用极大似然估计时,样本是没有噪音的图片\(\mathbf{x}_0\),似然函数\(p_\theta(\mathbf{x}_0)\)表示模型最终生成\(\mathbf{x}_0\)的概率。自然的,极大似然估计的目标是找到使得\(p_\theta(\mathbf{x}_0)\)最大的模型。

注意到在反向过程中,\(\mathbf{x}_T\)是噪音图片,直接采样自标准正态分布,并不需要通过模型生成,\(p_\theta(\mathbf{x}_T)\)和模型选取无关,因此可以记作\(p(\mathbf{x}_T)\)。

由于\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此

\[p(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) \]

2. 极大似然估计

本节推导得出的结论:\(\min -\log{p_\theta(\mathbf{x}_0)}\)等价于\(\min L_T+L_{T-1}+\cdots+L_0\),其中

\[\begin{aligned} L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_T\right)\right) \\ L_t & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}\right)\right) \quad \text { for } 1 \leq t \leq T-1 \\ L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) \end{aligned} \]

上文中,我们说到,极大似然估计的目标是\(\max{p_\theta(\mathbf{x}_0)}\),为了方便起见,可以将目标转换为\(\min -\log{p_\theta(\mathbf{x}_0)}\)。

我们对\(-\log{p_\theta(\mathbf{x}_0)}\)进行一些变形,得到

\[\begin{aligned} -\log p_\theta\left(\mathbf{x}_0\right) &\le -\log p_\theta\left(\mathbf{x}_0\right)+D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)\right)\\ & =-\log p_\theta\left(\mathbf{x}_0\right)+\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}+\log p_\theta\left(\mathbf{x}_0\right)\right] \\ & =\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right]\\ \end{aligned} \]

\[-\log p_\theta\left(\mathbf{x}_0\right) \le\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \tag{1} \]

其中,\(D_{\mathrm{KL}}(q||p_\theta)\)表示分布\(q\)和分布\(p_\theta\)的KL散度;期望\(\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}(f) = \int q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \times f \ \mathrm{d} \mathbf{x_{1:T}}\) 。


下面,我们对公式(1)左右两侧同时取期望

\[\begin{aligned} \int -\log p\left(\mathbf{x_0}\right) \cdot q\left(\mathbf{x_0}\right) \mathrm{d} \mathbf{x_0} &\le \int \mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \cdot q(\mathbf{x}_0)\mathrm{d} \mathbf{x_0} \\ -\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right)&\le \iint \left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \cdot q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \cdot q(\mathbf{x}_0)\mathrm{d} \mathbf{x}_{1:T}\mathrm{d}\mathbf{x}_0\\ &=\int \left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] q(\mathbf{x}_{0: T}) \mathrm{d}\mathbf{x}_{0:T}\\ &= \mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \end{aligned} \]

\[-\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right) \le \mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \tag{2} \]

为了方便表示,我们将\(-\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right)\)记作\(L_{\mathrm{CE}}\),将\(\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right]\)记作\(L_{\mathrm{VLB}}\)。

我们只需要\(\min L_{VLB}\),即可得到\(\min -\log p(\mathbf{x}_0)\)


下面对\(L_{\mathrm{VLB}}\)进行变形

\[\begin{aligned} & L_{\mathrm{VLB}}=\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \\ & =\mathbb{E}_q\left[\log \frac{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right] \\ & =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=1}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right] \\ & =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\ & =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)} \cdot \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}\right)+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\ & =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\ & =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\ & =\mathbb{E}_q\left[\log \frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_T\right)}+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\ & =\mathbb{E}_q\left[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t=2}^T \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}\right] \end{aligned} \]

\[\min L_{\mathrm{VLB}} \rightarrow \min L_T+L_{T-1}+\cdots+L_0 \tag{3} \]

其中

\[\begin{aligned} L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_T\right)\right) \\ L_t & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}\right)\right) \quad \text { for } 1 \leq t \leq T-1 \\ L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) \end{aligned} \]


注意到,对于\(L_T\),而言,其中的两个分布\(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)\)和\(p_\theta\left(\mathbf{x}_T\right)\)的取值和模型无关,因此\(L_T\)为常数,我们只需要最小化\(L_t\)和\(L_0\)即可。

又因为\(\min D_{KL}(q(\mathbf{x}_0|\mathbf{x}_1,\mathbf{x}_0)||p_\theta(\mathbf{x}_0|\mathbf{x}_1)) \rightarrow \min D_{KL}(1||p_\theta(\mathbf{x}_0|\mathbf{x_1})) \rightarrow -\log p_\theta(\mathbf{x}_0\mid \mathbf{x}_1)\),因此,可以将\(L_0\)转换为\(L_t\)的形式,那么只需要最小化\(L_t\)即可。

3. \(L_t\)中的\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)

本节推导得出的结论:\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}_t,\tilde{\beta}_t \mathbf{I}\right)\),其中\(\tilde{\boldsymbol{\mu}}_t=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_t\right)\),\(\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t\)

使用贝叶斯公式,我们可以将\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)转换为

\[q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \]

又因为\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此\(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) = q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)\)

我们在上篇文章的末尾提到过,在前向过程中,概率\(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right)\),\(q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)\)。

于是有

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) & =q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{\beta_t}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_t^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \mathbf{x}_{t-1}+C\left(\mathbf{x}_t, \mathbf{x}_0\right)\right)\right) \end{aligned} \]

其中 \(C\left(\mathbf{x}_t, \mathbf{x}_0\right)\) 是不含 \(\mathbf{x}_{t-1}\) 的常数,因此可以被忽略。

根据正态分布的概率公式,我们可以得到

注意 \(\alpha_t=1-\beta_t\) , \(\bar{\alpha}_t=\prod_{i=1}^t \alpha_i\)

\[\tilde{\beta}_t=1 /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right)=1 /\left(\frac{\alpha_t-\bar{\alpha}_t+\beta_t}{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}\right)=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \]

\[\begin{aligned} \tilde{\boldsymbol{\mu}}_t& =\left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\alpha_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \\ & =\left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \\ & =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 \end{aligned} \]

在上一篇文章中,我们得到\(\mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}\),因此有

\[\mathbf{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon\right) \]

其中\(\boldsymbol\epsilon\)表示从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和。

我们将\(\tilde{\boldsymbol{\mu}}_t\)表达式中的\(\mathbf{x}_0\)进行替换,可以得到

\[\begin{aligned} \tilde{\boldsymbol{\mu}}_t & =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}\right) \\ & =\frac{1}{\sqrt{\alpha}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}\right) \end{aligned} \]

因此,我们有

\[q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_t\right),\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \mathbf{I}\right)\tag{4} \]

4. 最小化\(L_t\)

本节推导得出的结论:最小化\(L_t\)等价于最小化\(\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2\)。

其中,\(\boldsymbol\epsilon\)表示从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和,\(\boldsymbol{\epsilon}_\theta\)表示预测噪音的模型,模型有两个输入:\(t\)时刻的图片\(\mathbf{x}_t\)以及时刻\(t\)

在上一小节,我们推出:\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)符合正态分布。又由于\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) = q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\),也就是说\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\)符合正态分布。

我们的目的是让反向过程尽可能和正向过程一致。因此我们可以合理假设,在反向过程中,\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\)也符合正态分布,并且和\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)的分布近似。

设\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right)\),因为\(\tilde{\boldsymbol{\sigma}}_t^2\)为常数,因此我们直接令\(\mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)= \tilde{\boldsymbol{\sigma}}_t^2\)。同时,我们还要尽可能的令\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)接近\(\tilde{\boldsymbol{\mu}}_t\)。

注意到,\(\tilde{\boldsymbol{\mu}}_t\)里面唯一一个,在反向过程中不知道的量就是从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和 \(\boldsymbol\epsilon\),因此我们可以训练一个模型来预测\(\boldsymbol\epsilon\)。

这个模型就是我们在第一篇文章中提到的Noise Predicter,我们将Noise Predicter记作\(\boldsymbol{\epsilon}_\theta\),它有两个输入:\(t\)时刻的图片\(\mathbf{x}_t\)以及时刻\(t\)。模型的预测值记作\(\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\)

因此,

\[\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) =\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right) \]

对于KL散度,我们有以下性质:

若有两个正态分布\(P\) ,\(Q\),均值分别为\(\mu_1\),\(\mu_2\);方差分别为\(\sigma_1^2\),\(\sigma_2^2\),且\(\sigma_1^2\),\(\sigma_2^2\)都为常数,那么

\[\min D_{KL}(P||Q) \rightarrow \min ||\mu_1-\mu_2||^2 \]

因此,

\[\begin{aligned} \min L_t &\rightarrow \min||\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2\\ &\rightarrow \min\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2 \end{aligned} \tag{5} \]

5. 总结

至此,我们完成了使用极大似然估计来推导损失函数的过程。

我们得到的结论是

\(\min -\log{p_\theta(\mathbf{x}_0)}\)等价于\(\min L_T+L_{T-1}+\cdots+L_0\),其中

\(\begin{aligned} L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_T\right)\right) \\ L_t & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_t \mid \mathbf{x}_{t+1}\right)\right) \quad \text { for } 1 \leq t \leq T-1 \\ L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) \end{aligned}\)

其中\(L_T\)可以忽略,\(L_0\)可以转换为\(L_t\)的形式,而最小化\(L_t\)又相当于最小化\(\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2\)。

因此我们知道:损失函数就是均方误差损失。

标签:误差,right,mathbf,mid,均方,选看,alpha,theta,left
From: https://www.cnblogs.com/rh-li/p/18475647

相关文章