首页 > 编程语言 >【论文和源码解读】Scaling on Scales:When Do We Not Need Larger Vision Models?

【论文和源码解读】Scaling on Scales:When Do We Not Need Larger Vision Models?

时间:2024-03-26 16:00:52浏览次数:20  
标签:multiscale num Larger Scaling 源码 尺寸 split 图像 out

文章目录

原文地址https://arxiv.org/abs/2403.13043

开源代码https://github.com/bfshi/scaling_on_scales

0. 问题和想法

  • 本文提出的问题:对于更好的视觉理解来说,更大的模型一定是必要的吗?

  • 核心思想保持预训练模型的规模不变,通过在越来越多的图像尺寸上运行获得越来越强大的特征,本文的作者结合上面的思想提出了 Scaling on Scales (S2)

1. 观察和见解

  1. 观察 1:虽然在许多情况下,使用 S^2 的较小的模型比较大的模型能获得更好的下游性能,但较大的模型仍能在较难的例子中表现出更优越的泛化能力
  2. 见解 1较小模型至少应该具有与较大模型相似的学习能力。实验表明,通过单一的线性变换,能够让具有多尺寸的较小模型更好地近似较大模型的特征。
  3. 见解 2本文假设较小模型的泛化能力较弱,是因为只用了单一的图像尺寸进行了预训练。通过实验得出,使用 S^2 缩放进行预训练可提高较小模型的泛化能力,使其能够与较大模型相媲美,甚至超越较大模型的优势。
  4. 观察 2:在冻住骨干网络并只训练特定于任务的头的实验中,S^2 scaling 方法在密集预测型任务(如语义分割和深度估计)中更具优势,可能是得益于多尺寸特征能提供更好的细节理解能力,而这正是这些任务所特别需要的。

2. 设计和框架

2.1 关键设计

  1. 将大图像分割成小的子图像,而不是直接在整个大图像上运行,避免了自注意力的平方计算复杂度,并且防止了位置嵌入插值导致的性能下降
  2. 处理单个子图像,而不是使用窗口注意力,可以使用不支持窗口注意力的预训练模型,避免从头开始训练额外的参数(例如相对位置嵌入)
  3. 将大的特征图插值为常规大小的特征图,从而确保输出 token 的数量保持不变,减少下游应用程序的计算开销。

2.2 模型框架

Scaling on Scales 的框架

具体来说,`S^2-Wrapper` 首先将输入的低分辨率图像通过插值获得更大尺寸的图像,例如将 224^2^ 插值为 448^2^ 尺寸的图像。接着,`S^2-Wrapper` 将 448^2^ 图像划分成四个 224^2^ 子图像,这些子图像与原始的 224^2^ 图像被输入到同一个预训练好的模型中去抽取特征。然后,通过将四个子图像的抽取出的特征进行合并,从而获得 448^2^ 图像所对应的大特征图,再对这个大特征图进行平均池化到与 224^2^ 图像的特征图相同的大小。最后的输出为不同尺寸下的特征图的拼接结果,并且其具有与单尺寸特征相同的空间形状,但具有更高的通道维度。

注意:当然可以直接输入原生高分辨的图像,作者在这里使用插值的方法是为了确保不引入额外的高分辨率信息,以便与模型规模缩放的方法进行公平比较

3. 源码解析

3.1 utils.py 文件

该文件包含 split_chessboardmerge_chessboard 两个函数,分别用于对输入图像进行划分和对输出特征进行合并。

对于 split_chessboard 而言,若输入张量 x 的维度为 (16, 3, 448, 448), num_split 为 2, 则输出张量 x_split(64, 3, 224, 224)

def split_chessboard(x, num_split):
    B, C, H, W = x.shape
    assert H % num_split == 0 and W % num_split == 0
    h, w = H // num_split, W // num_split
    # 将 x 划分为 (num_split ** 2) 个 sub_image, 其维度由 (B, C, H, W) 变为 (B * (num_split ** 2), C, h, w)
    x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)],dim=0)
    return x_split

对于 merge_chessboard 而言,若输入张量 x 的维度为 (4, 768, 14, 14), num_split 为 2, 则输出张量 x_merge(1, 768, 28, 28)

def merge_chessboard(x, num_split):
    B, C, H, W = x.shape
    assert B % (num_split**2) == 0
    b = B // (num_split**2)
    # 对 sub_image 沿着 batch 维度进行拼接后, 维度由 (B, C, H, W) -> (B / (num_split ** 2), C, num_split * H, num_split * W)
    x_merge = torch.cat([torch.cat([x[(i*num_split+j)*b:(i*num_split+j+1)*b] for j in range(num_split)], dim=-1)
                         for i in range(num_split)], dim=-2)
    return x_merge

3.2 core.py 文件

3.2.1 forward 函数中输入参数的含义

下面给出 forward 函数中各个参数的含义 (翻译自本文 github 开源代码) :

  • model:您的视觉模型或任何输入 BxCxHxW 图像张量并输出 BxNxC 特征张量的函数。

  • input:输入形状为 BxCxHxW 的图像张量。

  • scales:包含要提取特征的尺寸的列表。例如,如果默认尺寸为2242,则 scales=[1,2] 将提取 2242 和 4482 两个尺寸上的特征。

  • img_sizes:或者,您可以为每个尺寸去分配图像大小,而不是分配 scales。例如,如果默认尺寸为2242img_sizes=[224, 448] 将产生与 scales=[1,2] 相同的结果。

  • max_split_size:从大图像分割出来的子图像的最大尺寸。对于每个尺寸来说,图像将被分割成 ceil(img_size_that_scale / max_split_size)**2 个子图像。如果为 None,则默认设置为 input 的大小。

  • resize_output_to_idx:将最终输出的特征图调整到哪个尺寸。默认值是 scalesimg_sizes 中的第一个比例。

  • num_prefix_token:特征图中 prefix token 的数量。例如,如果 model 返回的特征图包含 1 个 [CLS] token和其他的 spatial token,则设置 num_prefix_token=1。默认为 0

  • output_shape:输出特征的形状。要么为 bnc (e.g., ViT) 或 bchw (e.g., ConvNet). 默认为 bnc

3.2.2 forward 函数的处理逻辑

下面是对 core.py 文件中的 forward 函数的源码进行解析,这同样也是 Scaling on Scales 方法的整体处理逻辑
本例中我们使用的 model 为预训练好的 VIT-Base-16input 的维度为 (1, 3, 224, 224)scales 设置为 [1, 2],其余参数的配置均保持不变。

注意:将输入为 (1, 3, 224, 224) 的图像张量输入到 ViT-Base-16 中,其最后一层 Block 输出的特征维度为 (1, 197, 768)

def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
            output_shape='bnc'):

    assert input.dim() == 4,  "Input image must be in the shape of BxCxHxW."
    assert input.shape[2] == input.shape[3],  "Currently only square images are supported."
    assert output_shape in ['bnc', 'bchw'],  "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
    assert output_shape == 'bnc' or num_prefix_token == 0,  "For ConvNet there shouldn't be any prefix token."

    b, c, input_size, _ = input.shape

    assert scales is not None or img_sizes is not None,  "Please assign either scales or img_sizes."
    # img_sizes 中存放的是图像的各个尺寸, 如 224, 448, 672 等
    img_sizes = img_sizes or [input_size * scale for scale in scales]

    # 划分的子图像的最大尺寸, 默认为输入图像的尺寸
    max_split_size = max_split_size or input_size
    # 计算对应尺寸下应划分出的子图像数量
    num_splits = [math.ceil(size / max_split_size) for size in img_sizes]

    input_multiscale = []
    for size, num_split in zip(img_sizes, num_splits):
        # 通过对输入图像进行插值缩放到对应的尺寸
        x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
        # 将缩放后的图像划分为 num_split 个子图像
        x = split_chessboard(x, num_split=num_split)
        input_multiscale.append(x)

    # 在每个尺寸下做一次前向传递, 并将输出结果存放到 outs_multiscale
    outs_multiscale = [model(x) for x in input_multiscale]
    if num_prefix_token > 0:
        # 提取多个尺寸对应特征的 prefix_token. 如果模型是 ViT 的话, 就是 CLS Token
        outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
        # 然后更新 outs_multiscale (剔除了 prefix_token)
        outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
    if output_shape == 'bnc':
        # 将 outs_multiscale 中多个尺寸下的输出进行维度上的重新排列
        # 例如 (1, 196, 768) -> (1, 768, 14, 14)
        outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
                           for out in outs_multiscale]

    # 在每个尺寸下, 对其划分的多个子图像的输出结果进行合并
    outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]

    # output_size 其实是所有尺寸下合并后的特征最终要 resize 到的大小
    output_size = outs_multiscale[resize_output_to_idx].shape[-2]
    # 将不同尺寸下的合并结果进行插值后 (对应着处理逻辑中的 Pooling 操作)
    # 再按着维度 1 进行拼接
    out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
                                   mode='area').to(outs_multiscale[i].dtype)
                     for i in range(len(outs_multiscale))], dim=1)

    if output_shape == 'bnc':
        out = rearrange(out, 'b c h w -> b (h w) c')
    if num_prefix_token > 0:
        # 在每个尺寸下, 对其所有子图像对应的 prefix token 进行合并取均值操作
        outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
        # 将所有尺寸的 CLS token 沿着最后一个维度进行拼接
        out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
        # 将合并后的 out 和 out_prefix_multiscale 进行拼接, 从而完成 Scaling on Scales 的全部流程
        out = torch.cat([out_prefix_multiscale, out], dim=1)

    return out

