首页 > 其他分享 >Part1: Overview of Diffusion Process

Part1: Overview of Diffusion Process

时间:2024-02-25 21:22:40浏览次数:35  
标签:Diffusion right mathbf Process Part1 textit theta left

本文将会概括性地介绍\(\textit{Diffusion Process}\)算法与实践,主要参考论文《Denoising Diffusion Probabilistic Models》。它的一些改进与优化,将“扩散方法”带入主流视野。

而具体的数学推导部分,请参考其它系列文章。整个系列有相对完整的公式推导,若正文中有涉及到的省略部分,皆额外整理在Part4,并会在正文中会指明具体位置。

Image Synthesis

在正式引入\(\textit{diffusion}\)前,希望先简单介绍图像生成的相关背景。一张图像由很多个像素点组成,对于彩色图像,每个像素点由三个0至255的整数表达,比如[255, 255, 255]代表像素点对应的颜色为白色。而一张512*512的图像就意味着共有26w左右的像素点。

image


图像生成的目标是学习像素及像素间的概率分布。结合具体的例子来理解概率分布,上图为拍摄于户外的照片。
假定它是由比较好的模型生成得到的图像。看到“图像上方是一片天空”,因为它是天空,所以上方的像素点不会是大面积的绿色覆盖。在生成天空区域时进行数值采样,采样到淡白色或者淡蓝色对应数值的概率,远大于采样到绿色对应数值的概率。

通过常规的机器学习方法,直接学习上述提到的概率分布很难,同时计算也很复杂。\(\textit{diffusion}\)的提出为学习图像的概率分布提供新的思路和解决办法。

Diffusion Process

图像生成领域的\(\textit{Diffusion Process}\)最早在2015年发表的论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》提到,作者认为它是一种适配于深度学习背景下的图片生成框架。其灵感来源于统计物理学中的非平衡态热力学\((\textit{non-equilibrium})\)。在该领域,\(\textit{diffusion}\)并非新名词,用于描述分子运动的现象。

Diffusion is the movement of a substance from a region of high concentration to a region of low concentration without bulk motion.

image

上图中,起初红色分子都聚集在液体的左上角,而当趋于稳定后,红色分子遍布在整个液体中。

那在图像生成领域,论文作者获取到的灵感是什么?
图像生成的其中一个难点在于不知晓图像服从何种分布,如此统计机器学习无法派上用场。故\(\textit{diffusion}\)的思路是,在\(\textit{forward process}\)中,通过缓慢、持续地往原数据上增加其它分布的数据,将图片原本分布破坏掉,转变到其它分布;同时,存在一个\(\textit{reverse process}\),即通过深度模型学习「如何将破坏掉的分布复原回破坏前的分布」。
前后两个过程都是缓慢且逐步进行的,天然地可以用马尔可夫链\((\textit{Markov chain})\)建模。并且,选择的其它分布是常见分布,比如高斯分布或二项分布等。这类分布的特性为人们熟知,可以方便地进行计算。

image

上图第一行描述二维数据分布从\(t = 0\)至\(t = T\)的变化情况,数据分布初始时为Swiss Roll,在逐步经过Gaussian的diffusion后,最终完全转变到服从Gaussian。

下面将围绕以下几部分展开介绍,符号定义与论文《Denoising Diffusion Probabilistic Models》保持一致:

  1. Forward Process
  2. Reverse Process
    • Loss Function
    • Model Architecture
  3. Train and Sample

Forward Process

给定一张图片,将其初始状态记为\(\mathbf{x}_0\),逐渐地对其增加服从其它分布的数据,如服从高斯分布,重复若干轮后,原图片的特征逐渐被销毁,完全难以分辨。为了方便,下文记“其它分布的数据”为噪声。

image

若对上图经过缓慢地加噪,其变换情况如下图所示。

image

整个缓慢的加噪过程建模为马尔可夫过程,状态间的转变\(\mathbf{x_{t-1}} \rightarrow \mathbf{x_{t}}\)服从高斯分布:

\[q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) \quad \tag{1} \]

故时刻\(t\)的图片\(\mathbf{x_t}\),由\(\mathbf{x_{t-1}}\)经过\((2)\)式得到,\((2)\)式的推导见Part4中高斯分布的性质1

