首页 > 其他分享 >diffusion model(一):DDPM技术小结 (denoising diffusion probabilistic)

diffusion model(一):DDPM技术小结 (denoising diffusion probabilistic)

时间:2024-05-24 10:28:58浏览次数:26  
标签:diffusion dim frac denoising self sqrt overline alpha probabilistic

发布日期:2023/05/18
主页地址:http://myhz0606.com/article/ddpm

1 从直觉上理解DDPM

在详细推到公式之前,我们先从直觉上理解一下什么是扩散

对于常规的生成模型,如GAN,VAE,它直接从噪声数据生成图像,我们不妨记噪声数据为\(z\),其生成的图片为\(x\)

对于常规的生成模型

学习一个解码函数(即我们需要学习的模型)p,实现 \(p(z)=x\)

\[z \stackrel{p} \longrightarrow x \tag{1} \]

常规方法只需要一次预测即能实现噪声到目标的映射,虽然速度快,但是效果不稳定。

常规生成模型的训练过程(以VAE为例)

\[x \stackrel{q} \longrightarrow z \stackrel{p} \longrightarrow \widehat{x} \tag{2} \]

对于diffusion model

它将噪声到目标的过程进行了多步拆解。不妨假设一共有\(T+1\)个时间步,第\(T\)个时间步 \(x_T\)是噪声数据,第0个时间步的输出是目标图片\(x_0\)。其过程可以表述为:

\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{3} \]

对于DDPM它采用的是一种自回归式的重建方法,每次的输入是当前的时刻及当前时刻的噪声图片。也就是说它把噪声到目标图片的生成分成了T步,这样每一次的预测相当于是对残差的预测。优势是重建效果稳定,但速度较慢。

训练整体pipeline包含两个过程

2 diffusion pipeline

2.1前置知识:

高斯分布的一些性质

(1)如果\(X \sim \mathcal{N}(\mu, \sigma^2)\),且\(a\)与\(b\)是实数,那么\(aX+b \sim \mathcal{N}(a\mu+b, (a\sigma)^2)\)

(2)如果\(X \sim \mathcal{N}(\mu(x), \sigma^2(x)) ,Y \sim \mathcal{N}(\mu(y), \sigma^2(y)),\)且\(X,Y\)是统计独立的正态随机变量,则它们的和也满足高斯分布(高斯分布可加性).

\[X+Y \sim \mathcal{N}(\mu(x)+\mu{(y), \sigma^2(x) + \sigma^2(y)}) \\ X-Y \sim \mathcal{N}(\mu(x)-\mu{(y), \sigma^2(x) + \sigma^2(y)}) \tag{4} \]

均值为\(\mu\)方差为\(\sigma\)的高斯分布的概率密度函数为

\[\begin{align*} f(x) &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left ({- \frac{(x - \mu)^2}{2\sigma^2}} \right) \\ &= \frac{1}{\sqrt{2\pi} \sigma } \exp \left[ -\frac{1}{2} \left( \frac{1}{\sigma^2}x^2 - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2} \right ) \right] \tag{5} \end{align*} \]

2.2 加噪过程

1 前向过程:将图片数据映射为噪声

每一个时刻都要添加高斯噪声,后一个时刻都是由前一个时刻加噪声得到。(其实每一个时刻加的噪声就是训练所用的标签)。即

\[x_0 \stackrel{q} \longrightarrow x_1 \stackrel{q} \longrightarrow x_{2} \stackrel{q} \longrightarrow \cdots \stackrel{q} \longrightarrow x_{T-1} \stackrel{q} \longrightarrow x_T=z \tag{6} \]

下面我们详细来看

记\(\beta_t = 1 - \alpha_t,\beta_t\)随\(t\)的增加而增大(论文中[2]从0.0001 -> 0.02) (这是因为一开始加一点噪声就很明显,后面需要增大噪声的量才明显).DDPM将加噪声过程建模为一个马尔可夫过程\(q(x_{1:T}|x_0):= \prod \limits_{t=1}^Tq(x_t|x_{t-1}) ,\)其中\(q(x_t|x_{t-1}):=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I})\)

