SAM+使用SAM应用数据集完成分割

2023-12-13 03:52:47

什么是SAM?

????????SAM(Segment Anything Model)是由 Meta 的研究人员团队创建和训练的深度学习模型。在 Segment everything 研究论文中,SAM 被称为“基础模型”。

????????基础模型是在大量数据上训练的机器学习模型(通常通过自监督或半监督学习),其目的是在更具体的任务上使用和重新训练。SAM 是一个预训练模型,旨在适应其他任务(特别是通过微调)。

sam安装

下载安装SAMhttps://github.com/facebookresearch/segment-anything

安装 Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

或在本地克隆存储库并使用

git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

Github页面里点击下载一个或者多个模型:

模型文件放到项目的目录即可。

H,L,B分别表示huge,large,base,从大到小。根据硬件能力选择合适的模型。

下列依次ViT-H SAM模型(vit_h),ViT-L?SAM 模型(vit_1), ViT-B SAM 模型(vit_b)

????https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth?

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

?使用

?方法一:使用官方命令

建立input,output文件夹

在input中存放待分割的图片,output用作存放输出的mask。

在这我使用的是vit_h

python3 scripts/amg.py --checkpoint ./sam_vit_h_4b8939.pth --model-type default --input ./input.jpeg --output output

官方命令即执行amg.py文件,并传入了一些参数,当传入参数固定时可以直接写在amg.py文件中。

方法二:

# coding=gb2312
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
device = "cuda"
sam = sam_model_registry["default"](checkpoint="你下载的权重的位置")
#sam_vit_h_4b8939.pth 是预训练的默认权重,需要单独下载
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

image = cv2.imread('图片 位置.jpeg')
masks = mask_generator.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

?

参考:?

https://zhuanlan.zhihu.com/p/627535252

文章来源:https://blog.csdn.net/huaooo1/article/details/134828066
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。