4. 消融实验

  1. 将从大尺寸图像中划分出来的子图像输入到视觉模型与直接将大尺寸图像输入到视觉模型进行对比。结果表明,将大尺寸图像划分/切分成多个子图像作为输入的性能表现更好
    split ablation
  2. 将不同尺寸的特征图进行拼接 (通道维度上的拼接) 和直接将它们加在一起 (通道维度上的求和) 进行对比。结果表明,将不同尺寸的特征图进行拼接 (通道维度上的拼接) 的性能表现更好
    concatenation ablation

5. 未来可能的改进方向

  1. Scale-selective processing即图像中每个位置的每个尺寸并非都包含同样有用的特征;根据图像内容和高级任务,为每个区域选择某些尺寸去进行处理会更有效,这类似于人类视觉注意力中自下而上和自上而下的选择机制。
  2. parallel processing of single image:与普通 ViT 一次性处理整幅图像不同,S^2 中的每个子图像都是独立处理的,因此可以并行处理单幅图像的不同子图像,这对于处理单幅大图像的延迟要求很高的场景来讲是很有用的。

标签:multiscale,num,Larger,Scaling,源码,尺寸,split,图像,out
From: https://blog.csdn.net/Eternity666long/article/details/137047532

相关文章

  • 【附源码】Node.js毕业设计个人健康信息记录移动应用app(Express)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:随着科技的进步和互联网的普及,移动应用已经成为人们日常生活中不可或缺的一部分。在健康管理领域,个人健康信息记录移动应用APP的开发和应用也日益受到关注......
  • 【附源码】Node.js毕业设计个人健康管理小助手(Express)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:随着社会节奏的加快和工作压力的增大,个人健康管理成为了人们日益关注的焦点。传统的健康管理方式往往需要用户手动记录健康数据,如饮食、运动、睡眠等,然后进......
  • 【附源码】Node.js毕业设计个人财务管理系统(Express)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:在当今社会,随着经济的快速发展和人们生活水平的提高,个人财务管理已经成为了我们生活中不可或缺的一部分。无论是日常生活的消费记录、投资理财,还是购房、购......
  • 【MATLAB源码-第15期】基于matlab的MSK的理论误码率与实际误码率BER对比仿真,采用差分
    操作环境:MATLAB2022a1、算法描述在数字调制中,最小频移键控(Minimum-ShiftKeying,缩写:MSK)是一种连续相位调制的频移键控方式,在1950年代末和1960年代产生。[1]与偏移四相相移键控(OQPSK)类似,MSK同样将正交路基带信号相对于同相路基带信号延时符号间隔的一半,从而消除了已调信号......
  • 【MATLAB源码-第16期】基于matlab的MSK定是同步仿真,采用gardner算法和锁相环。
    操作环境:MATLAB2022a1、算法描述**锁相环(PLL)**是一种控制系统,用于将一个参考信号的相位与一个输入信号的相位同步。它在许多领域中都有应用,如通信、无线电、音频、视频和计算机系统。锁相环通常由以下几个关键组件组成:1.**相位比较器(PhaseComparator):**这个组件比较输......
  • 《Android Framework源码解析》全网最详尽的Android系统框架层的指南,不容错过!!
    前言在当今数字化时代,移动应用已成为我们日常生活中不可或缺的一部分。随着技术的不断进步,Android作为全球领先的移动操作系统,其市场份额和影响力持续扩大。开发者们面临着一个充满活力且竞争激烈的市场环境,用户对应用的体验和性能要求日益提高。在这样的背景下,深入了解And......
  • Tomcat源码解析(二)
     1.项目源码结构2.Tomcat源码结构 1.在javax中保存的是新的JavaEE规范。可以具体来看看每个目录的作用。模块作用说明annotationannotation这个模块的作用是定义了一些公用的注解,避免在不同的规范中定义相同的注解ejbejb是个古老的传说,我们不管el在jsp中......
  • Building an Automatically Scaling Web Application
    2024年春季云计算课业1:构建一个自动伸缩的Web应用程序截止日期:2024年4月15日,星期一1目标和范围在这项任务中,我们将为(非常)琐碎的Web构建一个小型的自动伸缩测试平台应用任务的目标是熟悉伸缩Web的各个方面应用程序,这将提高您对低级/基本实现的理解云系统的详细信息。正如我们在......
  • java计算机毕业设计(附源码)新知书店(ssm+mysql+maven+LW文档)
    本系统(程序+源码)带文档lw万字以上  文末可领取本课题的JAVA源码参考系统程序文件列表系统的选题背景和意义选题背景:新知书店,作为一家专注于传播知识和文化的零售场所,承载着促进社会文化发展和满足人们精神需求的重要使命。在数字化时代背景下,实体书店面临着前所未有的挑......
  • 智慧工地解决方案,智慧工地项目管理系统源码,支持大屏端、PC端、手机端、平板端
    智慧工地解决方案依托计算机技术、物联网、云计算、大数据、人工智能、VR&AR等技术相结合,为工程项目管理提供先进技术手段,构建工地现场智能监控和控制体系,弥补传统方法在监管中的缺陷,最线实现项目对人、机、料、法、环的全方位实时监控。支持多端展示(大屏、PC端、手机端、平板......