\[\begin{align*} x_t &= \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \\ &= \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t \tag{7} \end{align*} \]

\(x_t\)为在t时刻的图片,当\(t=0\)时为原图;\(z_t\)为在t时刻所加的噪声,服从标准正态分布\(z_t \sim \mathcal{N}(0, \textbf{I});\)\(\alpha_t\)是常数,是自己定义的变量;从上式可见,随着\(T\)增大,\(x_t\)越来越接近纯高斯分布.

同理:

\[x_{t-1} = \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1} \tag{8} \]

将式(8)代入式(7)可得:

\[\begin{align*} x_t &= \sqrt{\alpha_t} (\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}}z_{t-1}) + \sqrt{1 - \alpha_t}z_t \\ &= \sqrt{\alpha_t \alpha_{t-1}}x_{t-2} + (\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} + \sqrt{1 - \alpha_t}z_t) \tag{9} \end{align*} \]

由于\(z_{t-1}\)服从均值为0,方差为1的高斯分布(即标准正态分布),根据定义\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1}\)服从的是均值为0,方差为\(\alpha_t (1 - \alpha_{t-1})\)的高斯分布.即\(\sqrt{\alpha_t (1 - \alpha_{t-1})} z_{t-1} \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1})\textbf{I})\).同理可得\(\sqrt{1 - \alpha_t}z_t \sim \mathcal{N}(0, (1 - \alpha_t)\textbf{I})\).则(高斯分布可加性,可以通过定义推得,不赘述)

\[ (\sqrt{\alpha_t (1 - \alpha_{t-1})} , z_{t-1} + \sqrt{1 - \alpha_t}z_t) \sim \mathcal{N}(0, \alpha_t (1 - \alpha_{t-1}) + 1 - \alpha_t) = \mathcal{N}(0, 1 - \alpha_t \alpha_{t-1}) \tag{10} \]

我们不妨记\(\overline{z}_{t-2} \sim \mathcal{N}(0, \textbf{I}),\)则\(\sqrt{1 - \alpha_t \alpha_{t-1}} \overline{z}_{t-2} \sim \mathcal{N}(0, (1 - \alpha_t \alpha_{t-1})\textbf{I})\)则式(10)最终可改写为

\[x_t = \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \overline{z}_{t-2} \tag{11} \]

通过递推,容易得到

\[\begin{align*} x_t &= \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1} x_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \dots \alpha_1} \overline{z}_0 \\ &= \sqrt{\prod_{i=1}^{t} {\alpha_i}}x_0 + \sqrt{1 - \prod_{i=1}^{t} {\alpha_i}} \overline {z}_0 \\ &\stackrel{\mathrm{令} \overline{\alpha}_{t} = \prod_{i=1}^{t} {\alpha_i}} = \sqrt{\overline{\alpha}_{t}}x_0+\sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0} \tag{12} \end{align*} \]

其中\(\overline{z}_{0} \sim \mathcal{N}(0, \mathrm{I}),x_0\)为原图.从式(13)可见,我们可以从\(x_0\)得到任意时刻的\(x_t\)的分布,而无需按照时间顺序递推!这极大提升了计算效率.

\[\begin{align*} q(x_t|x_0) &= \mathcal{N}(x_t; \mu{(x_t, t)},\sigma^2{(x_t, t)}{}\textbf{I}) \\ &= \mathcal{N}(x_t; \sqrt{\overline{\alpha}_{t}}x_0,(1 - \overline{\alpha}_{t})\textbf{I}) \tag{13} \end{align*} \]

⚠️加噪过程是确定的,没有模型的介入. 其目的是制作训练时标签

2.3 去噪过程

给定\(x_T\)如何求出\(x_0\)呢?直接求解是很难的,作者给出的方案是:我们可以一步一步求解.即学习一个解码函数\(p\),这个\(p\)能够知道\(x_{t}\)到\(x_{t-1}\)的映射规则.如何定义这个\(p\)是问题的关键.有了\(p\),只需从\(x_{t}\)到\(x_{t-1}\)逐步迭代,即可得出\(x_0\).

\[z = x_T \stackrel{p} \longrightarrow x_{T-1} \stackrel{p} \longrightarrow \cdots \stackrel{p} \longrightarrow x_{1} \stackrel{p} \longrightarrow x_0 \tag{14} \]

去噪过程是加噪过程的逆向.如果说加噪过程是求给定初始分布\(x_0\)求任意时刻的分布\(x_t\),即\(q(x_t|x_0)\)那么去噪过程所求的分布就是给定任意时刻的分布\(x_t\)求其初始时刻的分布\(x_0,\)即\(p(x_0|x_t)\) ,通过马尔可夫假设,可以对上述问题进行化简

\[\begin{align*} p(x_0|x_t) &= p(x_0|x1)p(x1|x2)\cdots p(x_{t-1}| x_t) \\ &= \prod_{i=0}^{t-1}{p(x_i|x_{i+1})} \tag{15} \end{align*} \]

如何求\({p(x_{t-1}|x_{t})}\)呢?前面的加噪过程我们大力气推到出了\({q(x_{t}|x_{t-1})},\)我们可以通过贝叶斯公式把它利用起来

\[p(x_{t-1}|x_t) = \frac{p(x_{t}|x_{t-1})p(x_{t-1})}{p(x_t)} \tag{16} \]

⚠️这里的(去噪)\(p\)和上面的(加噪)\(q\)只是对分布的一种符号记法。

有了式(17)还是一头雾水,\(p(x_t)\)和\(p(x_{t-1})\)都不知道啊!该怎么办呢?这就要借助模型的威力了.下面来看如何构建我们的模型.

延续加噪过程的推导\(p(x_t|x_0)\)和\(p(x_{t-1}|x_0)\)我们是可以知道的.因此若我们知道初始分布\(x_0\),则

\[\begin{align*} p(x_{t-1}|x_t,x_0) &= \frac{p(x_{t}|x_{t-1}, x_0)p(x_{t-1}|x_0)}{p(x_t|x_0)} &(17) \\ &= \frac{\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1 - \alpha_t) \textbf{I} ) \mathcal{N}(x_{t-1}; \sqrt{\overline{\alpha}_{t-1}}x_0,(1 - \overline{\alpha}_{t-1}) \textbf{I})} { \mathcal{N}(x_t; \sqrt{\overline{\alpha}_{t}}x_0,(1 - \overline{\alpha}_{t}) \textbf{I} )} &(18) \\ &\stackrel{将式(5)代入} \propto \frac{ \exp \left ({- \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{2 (1 - \alpha_t)}} \right) \exp \left ({- \frac{(x_{t-1} - \sqrt{\overline{\alpha}_{t-1}}x_0 )^2}{2 (1 - \overline{\alpha}_{t-1})}} \right) } { \exp \left ({- \frac{(x_{t} - \sqrt{\overline{\alpha}_{t}}x_0 )^2}{2 (1 - \overline{\alpha}_{t})}} \right) } &(19) \\ &= \exp \left [-\frac{1}{2} \left ( \frac{(x_t - \sqrt{\alpha_t}x_{t-1} )^2}{1 - \alpha_t} + \frac{(x_{t-1} - \sqrt{\overline{\alpha}_{t-1}}x_0 )^2}{1 - \overline{\alpha}_{t-1}} - \frac{(x_{t} - \sqrt{\overline{\alpha}_{t}}x_0 )^2}{1 - \overline{\alpha}_{t}} \right) \right] &(20) \\ &= \exp \left [ -\frac{1}{2} \left( \left( \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}_{t-1}} \right)x^2_{t-1} - \left ( \frac{2\sqrt{\overline{\alpha_{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}_{t-1}}} {1 - \overline{\alpha}_{t-1} }x_0 \right)x_{t-1} + C(x_t, x_0) \right) \right] &(21) \end{align*} \]

