论文链接:[2405.04312] Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (arxiv.org)
现有图像扩散模型生成的图像的分辨率通常被限制在 1024 ×1024 像素或更低,在生成超高分辨率图像(例如 4096 ×4096)时内存会二次增加,
上采样到更高分辨率的图像的最大挑战是显着的 GPU 内存需求。另外一个问题是如果要将图像完整的输入模型中,会占用的空间。
因此本文提出了一种单向块注意(UniBA)算法,该算法可以显著降低从O(N2)到O(N)的生成空间复杂度,大大提高了最高的可用分辨率。
Methodology
Unidirectional Block Attention (UniBA)
在UNet、DiT等模型中,块之间的依赖关系是双向的,即在计算时必须同时生成图像中的所有块。为了节省块隐藏状态的内存,我们希望设计一种算法,使其允
许同一图像中的块被分成几批来生成,每批只需要同时生成一部分块,并按批次顺序生成。
主要思想是将图片划分为块,其中B为块的大小。并提出了如下图所示的注意力实现:
左图:单向块注意力中,每个块直接取决于自身层的三个块:左上角的块、左侧和上面的块。
右图:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n × n的block。在这个过程中,只有后续块所依赖的块的KV-cache存储在内存中。
Inf-DiT 架构中,块之间的依赖关系是注意力操作。且transformer中单向块注意力可以计算如下:
表示第n层i行j列的块的隐藏状态,为块间相对位置编码。
虽然该方法每一个block的计算依赖的范围变小了,但是由于特征逐层传递,还是可以捕捉到长距离的信息;
在上图中,随着block计算的向前推进,不断有block的hidden states的值被丢弃。即可空间复杂度由原来的变为
Basic Model Architecture
Inf-DiT 的架构使用了与DiT类似的主干,它将Vision Transformer (ViT)应用于扩散模型,与基于卷积的体系结构(如UNet)相比,DiT仅利用注意力作为patch之间的
交互机制,可以方便地实现单向块注意。为了适应单向块注意,提高上采样的性能,我们做了如下几个修改和优化。
Model input
考虑到颜色偏移和细节损失等压缩产生的损失,Inf-DiT 的重建是在 RGB 像素空间中进行的,而不是潜在空间。在超分为f倍时,首先将低分辨率RGB图像上采样f倍,然后将其与扩散的噪声输入在特征维数上连接起来,然后将其输入到模型中。
Position Encoding
参考RoPE旋转位置编码。首先创建一个足够大的位置编码表,使用随机起点:对于每个训练图像,为图像的左上角随机分配一个位置 (x, y),而不是默认的 (0,0)。此外,考虑到同一块内和不同块之间的交互差异,还引入了块级相对位置编码,它根据注意前的相对位置分配不同的可学习嵌入。
Global and Local Consistency
Global Consistency with CLIP Image Embedding
利用预训练的CLIP中的图像编码器从低分辨率图像中提取图像嵌入,称之为语义输入。由于CLIP是在互联网上海量的图像-文本对上训练的,其图像编码器可以有效地从低分辨率图像中提取全局信息。将全局语义嵌入添加到DiT的时间嵌入中,并将其输入到每一层,使模型能够直接从高级语义信息中学习。
使用 CLIP 中的图像-文本潜在空间,即使模型没有在任何图像-文本对上进行训练,也可以使用文本来指导生成的方向。
给定一个正提示和一个负提示,就可以更新图像嵌入:
α用于控制语义的引导强度。在推理过程中,我们可以简单地使用
代替 作为全局语义嵌入来进行控制。
Local Consistency with Nearby LR Cross Attention
模型学习 LR 和 HR 图像之间的局部对应关系时仍然可能存在连续性问题。为了解决这个问题,引入了 Nearby LR Cross Attention。在transformer的第一层中,
每个块对周围的3 × 3 LR块进行交叉注意,以捕获附近的LR信息。实验表明,这种方法显着减少了生成不连续图像的概率。
Experiments
HPDV2数据集下超高分辨率的定量实验:
表现了模型生成高分辨率细节和协调全局信息的能力。虽然在4096X4096下的FID值略小于BSRGAN,但FIDcrop 是高分辨率特征的更有代表性的指标
FIDcrop是从高分辨率图像中随机抽取299 × 299个patch进行FID评估,不会像FID一样忽略了高分辨率的细节,因为FID的原始实现需要在特征提取前将输入图像
下采样到299 × 299的分辨率
下表是在DIV2K数据集下的超分定量实验: