大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细介绍U-ViT的模型架构和实验细节,虽然没有后续的DiT在AIGC领域火爆,但为后来的研究奠定了基础,但其开创性的探索值得学习。
文章目录
在前面的多模态系列文章中,我们介绍了DiT,其作为AIGC时代的新宠儿,将Transformer和Diffusion结合起来的,是近几年图像和视频生成的领域的优选结构。其实早在2022年9月,清华大学团队就发布了用「基于Transformer的架构U-ViT」替代基于卷积架构的U-Net,只不过没有现在的DiT火热。下面详细介绍U-ViT:
论文
All are Worth Words: A ViT Backbone for Diffusion Models
背景
扩散模型是一种强大的深度生成模型,近年来在高质量图像生成任务中展现了巨大的潜力。它们的发展速度迅猛,广泛应用于文本到图像生成、图像到图像生成、视频生成、语音合成以及3D合成等领域。
在目前的扩散模型中,骨干(backbones)结构的变革在扩散模型中起着核心作用。目前火热的扩散模型,如SD系列模型就是基于CNN的U-Net的,并取得了优异的性能。它通过下采样(encoding)和上采样(decoding)操作,结合跳跃连接,在捕捉局部细节和全局结构方面表现突出。
在SD 3 之前的扩散模型使用的是以基于CNN的U-Net为骨干结构的,在SD 3 之后的系列中,则是使用了DiT作为了骨干结构。
扩散模型系列参考:SD合集
将视觉和Transformer相结合的 ViT 在各种视觉任务中显现出了前景,而基于CNN的U-Net在扩散模型中仍占主导地位。本文中设计了一个简单而通用的基于ViT的架构(U-ViT),替代扩散模型中的U-Net,用于生成图像。
ViT参考:多模态论文笔记——ViT、ViLT
- 设计一个基于其他架构(如ViT)的模型,要在性能上达到甚至超越U-Net并不容易。
- ViT作为Transformer家族的一员,擅长处理全局特征,但其对局部细节的捕捉能力不如CNN。因此,直接用ViT替代U-Net并不现实。
U-ViT 核心设计如下:
- U-ViT是基于ViT(Vision Transformer)架构的,它借鉴了U-Net的结构,用于替代扩散模型中的基于 CNN 的 U-Net,以生成图像。
- 它将时间、条件和噪声图像补丁作为 token 输入,并使用**长跳跃连接(long skip connections)**连接浅层和深层。
效果:
- 在无条件、类条件图像生成和文本到图像生成任务中,U-ViT表现出色。
- 研究表明,长跳跃连接对于扩散模型中的图像建模至关重要,而 CNN-based U-Net 中的下采样和上采样操作并非总是必需的。
架构
本文中,作者设计的U-ViT架构,如下图所示:
图1. 用于扩散模型的U-ViT架构,其特点是将所有输入(包括时间、条件和噪声图像补丁)作为token,并在浅层和深层之间采用(#Blocks-1)/2个长跳跃连接。
架构说明:
- U-ViT遵循ViT的设计方法,对图片进行一个Patch化的操作,并且U-ViT将所有输入(包括时间、条件和图像patch)都视为token。
- 将时间 t t t、条件 c c c、图像patch后的加噪图像 x t x_t xt 作为输入,然后【通过 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c)】在 U-ViT模型中预测加入 x t x_t xt中的噪声。
- 受基于CNN的U-Net在扩散模型中的成功启发,U-ViT采用了类似的浅层和深层之间的长跳跃连接。
- ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c)的目标是像素级预测任务,对低级特征敏感。长跳跃连接为低级特征提供了捷径,使用长跳跃连接(long skip connections)连接浅层和深层,使浅层特征传递到深层,为深层网络提供更丰富的信息。
- 【可选项】U-ViT在输出之前添加一个3×3的卷积块。旨在防止transformer生成的图像中出现潜在的伪影(potential artifacts in images)。
在论文的Background部分,还对扩散模型的扩散原理进行了简单的回顾,如不了解这个内容,建议参考:Stable Diffusion的加噪和去噪详解
训练细节
作者通过系统的实验研究,精心设计了其关键实现细节,并在CIFAR10数据集上进行了消融实验,通过消融实验,作者确定了以下最佳实现细节:
- 长跳跃连接采用 连接后线性投影 的方式。
- 时间信息通过 Token 注入更优。
- 额外卷积块在 线性投影后 添加效果最佳。
- Patch Embedding 使用 线性投影 更好。
- 位置编码使用 一维可学习嵌入(1-dimensional learnable position embedding) 是最佳选择(U-ViT和ViT、ViLT、DiT选择的位置编码一样,都是1D position embeddings,不同的是:DiT是不可学习的,ViT、ViLT和U-ViT是可学习的)。
消融实验(Ablation Study)是一种常见的实验方法,用于评估复杂系统中各个组件或设计对整体性能的贡献。通过系统地移除、替换或修改某个组件,然后观察模型性能的变化,研究人员可以验证该组件的作用并优化设计。
1. 长跳跃连接 (Long Skip Connections)
问题:哪种长跳跃连接方法更优?
实验设置:考虑以下几种主分支
h
m
h_m
hm 和长跳跃分支
h
s
h_s
hs 的组合方法:
- 方法1:将它们连接后执行线性投影: Linear(Concat ( h m , h s ) ) \text{Linear(Concat}(h_m, h_s)) Linear(Concat(hm,hs))
- 方法2:直接相加: h m + h s h_m + h_s hm+hs
- 方法3:线性投影 h s h_s hs 后相加: h m + Linear ( h s ) h_m + \text{Linear}(h_s) hm+Linear(hs)
- 方法4:相加后进行线性投影: Linear ( h m + h s ) \text{Linear}(h_m + h_s) Linear(hm+hs)
- 方法5:不使用长跳跃连接。
结果:
- 方法1(连接后线性投影) 的性能最佳。该方法显著改变了表征信息,提升了模型性能。
- 方法2(直接相加) 表现较差,因为Transformer内部已有加法操作,导致无显著增益。
2. 时间信息的注入方式 (Feeding Time into the Network)
问题:如何将时间
t
t
t 送入网络?
实验设置:
- 方法1:将时间 t t t 作为一个Token输入(如图1所示)。
- 方法2:通过自适应层归一化 (Adaptive LayerNorm, AdaLN) 融入时间信息:
AdaLN ( h , y ) = y s ⋅ LayerNorm ( h ) + y b \text{AdaLN}(h, y) = y_s \cdot \text{LayerNorm}(h) + y_b AdaLN(h,y)=ys⋅LayerNorm(h)+yb
其中, y s y_s ys 和 y b y_b yb 为时间嵌入的线性投影。
结果:
- 方法1(将时间视为Token) 效果更好,尽管实现简单。
3. 额外的卷积块 (Extra Convolutional Block)
问题:Transformer后额外卷积块的位置对性能的影响?
实验设置:
- 方法1:在线性投影后添加一个3×3卷积块,将Token映射到图像Patch。
- 方法2:在线性投影前添加一个3×3卷积块。
- 方法3:不添加卷积块。
结果:
- 方法1(在线性投影后添加卷积块)性能略优。
4. Patch Embedding 的变体
问题:哪种Patch Embedding方式更好?
实验设置:
- 方法1:使用线性投影将Patch映射为Token嵌入(原始方式)。
- 方法2:堆叠3×3卷积块,后接1×1卷积块,将图像映射为Token嵌入。
结果:
- 方法1(原始线性投影) 表现优于卷积堆叠方式。
5. 位置编码 (Position Embedding)
问题:哪种位置编码更优?
实验设置:
- 方法1:一维可学习位置嵌入(ViT默认设置)。
- 方法2:二维正弦位置嵌入,Patch的 position ( i , j ) \text{position}(i, j) position(i,j) 由 i i i 和 j j j 的正弦编码拼接得到, i i i 和 j j j 分别是二维网格中的行索引和列索引。
- 方法3:不使用任何位置编码。
结果:
- 方法1(1D可学习位置嵌入) 性能最佳。
- 方法3(无位置编码) 无法生成有意义的图像,表明位置编码对图像生成至关重要。
深度、宽度、patch大小的影响
论文中还探讨了深度(层数)、宽度(隐藏层尺寸)和patch size对模型性能的影响。效果如下图所示:
-
深度 (Depth):
- 随着模型深度的增加,性能得到了提高(例如:depth=9, 13),证实了 scale 特性。
- 然而,在50K训练迭代后,增加到更大的深度(depth=17)并未带来额外的性能提升。
-
宽度 (Width):
- 增加隐藏层的宽度(例如:width=256, 512)有助于性能的提升。
- 然而,进一步增加到width=768并没有带来性能增益。
-
Patch Size:
- 减小patch size可以提高性能(例如:patch-size=8, 2),但是,减小到patch-size=1时,不再有任何性能提升。
- 作者认为,为了获得良好的性能,较小的patch size(如patch-size=2)是必要的。推测原因是扩散模型的噪声预测任务需要低级别的细节,而这与高级任务(如分类)不同。
-
低维潜在表示:
- 小的patch尺寸对于高分辨率图像的计算代价较高,因此作者选择将图像转换为低维潜在表示,并利用U-ViT对这些低维表示进行建模,【同SD模型,使用VAE进行降维】。
总结
通读完U-ViT,可以看出 U-ViT 和 后面发布并且爆火的 DiT在设计上有异曲同工之处:
- 二者均是将 Transformer 与扩散模型融合的思路
- 实验路径也相似,比如都采用了1 D 的位置编码 、在patch size上,都得出了同样的结论:patch size 为 2*2 是最理想的,都使用了和ViT一样的位置编码:1 D的正余弦。
- 在模型参数量上,两者都在 50M-500M 左右的参数量上做了实验,最终都证实了Transformer的强大 scale 特性。
- 额外的条件信息(时间信息/Timesteps,和文本信息)的注入方式实验中,都验证了自适应层规范化(AdaLN)。只不过U-ViT实验表明将时间 t t t 作为一个Token输入虽然简单,但是表现更好;而DiT实验中则认为AdaLN(准确的说是adaLN-Zero)效果更好。
DiT参考历史文章:多模态论文——DiT
长跳跃连接在图像扩散模型中的作用
长跳跃连接(long skip connections) 在图像扩散模型中的作用和 ResNet(Residual Networks) 的作用有相似之处。下面是长跳跃连接在图像扩散模型中的作用的详细介绍
1. 信息传递
- 直接连接浅层和深层,细节保留:通过长跳跃连接(long skip connections),浅层提取的低级特征(如边缘、纹理等)可以直接传递给深层网络。 使深层网络可以获得来自浅层的更丰富的信息。这种信息传递可以帮助深层网络更好地理解和捕捉图像中的细节和特征,从而提高图像扩散模型的性能。
2. 特征整合
- 融合多层次信息,增强上下文理解:通过跳跃连接,将浅层特征与深层特征融合(如通过加法、拼接等操作),形成丰富的多尺度特征表示。可以获得更丰富、更全局的特征表示。这种特征整合可以帮助模型更好地理解图像的上下文和语义信息,提高生成图像的质量和准确性。
3. 梯度传播
- 缓解梯度消失和梯度爆炸,增强训练稳定性:长跳跃连接通过直接连接浅层和深层,使得梯度能够从深层更有效地反向传播到浅层,避免梯度在传播过程中的逐渐衰减或增大。梯度的顺畅传播有助于网络各层参数的学习更加稳定,从而提高训练的收敛速度和效果。
4. 其他作用
- 支持高分辨率生成:在高分辨率图像生成中,长跳跃连接能够帮助模型更好地传递细粒度特征信息,避免因过多的下采样导致的分辨率损失。
- 减少依赖下采样和上采样操作:相比传统的卷积U-Net中大量依赖下采样和上采样,长跳跃连接可以减少对这些操作的依赖,从而降低结构复杂度。