结合高斯分布的定义(6)来看式(22),不难发现\(p(x_{t-1}|x_t,x_0)\)也是服从高斯分布的.并且结合式(6)我们可以求出其方差和均值

⚠️式17做了一个近似\(p(x_t|x_{t-1}, x_0) =p(x_t| x_{t-1}),\)能做这个近似原因是一阶马尔科夫假设,当前时间点只依赖前一个时刻的时间点.

\[\begin{align*} \frac{1}{\sigma_2} &= \frac{\alpha_t}{1-\alpha_t} + \frac{1}{1 - \overline{\alpha}_{t-1}} &(22) \\ \frac{2\mu}{\sigma^2} &= \frac{2\sqrt{\overline{\alpha_{t}}}}{1 - \alpha_t}x_t + \frac{2 \sqrt{\overline{\alpha}_{t-1}}} {1 - \overline{\alpha}_{t-1} }x_0 &(23) \end{align*} \]

可以求得:

\[\begin{align*} \sigma^2 &= \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_{t}} (1 - \alpha_t) \\ \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t}x_0 \tag{24} \end{align*} \]

通过上式,我们可得

\[p(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t}x_0 , (\frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_{t}} (1 - \alpha_t)) \textbf{I}) \tag{25} \]

该式是真实的条件分布.我们目标是让模型学到的条件分布\(p_\theta(x_{t-1}|x_t)\)尽可能的接近真实的条件分布\(p(x_{t-1}|x_t, x_0).\)从上式可以看到方差是个固定量,那么我们要做的就是让\(p(x_{t-1}|x_t, x_0)\)与\(p_\theta(x_{t-1}|x_t)\)的均值尽可能的对齐,即

(这个结论也可以通过最小化上述两个分布的KL散度推得)

\[\mathrm{arg} \mathop{min}_\theta \parallel u(x_0, x_t), u_\theta(x_t, t) \parallel \tag{26} \]

下面的问题变为:如何构造\(u_\theta(x_t, t)\)来使我们的优化尽可能的简单

我们注意到\(\mu(x_0, x_t)与\mu_\theta(x_t, t)\)都是关于\(x_t\)的函数,不妨让他们的\(x_t\)保持一致,则可将\(\mu_\theta(x_t, t)\)写成

\[\mu_\theta(x_t, t) = \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t} f_\theta(x_t, t) \tag{27} \]

\(f_\theta(x_t, t)\)是我们需要训练的模型.这样对齐均值的问题就转化成了: 给定\(x_t, t\)来预测原始图片输入\(x_0.\)根据上文的加噪过程,我们可以很容易制造训练所需的数据对! (Dalle2的训练采用的是这个方式).事情到这里就结束了吗?

DDPM作者表示直接从\(x_t\)到\(x_0\)的预测数据跨度太大了,且效果一般.我们可以将式(12)做一下变形

\[\begin{align*} x_t &= \sqrt{\overline{\alpha}_{t}}x_0+\sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0} \\ x_0 &= \frac{1}{\sqrt{\overline{\alpha}_{t}}}(x_t - \sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0}) \tag{28} \end{align*} \]

代入到式(24)中