\[\mathbf{x}_t =\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I}) \tag{2}\]

其中,\(\beta_t\)是个参数,为0到1的小数,表示时刻\(t\)噪声增加的幅度。有过相关实践的同学,应该会在模型文件或者代码中看到scheduler的字样,它便是用来产生\(\beta_t\)的模块。

对\((2)\)式中的项\(\beta_t\)改写,定义:

\[\begin{aligned} \alpha_t & =1-\beta_t \\ \bar{\alpha}_t & =\prod_{i=1}^t \alpha_i \end{aligned} \]


则通过推导(见Part4中推导一),可以得到\((3)\)式:

\[\begin{aligned} \mathbf{x}_t =\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \epsilon \end{aligned}, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})\tag{3} \]

这是一个很好的性质,意味着任意时刻的状态\(\mathbf{x}_t\)皆可由\(\mathbf{x}_0\)基于一次运算得到。在后续实际进行训练时,自然地与随机梯度下降\((\textit{Stochastic Gradient Descent})\)结合使用。

并且,当\(t\)足够长,发现\((3)\)式的第一项趋近于0,而\(\mathbf{x}_t \approx \epsilon\)。这也印证,经过逐步缓慢的“扩散过程”后,原数据的分布会被破坏,转而服从所加入噪声的分布。

Reverse Process

当\(\textit{forward process}\)进行完毕后,认为\(\mathbf{x}_T\)基本看作是一张噪声图片,它的像素点服从标准高斯分布。
$$p(\mathbf{x}_T) \sim \mathcal{N}(0, \mathbf{I})$$
前文提到\(diffusion\)方法需要在reverse阶段对被破坏的分布进行复原,如视频所示:

在这一过程中,深度学习模型介入,基于\((4)\)式表达这一过程:

\[\begin{gathered}p_\theta\left(x_0\right):=\int p_\theta\left(x_{0: T}\right) d x_{1: T} \\ p_\theta\left(\mathbf{x}_{0: T}\right):=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right), \quad p(\mathbf{x}_T) \sim \mathcal{N}(\mathbf{x}_T;0, \mathbf{I}) \\ p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right):=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \boldsymbol{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right) \end{gathered}\tag{4} \]

其中,\(\theta\)对应模型的参数。

Loss Function

扩散模型的损失函数基于变分推断\((\textit{variational inference})\)得到。简单来说,最初的目标是最大化对数似然估计,但直接对\(\ln p_\theta\left(x_0\right)\)无从做起。基于变分推断的思想,引入该目标的下界ELBO,通过最大化该下界实现计算。具体推导ELBO的方式请查看Part2

经过一系列推导,最后简化得到的损失函数如下所示:
$$L = \left|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)\right|^2
\tag{5}$$
目标是给定任意时刻\(t\)的状态\(\mathbf{x}_t\),预测此状态中所包含的噪声\(\epsilon_{\theta}(\mathbf{x}_t, t)\),并与真实增加的噪声\(\epsilon\)比较。\((5)\)式完全就是平方差损失函数\((\textit{mean squared error})\)。

Model Architecture

基于\((5)\)式,不难发现模型的输入与输出皆为图片格式的数据。在DDPM中,使用U-Net作为该部分模型的结构。下图为U-Net原始的结构框架:
image

DDPM中使用的UNet在原基础上,对于每个Block加入Attention以及Residual等模块。

拿上图所示的UNet举例,其包含DownMid以及Up三部分,体现在图中时,左边包含有红色向下箭头的部分皆属于Down,最底下包含数字1024的为Mid,而右边属于Up部分。不难看出,UNet的形状是镜像对称的,但是DownUp输入的张量维度却是不同,主要原因在于:Up中所有的子模块输入,除了接收上一层传递过来的张量,还会接收对应层级Down子模块传递过来的输出。拿Up部分最低一层56x56x1024来说,其中56x56x512来自较低一层传递而来,另一部分的56x56x512从左边的64x64x512经过copy+resize而来,两者concate之后构成了当前层的输入。

更加具体的代码实现可以参考

Train and Sample

这样一来,模型的训练算法如下所示:
image

而采样算法如下所示:
image

其中,\(\sigma_t\)在具体实现中一般设置为\(\sqrt{\beta_t}\),是在diffusion process时该时间步\(t\)对应的标准差。

Summary

本文相对粗旷地介绍了“扩散过程”。前半阶段,主要基于论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》引入\(\textit{diffusion}\)的具体思想;后半阶段,以《Denoising Diffusion Probabilistic Models》为例,引入其训练与采样的算法步骤。

对于\(\textit{diffusion}\),其最终在代码层面的实现不算复杂;实际上,比较令人着迷的是其完备的数学推导,以及与其它隐变量模型\((\textit{latent models})\)的关联。

Reference

标签:Diffusion,right,mathbf,Process,Part1,textit,theta,left
From: https://www.cnblogs.com/shayue/p/18033069

相关文章

  • [Rust] Exit a program using std::process in Rust
    Inthislessonwe'lllearnhowtoexitaprogramusingthe std::process moduleinRustandit's exit() method. usestd::io;usestd::process;fnmain(){letmutfirst=String::new();io::stdin().read_line(&mutfirst).unwrap()......
  • process.env.API_KEY undefined问题解决
    问题现象已经在root路径下面创建.env文件,但是使用process.env.API_KEY获取不到值。分析获取不到env文件中的值,检查env文件已配置API_KEY,检查是否安装了dotenv,检查是否导入配置了dotenv解决方法在index.ts中导入import'dotenv/config';应该在使用env的模块前面就导入dote......
  • Go - concurrent processing is not always faster than sequential processing
      ......
  • day38 动态规划part1 代码随想录算法训练营 746. 使用最小花费爬楼梯
    题目:746.使用最小花费爬楼梯我的感悟:哈哈,我居然自己独立写出来了,确实,只要定义定清楚了,哪怕定的含义只有自己能看懂,只要定义一致就可以求出解决来!!!我真是个大天才!!理解难点:听课笔记:代码示例:classSolution:defminCostClimbingStairs(self,cost:List[int])->int:......
  • day38 动态规划part1 代码随想录算法训练营 70. 爬楼梯
    题目:70.爬楼梯我的感悟:居然自己先写出来了!!继续努力!!理解难点:听课笔记:我的代码:classSolution:defclimbStairs(self,n:int)->int:ifn==1:return1dp=[0]*(n+1)dp[1]=1dp[2]=2foriinran......
  • subprocess中的return_code与poll
    subprocess中的return_code与pollp=subprocess.Popen('ping8.8.8.8',shell=True,stdout=subprocess.PIPE,stderror=subprocess.DEVNULL)whilenotp.poll():#p.poll()即为return_codeprint(p.stdout.read().decode())#return_code=p.poll()#......
  • 一文搞懂Flink Window机制 Windows和 Function 和 Process组合处理事件
    一文搞懂FlinkWindow机制和Function和Process组合处理事件Windows是处理无线数据流的核心,它将流分割成有限大小的桶(buckets),并在其上执行各种计算。Windows是处理无线数据流的核心,它将流分割成有限大小的桶(buckets),并在其上执行各种计算。窗口化的Flink程......
  • mysql: show processlist 详解
    showprocesslist显示的信息都是来自MySQL系统库information_schema中的processlist表。所以使用下面的查询语句可以获得相同的结果:select*frominformation_schema.processlist了解这些基本信息后,下面我们看看查询出来的结果都是什么意思。Id:就是这个线程的唯一标......
  • Java版Flink(十二)底层函数 API(process function)
    一、概述之前的转化算子是无法访问事件的时间戳信息和水位线watermark,但是,在某些情况下,显得很重要。Flink提供了DataStreamAPI的Low-Level转化算子。比如说可以访问事件时间戳、watermark、以及注册定时器,还可以输出一些特定的事件,比如超时事件等。ProcessFunction用......
  • 【Flink】使用CoProcessFunction完成实时对账、基于时间的双流join
    【Flink】使用CoProcessFunction完成实时对账、基于时间的双流join文章目录零处理函数回顾一CoProcessFunction的使用1CoProcessFunction使用2实时对账(1)使用离线数据源(批处理)(2)使用高自定义数据源(流处理)二基于时间的双流Join1基于间隔的Join(1)正向join(2)反向join2......