文章目录
原文地址:https://arxiv.org/abs/2403.13043
开源代码:https://github.com/bfshi/scaling_on_scales
0. 问题和想法
-
本文提出的问题:对于更好的视觉理解来说,更大的模型一定是必要的吗?
-
核心思想:保持预训练模型的规模不变,通过在越来越多的图像尺寸上运行获得越来越强大的特征,本文的作者结合上面的思想提出了 Scaling on Scales (S2)。
1. 观察和见解
- 观察 1:虽然在许多情况下,使用 S^2 的较小的模型比较大的模型能获得更好的下游性能,但较大的模型仍能在较难的例子中表现出更优越的泛化能力。
- 见解 1:较小模型至少应该具有与较大模型相似的学习能力。实验表明,通过单一的线性变换,能够让具有多尺寸的较小模型更好地近似较大模型的特征。
- 见解 2:本文假设较小模型的泛化能力较弱,是因为只用了单一的图像尺寸进行了预训练。通过实验得出,使用 S^2 缩放进行预训练可提高较小模型的泛化能力,使其能够与较大模型相媲美,甚至超越较大模型的优势。
- 观察 2:在冻住骨干网络并只训练特定于任务的头的实验中,S^2 scaling 方法在密集预测型任务(如语义分割和深度估计)中更具优势,可能是得益于多尺寸特征能提供更好的细节理解能力,而这正是这些任务所特别需要的。
2. 设计和框架
2.1 关键设计
- 将大图像分割成小的子图像,而不是直接在整个大图像上运行,避免了自注意力的平方计算复杂度,并且防止了位置嵌入插值导致的性能下降
- 处理单个子图像,而不是使用窗口注意力,可以使用不支持窗口注意力的预训练模型,避免从头开始训练额外的参数(例如相对位置嵌入)
- 将大的特征图插值为常规大小的特征图,从而确保输出 token 的数量保持不变,减少下游应用程序的计算开销。
2.2 模型框架
具体来说,`S^2-Wrapper` 首先将输入的低分辨率图像通过插值获得更大尺寸的图像,例如将 224^2^ 插值为 448^2^ 尺寸的图像。接着,`S^2-Wrapper` 将 448^2^ 图像划分成四个 224^2^ 子图像,这些子图像与原始的 224^2^ 图像被输入到同一个预训练好的模型中去抽取特征。然后,通过将四个子图像的抽取出的特征进行合并,从而获得 448^2^ 图像所对应的大特征图,再对这个大特征图进行平均池化到与 224^2^ 图像的特征图相同的大小。最后的输出为不同尺寸下的特征图的拼接结果,并且其具有与单尺寸特征相同的空间形状,但具有更高的通道维度。
注意:当然可以直接输入原生高分辨的图像,作者在这里使用插值的方法是为了确保不引入额外的高分辨率信息,以便与模型规模缩放的方法进行公平比较。
3. 源码解析
3.1 utils.py 文件
该文件包含 split_chessboard
和 merge_chessboard
两个函数,分别用于对输入图像进行划分和对输出特征进行合并。
对于 split_chessboard
而言,若输入张量 x
的维度为 (16, 3, 448, 448)
, num_split
为 2, 则输出张量 x_split
为 (64, 3, 224, 224)
def split_chessboard(x, num_split):
B, C, H, W = x.shape
assert H % num_split == 0 and W % num_split == 0
h, w = H // num_split, W // num_split
# 将 x 划分为 (num_split ** 2) 个 sub_image, 其维度由 (B, C, H, W) 变为 (B * (num_split ** 2), C, h, w)
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)],dim=0)
return x_split
对于 merge_chessboard
而言,若输入张量 x
的维度为 (4, 768, 14, 14)
, num_split
为 2, 则输出张量 x_merge
为 (1, 768, 28, 28)
def merge_chessboard(x, num_split):
B, C, H, W = x.shape
assert B % (num_split**2) == 0
b = B // (num_split**2)
# 对 sub_image 沿着 batch 维度进行拼接后, 维度由 (B, C, H, W) -> (B / (num_split ** 2), C, num_split * H, num_split * W)
x_merge = torch.cat([torch.cat([x[(i*num_split+j)*b:(i*num_split+j+1)*b] for j in range(num_split)], dim=-1)
for i in range(num_split)], dim=-2)
return x_merge
3.2 core.py 文件
3.2.1 forward 函数中输入参数的含义
下面给出 forward
函数中各个参数的含义 (翻译自本文 github 开源代码) :
-
model
:您的视觉模型或任何输入 BxCxHxW 图像张量并输出 BxNxC 特征张量的函数。 -
input
:输入形状为 BxCxHxW 的图像张量。 -
scales
:包含要提取特征的尺寸的列表。例如,如果默认尺寸为2242,则scales=[1,2]
将提取 2242 和 4482 两个尺寸上的特征。 -
img_sizes
:或者,您可以为每个尺寸去分配图像大小,而不是分配scales
。例如,如果默认尺寸为2242,img_sizes=[224, 448]
将产生与scales=[1,2]
相同的结果。 -
max_split_size
:从大图像分割出来的子图像的最大尺寸。对于每个尺寸来说,图像将被分割成ceil(img_size_that_scale / max_split_size)**2
个子图像。如果为None
,则默认设置为input
的大小。 -
resize_output_to_idx
:将最终输出的特征图调整到哪个尺寸。默认值是scales
或img_sizes
中的第一个比例。 -
num_prefix_token
:特征图中 prefix token 的数量。例如,如果model
返回的特征图包含 1 个 [CLS] token和其他的 spatial token,则设置num_prefix_token=1
。默认为0
。 -
output_shape
:输出特征的形状。要么为bnc
(e.g., ViT) 或bchw
(e.g., ConvNet). 默认为bnc
。
3.2.2 forward 函数的处理逻辑
下面是对 core.py
文件中的 forward
函数的源码进行解析,这同样也是 Scaling on Scales 方法的整体处理逻辑。
本例中我们使用的 model
为预训练好的 VIT-Base-16
,input
的维度为 (1, 3, 224, 224)
,scales
设置为 [1, 2]
,其余参数的配置均保持不变。
注意:将输入为 (1, 3, 224, 224)
的图像张量输入到 ViT-Base-16
中,其最后一层 Block 输出的特征维度为 (1, 197, 768)
。
def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
output_shape='bnc'):
assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
assert input.shape[2] == input.shape[3], "Currently only square images are supported."
assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."
b, c, input_size, _ = input.shape
assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
# img_sizes 中存放的是图像的各个尺寸, 如 224, 448, 672 等
img_sizes = img_sizes or [input_size * scale for scale in scales]
# 划分的子图像的最大尺寸, 默认为输入图像的尺寸
max_split_size = max_split_size or input_size
# 计算对应尺寸下应划分出的子图像数量
num_splits = [math.ceil(size / max_split_size) for size in img_sizes]
input_multiscale = []
for size, num_split in zip(img_sizes, num_splits):
# 通过对输入图像进行插值缩放到对应的尺寸
x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
# 将缩放后的图像划分为 num_split 个子图像
x = split_chessboard(x, num_split=num_split)
input_multiscale.append(x)
# 在每个尺寸下做一次前向传递, 并将输出结果存放到 outs_multiscale
outs_multiscale = [model(x) for x in input_multiscale]
if num_prefix_token > 0:
# 提取多个尺寸对应特征的 prefix_token. 如果模型是 ViT 的话, 就是 CLS Token
outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
# 然后更新 outs_multiscale (剔除了 prefix_token)
outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
if output_shape == 'bnc':
# 将 outs_multiscale 中多个尺寸下的输出进行维度上的重新排列
# 例如 (1, 196, 768) -> (1, 768, 14, 14)
outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
for out in outs_multiscale]
# 在每个尺寸下, 对其划分的多个子图像的输出结果进行合并
outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]
# output_size 其实是所有尺寸下合并后的特征最终要 resize 到的大小
output_size = outs_multiscale[resize_output_to_idx].shape[-2]
# 将不同尺寸下的合并结果进行插值后 (对应着处理逻辑中的 Pooling 操作)
# 再按着维度 1 进行拼接
out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
mode='area').to(outs_multiscale[i].dtype)
for i in range(len(outs_multiscale))], dim=1)
if output_shape == 'bnc':
out = rearrange(out, 'b c h w -> b (h w) c')
if num_prefix_token > 0:
# 在每个尺寸下, 对其所有子图像对应的 prefix token 进行合并取均值操作
outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
# 将所有尺寸的 CLS token 沿着最后一个维度进行拼接
out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
# 将合并后的 out 和 out_prefix_multiscale 进行拼接, 从而完成 Scaling on Scales 的全部流程
out = torch.cat([out_prefix_multiscale, out], dim=1)
return out
4. 消融实验
- 将从大尺寸图像中划分出来的子图像输入到视觉模型与直接将大尺寸图像输入到视觉模型进行对比。结果表明,将大尺寸图像划分/切分成多个子图像作为输入的性能表现更好。
- 将不同尺寸的特征图进行拼接 (通道维度上的拼接) 和直接将它们加在一起 (通道维度上的求和) 进行对比。结果表明,将不同尺寸的特征图进行拼接 (通道维度上的拼接) 的性能表现更好。
5. 未来可能的改进方向
- Scale-selective processing:即图像中每个位置的每个尺寸并非都包含同样有用的特征;根据图像内容和高级任务,为每个区域选择某些尺寸去进行处理会更有效,这类似于人类视觉注意力中自下而上和自上而下的选择机制。
- parallel processing of single image:与普通 ViT 一次性处理整幅图像不同,S^2 中的每个子图像都是独立处理的,因此可以并行处理单幅图像的不同子图像,这对于处理单幅大图像的延迟要求很高的场景来讲是很有用的。