\[\begin{align*} \mu &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{\sqrt{\overline{\alpha}_{t-1}} (1 - \alpha_t) }{1 - \overline{\alpha}_t} \frac{1}{\sqrt{\overline{a}_{t}}}(x_t - \sqrt{1 - \overline{a}_{t}}\overline{z}_{0}) \\ &= \frac{\sqrt{\alpha_t} (1 - \overline{\alpha}_{t-1})} {1 - \overline{\alpha}_t}x_t + \frac{(1 - \alpha_t) }{1 - \overline{\alpha}_t} \frac{1}{\sqrt{\alpha}_{t}}(x_t - \sqrt{1 - \overline{\alpha}_{t}}\overline{z}_{0}) \\ &\stackrel{合并x_t} = \frac{\alpha_t(1 - \overline{\alpha}_{t-1}) + (1 - \alpha_t) }{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{\sqrt{1 - \overline{\alpha}_t}(1 - \alpha_t) }{\sqrt{\alpha_t}(1 - \overline{\alpha}_t)}\overline{z}_0 \\ &= \frac{1 - \overline{\alpha}_t}{\sqrt{\alpha}_t (1 - \overline{\alpha}_t)}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \\ &= \frac{1}{\sqrt{\alpha}_t}x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}\overline{z}_0 \tag{29} \end{align*} \]

经过这次化简,我们将\(\mu{(x_0, x_t)} \Rightarrow \mu{(x_t, \overline{z}_0)},\)其中\(\overline{z}_0 \sim \mathcal{N}(0, \textbf{I}),\)可以将式(29)转变为

\[\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t - \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}f_\theta(x_t, t) \tag{30} \]

此时对齐均值的问题就转化成:给定\(x_t, t\)预测\(x_t\)加入的噪声\(\overline{z}_0\), 也就是说我们的模型预测的是噪声\(f_\theta{(x_t, t)} = \epsilon_{\theta}(x_t, t) \simeq \overline{z}_0\)

2.3.1 训练与采样过程

训练的目标就是这所有时刻两个噪声的差异的期望越小越好(用MSE或L1-loss).

\[\mathbb{E}_{t \sim T } \parallel \epsilon - \epsilon_{\theta}(x_t, t)\parallel_2 ^2 \tag{31} \]

下图为论文提供的训练和采样过程

image

2.3.2 采样过程

通过以上讨论,我们推导出\(p_\theta(x_{t-1}|x_t)\)高斯分布的均值和方差.\(p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1}; \mu_{\theta}(x_t, t), \sigma^2(t) \textbf{I})\),根据文献[1]从一个高斯分布中采样一个随机变量可用一个重参数化技巧进行近似

\[\begin{align*} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}_t}}\epsilon_\theta(x_t, t)) + \sigma(t) \epsilon &\tag{32} \end{align*} \]

式(32)和论文给出的采样递推公式一致.

至此,已完成DDPM整体的pipeline.

还没想明白的点,为什么不能根据(7)的变形来进行采样计算呢?

\[x_{t-1} = \frac{1}{\sqrt{\alpha_t}}x_t - \sqrt{\frac{1 - \alpha_t}{\alpha_t}} f_\theta(x_t, t) \tag{33} \]

3 从代码理解训练&预测过程

3.1 训练过程

参考代码仓库: https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main/denoising_diffusion_pytorch

已知项: 我们假定有一批N张图片\(\{x_i |i=1, 2, \cdots, N\}\)

第一步: 随机采样K组成batch,如\(\mathrm{x\_start}= \{ x_k|k=1,2, \cdots, K \}, \mathrm{Shape}(\mathrm{x\_start}) = (K, C, H, W)\)

第二步: 随机采样一些时间步

t = torch.randint(0, self.num_timesteps, (b,), device=device).long()  # 随机采样时间步

第三步: 随机采样噪声

noise = default(noise, lambda: torch.randn_like(x_start))  # 基于高斯分布采样噪声

第四步: 计算\(\mathrm{x\_start}\)在所采样的时间步的输出\(x_T\)(即加噪声).(根据公式12)

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def q_sample(x_start, t, noise=None):
  """
  \begin{eqnarray}
    x_t &=& \sqrt{\alpha_t}x_{t-1} + \sqrt{(1 - \alpha_t)}z_t \nonumber \\
    &=&  \sqrt{\alpha_t}x_{t-1} + \sqrt{\beta_t}z_t
  \end{eqnarray}
  """
    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
    )

