随着Meta发布的Segment Anything Model(SAM),计算机视觉迎来了ChatGPT时刻。SAM经过超过110亿个分割掩码的训练,是预测性人工智能用例而非生成性人工智能的基础模型。虽然它在广泛的图像模式和问题空间上表现出了令人难以置信的灵活性,但它的发布没有“微调”功能。
本教程将概述使用掩码解码器微调SAM的一些关键步骤,特别是描述SAM的哪些函数用于预/后处理数据,使其处于良好的微调状态。
What is the Segment Anything Model (SAM)?
分段任意模型(SAM)是Meta AI开发的一个分段模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万张图像和数十亿个口罩的庞大数据库上进行训练的,这使得它非常强大。顾名思义,SAM能够为各种图像生成准确的分割掩模。SAM的设计使其能够将人工提示考虑在内,使其对“循环中的人工”注释特别强大。这些提示可以是多模式的:它们可以是要分割的区域上的点、要分割的对象周围的边界框,或者关于应该分割的内容的文本提示。
该模型分为三个部分:图像编码器、提示编码器和掩码解码器。
图像编码器为被分割的图像生成嵌入,而提示编码器为提示生成嵌入。图像编码器是该模型的一个特别大的组件。这与基于嵌入预测分割掩码的轻量级掩码解码器形成对比。Meta AI已经将在Segment Anything 10 Billion Mask(SA-1B)数据集上训练的模型的权重和偏差作为模型检查点。
在解释者博客文章中了解更多关于Segment Anything如何工作的信息:https://encord.com/blog/segment-anything-model-explained/
What is Model Fine-Tuning?
公开提供的最先进的模型具有自定义架构,通常提供预先训练的模型权重。如果这些体系结构是在没有权重的情况下提供的,那么用户将需要从头开始训练模型,他们将需要使用大量数据集来获得最先进的性能。
模型微调是采用预先训练好的模型(架构+权重)并向其显示特定用例的数据的过程。这通常是模型以前从未见过的数据,或者在其原始训练数据集中代表性不足的数据。
微调模型和从头开始之间的区别在于权重和偏差的起始值。如果我们从头开始训练,这些将根据一些策略随机初始化。在这样的启动配置中,模型将对手头的任务“一无所知”,并表现不佳。通过使用预先存在的权重和偏差作为起点,我们可以“微调”权重和偏差,以便我们的模型在自定义数据集上更好地工作。例如,学会识别猫的信息(边缘检测、计数爪子)将有助于识别狗。
Why Would I Fine-Tune a Model?
微调模型的目的是在预先训练的模型以前没有看到的数据上获得更高的性能。例如,在从手机摄像头收集的大量数据上训练的图像分割模型将主要从水平角度看到图像。
如果我们试图将这个模型用于从垂直角度拍摄的卫星图像,它可能不会表现得那么好。如果我们试图分割屋顶,该模型可能不会产生最佳结果。预训练是有用的,因为模型通常已经学会了如何分割对象,所以我们希望利用这个起点来构建一个可以准确分割屋顶的模型。此外,我们的自定义数据集可能没有数百万个示例,因此我们希望进行微调,而不是从头开始训练模型。
微调是可取的,这样我们就可以在特定的用例中获得更好的性能,而不必承担从头开始训练模型的计算成本。
How to Fine-Tune Segment Anything Model [With Code]
背景与架构
我们在介绍部分概述了SAM体系结构。图像编码器具有具有许多参数的复杂结构。为了微调模型,我们有必要关注掩码解码器,它重量轻,因此更容易、更快、更高效地进行微调。
为了微调SAM,我们需要提取其架构的底层部分(图像和提示编码器、掩码解码器)。我们无法使用SamPredictor.predict(链接):
我们只想微调掩码解码器
这个函数调用SamPredictor.predict_tarch,它有@torch.no_grad()装饰器(链接),它阻止我们计算梯度
因此,我们需要检查SamPredictor.prpredict函数,并在我们想要微调的部分(掩码解码器)启用梯度计算的情况下调用适当的函数。这样做也是了解更多SAM如何工作的好方法。
Creating a Custom Dataset
我们需要三件事来微调我们的模型:
要在其上绘制分割的图像
分割地面实况掩码
提示输入到模型中
我们选择了印章验证数据集(链接),因为它有SAM在其训练中可能没有看到的数据(即,文件上的印章)。我们可以通过使用预先训练的权重运行推理来验证它在该数据集上的表现良好,但并不完美。ground truth masks也非常精确,这将使我们能够计算出准确的损失。最后,这个数据集包含分割掩码周围的边界框,我们可以将其用作SAM的提示。下面显示了一个示例图像。这些边界框与人工注释器在生成分段时要经过的工作流程非常一致。
Input Data Preprocessing
我们需要对从numpy数组到pytorch张量的扫描进行预处理。要做到这一点,我们可以遵循SamPredictor.set_image(链接)和预处理图像的SamPredictor.set_arch_image(链接)内部发生的情况。首先,我们可以使用utils.transform。ResizeLongestSide可调整图像的大小,因为这是预测器(链接)内部使用的转换器。然后,我们可以将图像转换为pytorch张量,并使用SAM预处理方法(链接)完成预处理。
Training Setup
我们下载vit_b模型的模型检查点,并将其加载到:
sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
我们可以使用默认值设置Adam优化器,并指定要调整的参数是掩码解码器的参数:
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())
同时,我们可以设置我们的损失函数,例如均方误差
loss_fn = torch.nn.MSELoss()
Training Loop
在主训练循环中,我们将迭代我们的数据项,生成掩码,并将它们与我们的ground truth掩码进行比较,以便我们可以基于损失函数优化模型参数。
在这个例子中,我们使用GPU进行训练,因为它比使用CPU快得多。在适当的张量上使用.to(设备)是很重要的,以确保CPU上没有某些张量,GPU上没有其他张量。
我们希望通过将编码器封装在torch.no.grad()上下文管理器中来嵌入图像,因为否则我们将出现内存问题,同时我们不希望微调图像编码器。
with torch.no_grad(): image_embedding = sam_model.image_encoder(input_image)
我们还可以在no.grad上下文管理器中生成提示嵌入。我们使用边界框坐标,转换为pytorch张量。
with torch.no_grad(): sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( points=None, boxes=box_torch, masks=None, )
最后,我们可以生成遮罩。请注意,这里我们处于单掩码生成模式(与正常输出的3个掩码形成对比)。
low_res_masks, iou_predictions = sam_model.mask_decoder( image_embeddings=image_embedding, image_pe=sam_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, )
这里的最后一步是将遮罩升级回原始图像大小,因为它们的分辨率较低。我们可以使用Sam.postprocess_masks来实现这一点。我们还希望从预测的掩码中生成二进制掩码,以便将其与我们的基本事实进行比较。为了不破坏反向传播,使用torch泛函是很重要的。
点击查看代码
upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
from torch.nn.functional import threshold, normalize
binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)
点击查看代码
loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
SAM论文链接:https://arxiv.org/pdf/2304.02643
标签:SAM,finetune,微调,我们,图像,掩码,视觉,模型 From: https://www.cnblogs.com/SunshineWeather/p/18209059