如何Fine-Tune微调SAM

2023-12-27 11:28:49

转眼已经到了2023年的末尾,年初ChatGPT爆火,随后SAM横空出世,给今年的科技圈带来了众多看点,在SAM刚刚发布的时候我们也做过相关的实践,感兴趣的话可以自行移步阅读:

《Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了》

《Segment Anything Model (SAM)——分割一切,具有预测提示输入的图像分割实践》

《SAM-FAST:Accelerating Generative AI with PyTorch: Segment Anything, Fast基于官方PyTorch团队开发原生SAM提速8倍》

Segment Anything Model(SAM)的发布让计算机视觉迎来了ChatGPT时刻。SAM经过超过110亿个分割掩码的训练,是预测性人工智能用例而非生成性人工智能的基础模型。虽然它在广泛的图像模式和问题空间上表现出了令人难以置信的灵活性,但它的发布没有“微调”功能。

本教程将概述使用掩码解码器微调SAM的一些关键步骤,特别是描述SAM的哪些函数用于预/后处理数据,使其处于良好的微调状态。

什么是(SAM)?

分割一切模型(SAM)是Meta AI开发的一个分段模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万张图像和数十亿个mask的庞大数据库上进行训练的,这使得它非常强大。顾名思义,SAM能够为各种图像生成准确的分割掩模。SAM的设计使其能够将人工提示考虑在内,使其对“循环中的人工”注释特别强大。这些提示可以是多模式的:它们可以是要分割的区域上的点、要分割的对象周围的边界框或关于应该分割的内容的文本提示。

该模型分为三个部分:图像编码器、提示编码器和掩码解码器。

显示Segment Anything(SA)模型的基础模型体系结构的图像

官方论文在这里,如下所示:

图像编码器为被分割的图像生成嵌入,而提示编码器为提示生成嵌入。图像编码器是模型中一个特别大的组件。这与基于嵌入预测分割掩码的轻量级掩码解码器形成对比。Meta AI已经将在Segment Anything 10 Billion Mask(SA-1B)数据集上训练的模型的权重和偏差作为模型检查点。

什么是模型微调?

公开可用的现有技术模型具有自定义架构,并且通常提供有预先训练的模型权重。如果这些架构是在没有权重的情况下提供的,那么用户将需要从头开始训练模型,用户将需要使用大量数据集来获得最先进的性能。

模型微调是采用预先训练好的模型(体系结构+权重)并向其显示特定用例的数据的过程。这通常是模型以前从未见过的数据,或者在其原始训练数据集中代表性不足的数据。

微调模型和从头开始之间的区别在于权重和偏差的起始值。如果我们从头开始训练,这些将根据一些策略随机初始化。在这样的启动配置中,模型将对手头的任务“一无所知”,并表现不佳。通过使用预先存在的权重和偏差作为起点,我们可以“微调”权重和偏差,以便我们的模型在自定义数据集上更好地工作。例如:学会识别猫的信息(边缘检测、计数爪子)将有助于识别狗。

为什么要微调模型?

微调模型的目的是在预先训练的模型以前没有看到的数据上获得更高的性能。例如,在从手机摄像头收集的大量数据上训练的图像分割模型将主要从水平角度看到图像。

如果我们试图将这个模型用于从垂直角度拍摄的卫星图像,它可能不会表现得那么好。如果我们试图分割屋顶,该模型可能不会产生最佳结果。预训练是有用的,因为模型通常已经学会了如何分割对象,所以我们想利用这个起点来建立一个可以准确分割屋顶的模型。此外,我们的自定义数据集可能没有数百万个示例,因此我们希望进行微调,而不是从头开始训练模型。

微调是可取的,这样我们就可以在特定的用例中获得更好的性能,而不必承担从头开始训练模型的计算成本。

如何微调分段任意模型?

背景与架构

我们在介绍部分概述了SAM体系结构。图像编码器具有具有许多参数的复杂结构。为了微调模型,我们有必要关注掩码解码器,它重量轻,因此更容易、更快、更高效地进行微调。

为了微调SAM,我们需要提取其架构的底层部分(图像和提示编码器、掩码解码器)。由于两个原因,我们无法使用SamPredictor.predict

1、我们只想微调掩码解码器

2、这个函数调用SamPredictor.predict_tarch,它有@torch.no_grad()装饰器,它阻止我们计算梯度