x = q_sample(x_start = x_start, t = t, noise = noise)  # 这就是x0在时间步T的输出

第五步: 预测噪声.输入\(x_T,t\)到噪声预测模型,来预测此时的噪声\(\hat{z}_t = \epsilon_\theta(x_T, t)\).论文用到的模型结构是Unet,与传统Unet的输入有所不同的是增加了一个时间步的输入.

model_out = self.model(x, t, x_self_cond=None)  # 预测噪声

这里面有一个需要注意的点:模型是如何对时间步进行编码并使用的

  • 首先会对时间步进行一个编码,将其变为一个向量,以正弦编码为例
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        """
        Args:
          x (Tensor), shape like (B,)
        """
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# 时间步的编码pipeline如下,本质就是将一个常数映射为一个向量
self.time_mlp = nn.Sequential(
    SinusoidalPosEmb(dim),
    nn.Linear(fourier_dim, time_dim),
    nn.GELU(),
    nn.Linear(time_dim, time_dim)
)
  • 将时间步的embedding嵌入到Unet的block中,使模型能够学习到时间步的信息
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift  # 将时间向量一分为2,一份用于提升幅值,一份用于修改相位

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):

        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

第六步:计算损失,反向传播.计算预测的噪声与实际的噪声的损失,损失函数可以是L1或mse

@property
    def loss_fn(self):
        if self.loss_type == 'l1':
            return F.l1_loss
        elif self.loss_type == 'l2':
            return F.mse_loss
        else:
            raise ValueError(f'invalid loss type {self.loss_type}')

通过不断迭代上述6步即可完成模型的训练

3.2采样过程

第一步:随机从高斯分布采样一张噪声图片,并给定采样时间步

img = torch.randn(shape, device=device)

第二步: 根据预测的当前时间步的噪声,通过公式计算当前时间步的均值和方差


  posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) # 式(24)x_0的系数
  posterior_mean_coef = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)  # 式(24) x_t的系数

  def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

  def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
        extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )  # 求出此时的均值
    posterior_variance = extract(self.posterior_variance, t, x_t.shape)  # 求出此时的方差
    posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) # 对方差取对数,可能为了数值稳定性
    return posterior_mean, posterior_variance, posterior_log_variance_clipped

  def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
      preds = self.model_predictions(x, t, x_self_cond)  # 预测噪声
      x_start = preds.pred_x_start  # 模型预测的是在x_t时间步噪声,x_start是根据公式(12)求

      if clip_denoised:
          x_start.clamp_(-1., 1.)

      model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
      return model_mean, posterior_variance, posterior_log_variance, x_start

第三步: 根据公式(32)计算得到前一个时刻图片\(x_{t-1}\)

  @torch.no_grad()
  def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
      b, *_, device = *x.shape, x.device
      batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
      model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)  # 计算当前分布的均值和方差
      noise = torch.randn_like(x) if t > 0 else 0. # 从高斯分布采样噪声
      pred_img = model_mean + (0.5 * model_log_variance).exp() * noise  # 根据
      return pred_img, x_start

通过迭代以上三步,直至\(T=0\)完成采样.

思考和讨论

DDPM区别与传统的VAE与GAN采用了一种新的范式实现了更高质量的图像生成.但实践发现,需要较大的采样步数才能得到较好的生成结果.由于其采样过程是一个马尔可夫的推理过程,导致会有较大的耗时.后续工作如DDIM针对该特性做了优化,数十倍降低采样所用时间。

参考文献

[1]  Understanding Diffusion Models: A Unified Perspective

[2]  Denoising Diffusion Probabilistic Models

标签:diffusion,dim,frac,denoising,self,sqrt,overline,alpha,probabilistic
From: https://www.cnblogs.com/myhz/p/18210120

