Denoising Diffusion Probabilistic Models去噪扩散模型(DDPM)
2024/2/28
论文链接:Denoising Diffusion Probabilistic Models(neurips.cc)
这篇文章对DDPM写个大概,公式推导会放在以后的文章里。
一、引言 Introduction
各类深度生成模型在多种数据模态上展示了高质量的样本。生成对抗网络(GANs)、自回归模型、流模型和变分自编码器(VAEs)已经合成了引人注目的图像和音频样本。此外,在基于能量的建模和得分匹配方面也取得了显著进展,生成的图像与GANs生成的图像相当。
扩散概率模型是一个参数化马尔科夫链,使用变分推断(Variational Inference)进行训练,以便在有限时间内产生于数据相匹配的样本。这个链的转移是学习来逆转扩散过程的,扩散过程是一种马尔可夫链,它逐渐向与采样相反的方向添加噪声到数据中,直到信号被破坏。当扩散包含的是少量的高斯噪声时,只需将采样链转移设置为条件高斯分布,这样就可以实现一个特别简单的神经网络参数化。
变分推断(Variational Inference):这是一种用于估计概率模型参数的统计方法。它通过优化一个目标函数来近似真实的后验分布,这个目标函数通常是真实后验分布与一个易于计算的分布(变分分布)之间的差异。
流模型(Flows):流模型是一种生成模型,它通过一系列可逆的变换(称为流)将数据从高维空间映射到低维空间,然后再映射回高维空间,以生成新的数据样本。流模型的优势在于其变换是可逆的,这有助于保持数据的多样性。
能量基建模(Energy-based Modeling):这是一种基于能量函数的建模方法,通常用于二分类问题。能量函数定义了输入数据与特定标签的不匹配程度。在图像生成的背景下,能量基模型可以用来评估和改进生成图像的质量。
得分匹配(Score Matching):这是一种用于训练生成模型的技术,特别是在概率密度估计中。它涉及计算真实数据分布的得分函数,并使生成模型的得分函数与之匹配,以此来提高生成样本的质量。
二、模型具体细节
扩散是指物质粒子从高浓度区域向低浓度区域移动的过程,扩散模型的灵感来自非平衡热力学,扩散模型想做的就是通过向图片中加入高斯噪声模拟这个过程,最后通过逆向过程从随机噪声中生成图片。
2.1 前向加噪
我们需要进行随机采样生成和图片尺寸大小相同的噪声图片。噪声图片中所有通道数值遵从正态分布。我们根据\(T\)步将生成的噪声图片与原图片进行混合,每一步的混合方式满足以下公式:
\[\begin{aligned}\sqrt{\beta}\times\epsilon+\sqrt{1-\beta}\times x\end{aligned} \]其中,\(x\)为原始图片,\(\epsilon\)是高斯噪声,\(\beta\)是一个介于[0.0,1.0]之间的数字,用于产生\(x\)和\(\epsilon\)前的系数。
我们输入\(x_0\)套用公式后我们得到了\(x_1\):
\[x_1=\sqrt{\beta_1}\times\epsilon_1+\sqrt{1-\beta_1}\times x_0 \]输入\(x_1\)套用公式后我们得到了\(x_2\):
\[x_2=\sqrt{\beta_2}\times\epsilon_2+\sqrt{1-\beta_2}\times x_1 \]......
以此类推,我们可以得到前一时刻与后一时刻的关系:
\[x_t=\sqrt{\beta_t}\times\epsilon_t+\sqrt{1-\beta_t}\times x_{t-1} \]其中\(\epsilon_t\)都是基于标准正态分布重新采样的随机数,而其中的\(\beta_t\)是从一个接近0的数字逐步递增,最后趋近于1,\(0<\beta_1<\beta_2<\beta_3<\beta_{t-1}<\beta_t<1\).
有:
\(q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I})\)
随着步长\(t\)增加,原来的样本\(x_0\)的特征变得不可区分。当$T\to\infty \(时,\)\mathbf{x}_T$等价于各相同性高斯分布。
过程如上图所示,上诉过程有一个很好的特性,可以使用重参数化技巧(reparameterization trick)(参见VAE),在任何任意时间步长\(t\)上采样\(x_t\)。
为了简化后续的推导,我们引入一个新变量\(\alpha_t=1-\beta_t\),上诉公式变为:
\[x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \]接下来需要思考的是通过公式能否使\(x_0\)直接得到\(x_T\),我们从
\[x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \]向后推,得到:
\[\begin{aligned} \mathbf{x}_{t}& =\sqrt{\alpha_t}\mathbf{x}_{t-1}+\sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} & ;\text{其中, }\boldsymbol{\epsilon}_{t-1},\boldsymbol{\epsilon}_{t-2},\cdots\sim\mathcal{N}(\mathbf{0},\mathbf{I}) \\ &=\sqrt{\alpha_t\alpha_{t-1}}\mathbf{x}_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\boldsymbol{\epsilon}}_{t-2}& ;\text{其中, }\bar{\boldsymbol{\epsilon}}_{t-2}\text{ 合并两个高斯量 }(*). \\ &=\ldots \\ &=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon} \end{aligned}\]其中,\(\bar{\alpha}_t=\prod_{i=1}^t\alpha_i\)
\((*)\)当我们合并两个具有不同方差的高斯量\(\mathcal{N}(\mathbf{0},\sigma_1^2\mathbf{I})\)和\(\mathcal{N}(\mathbf{0},\sigma_2^2\mathbf{I})\)时,新的分布是\(\mathcal{N}(\mathbf{0},(\sigma_1^2+\sigma_2^2)\mathbf{I})\),这里合并的标准差是\(\sqrt{(1-\alpha_t)+\alpha_t(1-\alpha_{t-1})}=\sqrt{1-\alpha_t\alpha_{t-1}}\)
经过推导我们可以得到公式:
\(\begin{aligned}x_t=\sqrt{1-\bar{\alpha}_t}\times\epsilon+\sqrt{\bar{\alpha}_t}\times x_0\end{aligned}\)
通常,当样本变得更嘈杂时,我们可以承受更大的更新步骤,因此
\(\begin{aligned}\beta_1<\beta_2<\cdots<\beta_T\end{aligned}\)
\(\bar{\alpha}_1>\cdots>\bar{\alpha}_T\)。
2.2 反向过程
反向过程的目的是将有噪声的图片恢复成原始图片,如果我们可以反转上述过程,从\(q(\mathbf{x}_{t-1}|\mathbf{x}_t)\)中采样,将可以从高斯噪声中生成图片。因为前向加噪是一个随机过程,所以反向过程也是一个随机过程,所以我们可以用\(P(x_{t-1}|x_t)\)表示在给定\(x_t\)的情况下,前一时刻\(x_{t-1}\)的概率,根据贝叶斯公式有:
\(P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},x_0)=\frac{P(\boldsymbol{x}_t|\boldsymbol{x}_{t-1},x_0)P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{P(\boldsymbol{x}_t|\boldsymbol{x}_0)}\)
根据公式:
\(\begin{gathered} x_t=\sqrt{1-\alpha_t}\times\epsilon_t+\sqrt{\alpha_t}\times x_{t-1} \\ x_t=\sqrt{1-\bar{\alpha}_t}\times\epsilon+\sqrt{\bar{\alpha}_t}\times x_0 \end{gathered}\)
我们可以得到\(x_t\)是分别满足\(N(\sqrt{\alpha_t}x_{t-1},1-\alpha_t)\)和\(N(\sqrt{\bar{\alpha}_t}x_0,1-\bar{\alpha}_t)\)的正态分布(因为噪声\(\epsilon\)是满足高斯分布的),\(x_{t-1}\)是满足\(N(\sqrt{\bar{\alpha}_{t-1}}x_0,1-\bar{\alpha}_{t-1})\)的正态分布。我们可以将上式改为:
\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t,\right. & \left.\mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \\ & \left(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_t ; \sqrt{\alpha_t} \mathbf{x}_{t-1},\left(1-\alpha_t\right) \mathbf{I}\right)\right) \\ & \left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_{t-1} ; \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0,\left(1-\bar{\alpha}_{t-1}\right) \mathbf{I}\right)\right) \\ & \left(q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) \sim \mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)\right) \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{\beta_t}+\frac{\left(\mathbf{x}_{l-1}-\sqrt{\bar{\alpha}_{l-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t 1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_l^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_l}{\beta_t}+\frac{1}{1-\bar{\alpha}_t}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_l}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_t} \mathbf{x}_0\right) \mathbf{x}_{l-1}+C\left(\mathbf{x}_l, \mathbf{x}_0\right)\right)\right) \end{aligned}\]其中\(C(\mathbf{x}_t,\mathbf{x}_0)\)不涉及\(\mathbf{x}_{t-1}\)某些功能,省略了详细信息。
从中我们可以得知\(P(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},x_0)\)是满足\(\begin{aligned}\boldsymbol{N}\left(\frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t+\frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}\times\frac{x_t-\sqrt{1-\bar{a}_t}\times\epsilon}{\sqrt{\bar{a}_t}},\left(\color{}{\sqrt{\frac{\beta_t(1-\bar{a}_{t-1})}{1-\bar{a}_t}}}\right)^2\right)\end{aligned}\)
这里只要我们知道了\(\epsilon\)就可以知道前一个时刻的图像,这里我们训练一个神经网络模型,来预测此图像相对于\(x_0\)原图所加入的噪声。
根据实验可知,\(x_T\)是一任何张满足标准正态分布的噪声图片。我们使用标准正态分布随机采样就能得到\(x_T\)。
反向过程通过\(T\)步从\(p(x_T)=\mathcal{N}(x_T;\mathbf{0},\mathbf{I})\)开始的噪声。
\[\begin{aligned} \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)& =\mathcal{N}\left(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),{\Sigma_\theta(x_t,t)}\right) \\ \textcolor{lightgreen}{p_\theta}(x_{0:T})& =\textcolor{lightgreen}{p_\theta}(x_T)\prod_{t=1}^T\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t) \\ \textcolor{lightgreen}{p_\theta}(x_0)& =\int\textcolor{lightgreen}{p_\theta}(x_{0:T})dx_{1:T} \end{aligned}\]其中\(\color{lightgreen}{\theta}\)是我们训练的参数。
2.3 Loss损失
文中对负对数似然上优化了ELBO(来自琴生不等式)
\[\begin{gathered} \mathbb{E}[-\log \textcolor{lightgreen}{p_\theta}(x_0)] \leq\mathbb{E}_q[-\log\frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)}] \\ =L \end{gathered}\]损失可以按如下方式重写:
\[\begin{aligned} \text{L}& =\mathbb{E}_q[-\log\frac{\textcolor{lightgreen}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)}] \\ &=\mathbb{E}_q[-\log p(x_T)-\sum_{t=1}^T\log\frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})}] \\ &=\mathbb{E}_q[-\log\frac{p(x_T)}{q(x_T|x_0)}-\sum_{t=2}^T\log\frac{\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}-\log\textcolor{lightgreen}{p_\theta}(x_0|x_1)] \\ &=\mathbb{E}_q[D_{KL}(q(x_T|x_0)||p(x_T))+\sum_{t=2}^TD_{KL}(q(x_{t-1}|x_t,x_0)||{\textcolor{lightgreen}{p_\theta}}(x_{t-1}|x_t))-\log \textcolor{lightgreen}{p_\theta}(x_0|x_1)] \end{aligned}\]因为我们保持\(\beta_1,\ldots,\beta_T\)恒定,所以\(D_{KL}(q(x_T|x_0)||p(x_T))\)也是恒定的。
2.4 计算 \(D_{KL}(q(x_{t-1}|x_t,x_0)\|\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t))\)
在给定初始\(x_0\)的条件下,前向过程的后验概率为:
\[\begin{aligned} q(x_{t-1}|x_t,x_0)& =\mathcal{N}\left(x_{t-1};\tilde{\mu}_t(x_t,x_0),\tilde{\beta}_t\mathbf{I}\right) \\ \tilde{\mu}_t(x_t,x_0)& =\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha_t}}x_0+\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha_t}}x_t \\ \tilde{\beta}_{t}& =\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha_t}}\beta_t \end{aligned}\]论文中设置\(\textcolor{lightgreen}{\Sigma_\theta}(x_t,t)=\sigma_t^2\mathbf{I}\),其中\(\sigma_t^2\)设置为常量\(\beta_t\)或\(\tilde{\beta_t}\)。
然后,
\[\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),\sigma_t^2\mathbf{I}) \]对于给定的噪声\(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\),使用\(q(x_t|x_0)\)
\[\begin{aligned} x_t(x_0,\epsilon)& =\sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha_t}}\epsilon \\ {x_0}& =\frac1{\sqrt{\bar{\alpha}_t}}\Big(x_t(x_0,\epsilon)-\sqrt{1-\bar{\alpha}_t}\epsilon\Big) \end{aligned}\]这里,
\[\begin{aligned} L_{t-1}& =D_{KL}(q(x_{t-1}|x_t,x_0)\|\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)) \\ &=\mathbb{E}_q\left[\frac1{2\sigma_t^2}\left\|\tilde{\mu}(x_t,x_0)-\textcolor{lightgreen}{\mu_\theta}(x_t,t)\right\|^2\right] \\ &=\mathbb{E}_{x_0,\epsilon}\left[\frac1{2\sigma_t^2}\left\|\frac1{\sqrt{\alpha_t}}\left(x_t(x_0,\epsilon)-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\epsilon\right)-\textcolor{lightgreen}{\mu_\theta}(x_t(x_0,\epsilon),t)\right\|^2\right] \end{aligned}\]使用模型重新参数化以预测噪声
\[\begin{gathered} \textcolor{lightgreen}{\mu_\theta}(x_t,t) =\tilde{\mu}\left(x_t,\frac1{\sqrt{\bar{\alpha}_t}}\left(x_t-\sqrt{1-\bar{\alpha}_t}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\right)\right) \\ =\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big) \end{gathered}\]其中 \(\epsilon_\mathrm{\theta}\) 是预测 其中 \(\epsilon_\mathrm{\theta}\) 是预测 \(\epsilon\) 给定 \((x_t,t)\) 的学习函数。
这里给定,
\[L_{t-1}=\mathbb{E}_{x_0,\epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)}\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right] \]用来训练预测噪声。
2.5 简化损失
\[L_{\mathrm{simple}}(\theta)=\mathbb{E}_{t,x_0,\epsilon}\left[\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right] \]这在\(t=1\)时最小化\(-\log\textcolor{lightgreen}{p_\theta}(x_0|x_1)\),并且在\(t>1\)时最小化\(L_{t-1}\),同时丢弃\(L_{t-1}\)中的权重。
丢弃权重\(\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha_t})}\)会增加给予更高 t (具有更高噪声水平) 的权重,从而提高样本质量。
三、代码实现
Denoise Diffusion 降噪扩散
1. 代码解析
1. 初始化
注意:以下代码块都是在DenoiseDiffusion
类中
eps_model
是\(\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\)模型
n_steps
是\(t\)
device
是放置常量的设备
class DenoiseDiffusion:
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
super().__init__()
self.eps_model = eps_model
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.n_steps = n_steps
self.sigma2 = self.beta
为了方便代码理解,这里将class DenoiseDiffusion
拆分进行解释,理解代码每一步在做什么。
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
这里是生成了一个tensor,该tensor包含n_steps
个数据,包含从 0.0001
到 0.02
的等间隔数值,代表了公式中的 \(\beta_1,\ldots,\beta_T\)
self.alpha = 1. - self.beta
代表 \(\alpha_t=1-\beta_t\)
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
代表 \(\bar{\alpha_t}=\prod_{s=1}^t\alpha_s\)
self.n_steps = n_steps
代表 \(T\)
self.sigma2 = self.beta
代表 $\sigma^2=\beta $
2. 获取\(q(x_t|x_0)\)分布
关于公式 \(q(x_t|x_0)=\mathcal{N}\Big(x_t;\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)\mathbf{I}\Big)\) 的代码实现
#该函数返回一个包含两个张量的元组
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
gather
这个操作会根据 t
中的索引从 self.alpha_bar
中提取元素。t
是索引张量,包含了要提取的元素的索引。
mean = gather(self.alpha_bar, t) ** 0.5 * x0
计算 \(\sqrt{\bar{\alpha}_t}x_0\)
var = 1 - gather(self.alpha_bar, t)
计算 \((1-\bar{\alpha}_t)\mathbf{I}\)
3. 来自\(q(x_t|x_0)\)的样本
关于公式 \(q(x_t|x_0)=\mathcal{N}\Big(x_t;\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)\mathbf{I}\Big)\) 的代码实现
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps
上述代码中if eps is None:
所包含的内容代表 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)
mean, var = self.q_xt_x0(x0, t)
代表获取 \(q(x_t|x_0)\)
最后返回来自 \(q(x_t|x_0)\) 的样本
4. 来自\(\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)\)的样本
这段代码实现公式
\[\begin{aligned} \textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)& =\mathcal{N}\left(x_{t-1};\textcolor{lightgreen}{\mu_\theta}(x_t,t),\sigma_t^2\mathbf{I}\right) \\ \textcolor{lightgreen}{\mu_\theta}(x_t,t)& =\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big) \end{aligned}\] def p_sample (self, xt: torch.Tensor, t: torch.Tensor):
eps_theta = self.eps_model(xt, t)
alpha_bar = gather(self.alpha_bar, t)
alpha = gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** .5) * eps
上述代码中,eps_theta = self.eps_model(xt, t)
表示\(\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\)。
alpha_bar = gather(self.alpha_bar, t)
是在收集\(\bar{\alpha}_t\)
alpha = gather(self.alpha, t)
表示\(\alpha_{t}\)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
表示\(\frac\beta{\sqrt{1-\overline{\alpha}t}}\)
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
计算的是\(\frac1{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\textcolor{lightgreen}{\epsilon_\theta}(x_t,t)\Big)\)
var = gather(self.sigma2, t)
表示的是\(\sigma^2\)
eps = torch.randn(xt.shape, device=xt.device)
代表 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)
最后return mean + (var ** .5) * eps
返回样本。
5. 简化损失
这段代码实现的是 \(L_{\mathrm{simple}}(\theta)=\mathbb{E}_{t,x_0,\epsilon}\left[\left\|\epsilon-\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t)\right\|^2\right]\) 公式
def loss(self, x0: Tensor, noise: Optional[torch.Tensor] = None):
batch_size - x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if noise is None:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, eps=noise)
eps_theta = self.eps_model(xt, t)
return F.mse_loss(noise, eps_theta)
上述代码中,batch_size - x0.shape[0]
是为了获取批量大小。
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
是对批次中的每个样品得到随机的 \(t\)
if noise is None:
中的代表着 \(\epsilon\sim\mathcal{N}(\mathbf{0},\mathbf{I})\)
xt = self.q_sample(x0, t, eps=noise)
中xt
是\(q(x_t|x_0)\)中得到的样本。
eps_theta = self.eps_model(xt, t)
是获取公式 \(\textcolor{lightgreen}{\epsilon_\theta}(\sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha_t}}\epsilon,t)\)
最后return F.mse_loss(noise, eps_theta)
返回MSE损失。
2. 完整代码
下面是完整的Denoise Diffusion代码
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
from labml_nn.diffusion.ddpm.utils import gather
class DenoiseDiffusion:
"""
## Denoise Diffusion
"""
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
super().__init__()
self.eps_model = eps_model
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.n_steps = n_steps
self.sigma2 = self.beta
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
mean = gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - gather(self.alpha_bar, t)
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
eps_theta = self.eps_model(xt, t)
alpha_bar = gather(self.alpha_bar, t)
alpha = gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** .5) * eps
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if noise is None:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, eps=noise)
eps_theta = self.eps_model(xt, t)
return F.mse_loss(noise, eps_theta)
参考文献
[1].Diffusion Models 10 篇必读论文(1)DDPM - 知乎 (zhihu.com)
[2].去噪扩散模型
[3].[What are Diffusion Models? | Lil'Log (lilianweng.github.io)](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#:~:text=Diffusion models are inspired by,data samples from the noise)
标签:Diffusion,bar,Models,Denoising,self,sqrt,mathbf,theta,alpha From: https://www.cnblogs.com/TTS-TTS/p/18063486