因此,我们需要检查SamPredictor.prpredict函数,并在我们想要微调的部分(掩码解码器)启用梯度计算的情况下调用适当的函数。这样做也是了解更多SAM如何工作的好方法。

创建自定义数据集

我们需要完成三件事来微调我们的模型:

1、要在其上绘制分割的图像

2、分割实况掩码

3、提示输入到模型中

我们选择了印章验证数据集因为它有SAM在训练中可能没有看到的数据(即,在文档上盖章)。我们可以通过使用预先训练的权重运行推理来验证它在该数据集上的表现良好,但并不完美。地面实况面具也非常精确,这将使我们能够计算出准确的损失。最后,这个数据集包含分割掩码周围的边界框,我们可以将其用作SAM的提示。下面显示了一个示例图像。这些边界框与人工注释器在生成分段时要经过的工作流程非常一致。

输入数据预处理

我们需要对从numpy数组到pytorch张量的扫描进行预处理。要做到这一点,我们可以遵循SamPredictor.set_image和预处理图像的SamPredictor.set_arch_image内部发生的情况。首先,我们可以使用utils.transform.ResizeLongestSide来调整图像的大小,因为这是预测器内部使用的转换器。然后,我们可以将图像转换为pytorch张量,并使用SAM预处理方法完成预处理。

训练设置

我们下载vit_b模型的模型检查点,并将其加载到:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

我们可以使用默认值设置Adam优化器,并指定要调整的参数是掩码解码器的参数:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters()) 

同时,我们可以设置我们的损失函数,例如均方误差

loss_fn = torch.nn.MSELoss()

训练循环

在主训练循环中,我们将迭代我们的数据项,生成掩码,并将其与我们的地面实况掩码进行比较,以便我们可以基于损失函数优化模型参数。

在这个例子中,我们使用GPU进行训练,因为它比使用CPU快得多。在适当的张量上使用.to(设备)是很重要的,以确保CPU上没有某些张量,GPU上没有其他张量。

我们希望通过将编码器封装在torch.no.grad()上下文管理器中来嵌入图像,因为否则我们将出现内存问题,同时我们不希望微调图像编码器。

with torch.no_grad():
	image_embedding = sam_model.image_encoder(input_image)

我们还可以在no.grad上下文管理器中生成提示嵌入。我们使用边界框坐标,转换为pytorch张量。

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

最后,我们可以生成掩码。请注意,这里我们处于单掩码生成模式(与正常输出的3个掩码形成对比)。

low_res_masks, iou_predictions = sam_model.mask_decoder(
  image_embeddings=image_embedding,
  image_pe=sam_model.prompt_encoder.get_dense_pe(),
  sparse_prompt_embeddings=sparse_embeddings,
  dense_prompt_embeddings=dense_embeddings,
  multimask_output=False,
)

这里的最后一步是将掩码升级回原始图像大小,因为它们的分辨率较低。我们可以使用Sam.postprocess_masks来实现这一点。我们还希望从预测的掩码中生成二进制掩码,以便将其与我们的基本事实进行比较。为了不破坏反向传播,使用torch泛函是很重要的。

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

from torch.nn.functional import threshold, normalize

binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

最后,我们可以计算损失并运行优化步骤:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

通过在多个epoch和批次上重复这一过程,我们可以微调SAM解码器。

保存检查点并从中启动模型

一旦我们完成了训练并对性能提升感到满意,我们就可以使用以下方法保存调整模型的状态dict:

torch.save(model.state_dict(), PATH)

然后,当我们想对与我们用来微调模型的数据相似的数据执行推理时,我们可以加载这个状态dict。

针对下游应用的微调

虽然SAM目前不提供开箱即用的微调,但我们正在构建一个与Encord平台集成的自定义微调调谐器。如本文所示,为了实现这一点,我们对解码器进行了微调。这在web应用程序中是一个开箱即用的一键过程,可以自动设置超参数。

微调前

微调后

我们可以看到,这个掩码比原来的掩码更紧。这是对印章验证数据集中的一小部分图像进行微调的结果,然后在一个以前看不见的例子上运行调整后的模型。通过进一步的训练和更多的例子,我们可以获得更好的结果。

结论

现在已经学会了如何微调分段任意模型(SAM)。如果您想开箱即用地微调SAM,您可能也有兴趣了解我们最近在Encord中发布了Segment Anything模型,允许您在不编写任何代码的情况下微调模型。

参考

How To Fine-Tune Segment Anything

文章来源:https://blog.csdn.net/Together_CZ/article/details/135237323
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。