相关文章

  • OOTDiffusion环境搭建&推理测试
    引子记得2015年左右,去参加VALSE的时候,就有虚拟试衣的项目亮相。现在回头看看,当时的效果还是十分简陋和不协调的。今天在全球最大的同性交友网站github上突然发现一个不错的虚拟试衣项目,看其效果还是不错,加入了扩散模型,效果看起来有质的提升。OK,让我们开始吧。一、模型介绍论文......
  • Stable Diffusion WebUI 绘画
    配置环境介绍目前平台集成了StableDiffusionWebUI的官方镜像,该镜像中整合如下资源:立即免费体验:https://gpumall.com/login?type=register&source=cnblogsStableDiffusionWebUI版本:v1.7.0Python版本:3.10.6Pytorch版本:2.0.1CUDA版本:11.8Xformers版本:0.0.20Gradio版本......
  • Stable Diffusion:原理、应用与未来展望
    一、引言在人工智能的快速发展中,StableDiffusion作为一种先进的随机过程模型,受到了广泛的关注。StableDiffusion不仅能够描述许多自然和人工系统中的随机演化行为,而且在多个领域展现出了广泛的应用潜力。本文将详细介绍StableDiffusion的原理、应用以及未来的发展趋势。立即......
  • (非原创)Stable Diffusion 提示词prompt tag语法总结
    基本认知提示词会相互污染,要尽可能地做减法。XL版本主推使用自然语言使用注释将修饰词汇限定给某个主体,避免提示词污染1girl(silverlonghair,purpleeyes),yellowsuit2people(1girlAND1boy)2characters(1girlAND1dog)权重调整旧语法:(){}加大权重,[]......
  • 使用stable diffusion设计logo的提示词
    使用stablediffusion设计logo的提示词StableDiffusion是一种基于图像处理和机器学习的算法,可以用于生成各种类型的图像,包括Logo设计。本文将介绍如何使用StableDiffusion来设计Logo,并提供一些提示词以帮助读者更好地理解和应用这种技术。1.了解StableDiffusion的基本原理在......
  • 如何使用stable diffusion设计logo
    好的,我可以帮你写一篇关于如何使用stablediffusion设计logo的文章。这篇文章将从第二级标题开始,主题为:如何使用stablediffusion设计logo。二级标题:什么是StableDiffusion?StableDiffusion是一种尖端的文本到图像扩散模型,可以根据任何给定的文本输入生成逼真的图像。通过使用......
  • [转]矿卡P104再就业AI绘图(附centos安装cuda及配置stable diffusion教程)
    原文地址:矿卡P104再就业AI绘图(附centos安装cuda及配置stablediffusion教程)-哔哩哔哩早就听说p104用的gtx1080同款核心,只是阉割了编解码与视频输出,cuda还在,有8G显存,一看就很适合ai画图,当然,150不到的超低廉价格才是笔者购买它的决定性原因!    废话不多说,在linux上使用该显......
  • StoryDiffusion文字生漫画
    地址https://github.com/HVision-NKU/StoryDiffusion安装condacreate-nstorydiffpython==3.11.0condaactivatestorydiff#修改一下requirements.txtgradio==4.21.0xformers==0.0.25diffusers==0.25.0transformers==4.36.2huggingface-hub==0.20.2spaces==0.19.......
  • 2024CVPR_Low-light Image Enhancement via CLIP-Fourier Guided Wavelet Diffusion(C
    一、Motivation1、单模态监督问题:大多数方法往往只考虑从图像层面监督增强过程,而忽略了图像的详细重建和多模态语义对特征空间的指导作用。这种单模态监督导致不确定区域的次优重建和较差的局部结构,导致视觉结果不理想的出现。------》扩散模型缺乏有效性约束,容易出现多种生成效......
  • Stable Diffusion webui.sh(Version: v1.9.3)选项翻译
    补充解释:Linux/iOS的目录==Windows文件夹options选项:-h,--help显示帮助信息并退出程序 showthishelpmessageandexit--update-all-extensions在启动时更新所有扩展插件 (此为launch.py脚本的参数,下同)launch.pyargument:downloadupdatesforallextensi......