首页 > 编程语言 >【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二)

【Python&语义分割】Segment Anything(SAM)模型全局语义分割代码+掩膜保存(二)

时间:2023-10-12 11:59:08浏览次数:53  
标签:-% 分割 掩膜 语义 mask masks path data

 我上篇博文分享了Segment Anything(SAM)模型的基本操作,这篇给大家分享下官方的整张图片的语义分割代码(全局),同时我还修改了一部分支持掩膜和叠加影像的保存。

1 Segment Anything介绍

1.1 概况

        Meta AI 公司的 Segment Anything 模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。

        论文地址:https://arxiv.org/abs/2304.02643

        项目地址:Segment Anything

1.2 使用方法

        具体使用方法上,Segment Anything 提供了简单易用的接口,用户只需要通过提示,即可进行物体识别和分割操作。例如在图片处理中,用户可以通过 Hover & Click 或 Box 等方式来选取物体。值得一提的是,SAM 还支持通过上传自己的图片进行物体分割操作,提取物体用时仅需数秒。

        总的来说,Meta AI 的 Segment Anything 模型为我们提供了一种全新的物体识别和分割方式,其强大的泛化能力和广泛的应用前景将极大地推动计算机视觉领域的发展。未来,我们期待看到更多基于 Segment Anything 的创新应用,以及在科学图像分析、照片编辑等领域的广泛应用。

​​2 模型代码+注释

2.1 模型预加载

        我这里将掩膜生成的函数单独拿出来了,因为里面集成了掩膜保存的代码。所以先给大家看预处理部分。

    try:
        image = cv2.imread(image_path)  # 读取的图像以NumPy数组的形式存储在变量image中
        print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
        print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    except:
        print("图片打开失败!请检查路径!")
        pass
        sys.exit()
    sys.path.append("..")  # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
    sam_checkpoint = model_path  # 定义模型路径

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)  # 定义模型参数
    mask_generator = SamAutomaticMaskGenerator(model=sam,  # 用于掩膜预测的SAM模型
                                               points_per_side=32,  # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
                                               # points_per_batch=64,  # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
                                               pred_iou_thresh=0.86,  # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
                                               stability_score_thresh=0.92,
                                               # 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
                                               # stability_score_offset=1.0,  # 计算稳定性分数时,对截止点的偏移量
                                               # box_nms_thresh=0.7,  # 非最大抑制用于过滤重复掩码的箱体IoU截止点
                                               crop_n_layers=1,  # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
                                               # crop_nms_thresh=0.7,  # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
                                               # crop_overlap_ratio=512 / 1500,  # 设置作物重叠的程度
                                               crop_n_points_downscale_factor=2,
                                               # 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
                                               # point_grids=None,  # 用于取样的明确网格的列表,归一化为[0,1]
                                               min_mask_region_area=100,
                                               # 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
                                               # output_mode="binary_mask"  # 掩模的返回形式。
                                               # 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
                                               # coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
                                               )  # 激活函数

2.2 模型预测代码

masks = mask_generator.generate(image)  # 类别掩膜提取(包含所有的,可按照索引查看)

# ---------------------------masks输出内容---------------------------
# segmentation : np的二维数组,为二值的mask图片
# area : mask的像素面积
# bbox : mask的外接矩形框,为X Y WH格式
# predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
# point_coords : 用于生成该mask的point输入
# stability_score : mask质量的附加指标
# crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
# ------------------------------------------------------------------

print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.figure(figsize=(20, 20))  # 创建一个新的图形窗口,设置其大小为10x10英寸
plt.imshow(image)  # 使用imshow函数在创建的图形窗口中显示图像
print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("【结果保存阶段】")
show_mask_auto(masks, out_path, out_path1)
plt.axis('on')  # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
plt.savefig(out_image_path, dpi=300)
plt.show()  # 显示已经创建的图形窗口和其中的内容

2.3 掩膜生成+保存代码

        我这里在官方的掩膜生成的函数的基础上,加入了两段保存数据的代码。一个是彩色的mask(叠加显示的mask),一个是单波段的mask(DN值代表序号)。

        大家在使用这个函数时,将这段放在2.1,2.2展示的代码前面即可。

def show_mask_auto(masks_data, out_mask_path, out_path_01):
    """
    :param masks_data: 掩膜数据
    :param out_mask_path: 输出彩色掩膜
    :param out_path_01: 输出单波段掩膜
    :return: None
    """
    if len(masks_data) == 0:
        return
    sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True)  # 按照面积大小降序排列
    ax = plt.gca()  # 获取当前的轴(axes)
    ax.set_autoscale_on(False)  # 关闭轴的自动缩放功能
    img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
    # 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
    img[:, :, 3] = 0  # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
    img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
                          sorted_masks_data[0]['segmentation'].shape[1]))
    # 创建一个二维数组,用于保存掩膜做栅格转面
    j = 1
    for sorted_mask_data in sorted_masks_data:
        # 循环所有类别的掩膜
        m = sorted_mask_data['segmentation']
        # 获取当前类别的二值mask图片
        color_mask = np.concatenate([np.random.random(3), [0.65]])
        # 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
        img[m] = color_mask
        # 将颜色赋予图片的数组
        img_raster[m] = j
        # 给掩膜赋值
        j += 1
    """for i in range(0, len(masks_data)):
        # 循环所有类别的掩膜
        rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
                                 masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
                                 facecolor='none', linewidth=2)  # 绘制类别的外接矩形框
        ax.add_patch(rect)  # 将矩形添加到ax对象中"""
    plt.imshow(img, alpha=0.8)
    print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    driver = gdal.GetDriverByName('GTiff')  # 载入数据驱动,用于存储内存中的数组
    ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
                              sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
    # 创建一个数组,宽高为原始尺寸
    for i in range(3):
        ds_result.GetRasterBand(i+1).SetNoDataValue(0)  # 将无效值设为0
        ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i])  # 将结果写入数组
    ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
                                     sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
    # ds_result.SetGeoTransform(ds_geo)  # 导入仿射地理变换参数
    # ds_result.SetProjection(ds_prj)  # 导入投影信息
    ds_result_raster.GetRasterBand(1).SetNoDataValue(0)  # 将无效值设为0
    ds_result_raster.GetRasterBand(1).WriteArray(img_raster)  # 将结果写入数组
    del ds_result
    del ds_result_raster

3 完整代码

# -*- coding: utf-8 -*-
"""
@Time : 2023/10/8 10:15
@Auth : RS迷途小书童
@File :Segment Anything Auto.py
@IDE :PyCharm
@Purpose:Segment Anything Model自动全局语义分割
"""
import sys
import cv2
import random
import numpy as np
from osgeo import gdal
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator


def SAM_auto(image_path, model_path, model_type, device, out_path, out_path1, out_image_path):
    """
    :param image_path: 输入需要分割的影像
    :param model_path: 输入模型路径
    :param model_type: 输入模型类型
    :param device: 输入cpu or cuda
    :param out_path: 输出彩色掩膜文件
    :param out_path1: 输出单波段掩膜文件
    :param out_image_path: 输出叠加图片
    :return: None
    """

    def show_mask_auto(masks_data, out_mask_path, out_path_01):
        """
        :param masks_data: 掩膜数据
        :param out_mask_path: 输出彩色掩膜
        :param out_path_01: 输出单波段掩膜
        :return: None
        """
        if len(masks_data) == 0:
            return
        sorted_masks_data = sorted(masks_data, key=(lambda x: x['area']), reverse=True)  # 按照面积大小降序排列
        ax = plt.gca()  # 获取当前的轴(axes)
        ax.set_autoscale_on(False)  # 关闭轴的自动缩放功能
        img = np.ones((sorted_masks_data[0]['segmentation'].shape[0], sorted_masks_data[0]['segmentation'].shape[1], 4))
        # 创建了一个新的三维数组img。数组的形状是基于segmentation']的形状,其中四个通道通常代表红色、绿色、蓝色和透明度(RGBA)
        img[:, :, 3] = 0  # 将新创建的图像的第四个通道(也就是透明度通道)设置为0
        img_raster = np.zeros((sorted_masks_data[0]['segmentation'].shape[0],
                              sorted_masks_data[0]['segmentation'].shape[1]))
        # 创建一个二维数组,用于保存掩膜做栅格转面
        j = 1
        for sorted_mask_data in sorted_masks_data:
            # 循环所有类别的掩膜
            m = sorted_mask_data['segmentation']
            # 获取当前类别的二值mask图片
            color_mask = np.concatenate([np.random.random(3), [0.65]])
            # 随机生成的RGB颜色,它的形状为(3,),0.65表示颜色的透明度。
            img[m] = color_mask
            # 将颜色赋予图片的数组
            img_raster[m] = j
            # 给掩膜赋值
            j += 1
        """for i in range(0, len(masks_data)):
            # 循环所有类别的掩膜
            rect = patches.Rectangle((masks_data[i]['bbox'][0], masks_data[i]['bbox'][1]), masks_data[i]['bbox'][2],
                                     masks_data[i]['bbox'][3], edgecolor=tuple(random.uniform(0, 1) for _ in range(3)),
                                     facecolor='none', linewidth=2)  # 绘制类别的外接矩形框
            ax.add_patch(rect)  # 将矩形添加到ax对象中"""
        plt.imshow(img, alpha=0.8)
        print("[%s]正在保存类别掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        driver = gdal.GetDriverByName('GTiff')  # 载入数据驱动,用于存储内存中的数组
        ds_result = driver.Create(out_mask_path, sorted_masks_data[0]['segmentation'].shape[1],
                                  sorted_masks_data[0]['segmentation'].shape[0], bands=4, eType=gdal.GDT_Float64)
        # 创建一个数组,宽高为原始尺寸
        for i in range(3):
            ds_result.GetRasterBand(i+1).SetNoDataValue(0)  # 将无效值设为0
            ds_result.GetRasterBand(i+1).WriteArray(img[:, :, i])  # 将结果写入数组
        ds_result_raster = driver.Create(out_path_01, sorted_masks_data[0]['segmentation'].shape[1],
                                         sorted_masks_data[0]['segmentation'].shape[0], bands=1, eType=gdal.GDT_Float64)
        # ds_result.SetGeoTransform(ds_geo)  # 导入仿射地理变换参数
        # ds_result.SetProjection(ds_prj)  # 导入投影信息
        ds_result_raster.GetRasterBand(1).SetNoDataValue(0)  # 将无效值设为0
        ds_result_raster.GetRasterBand(1).WriteArray(img_raster)  # 将结果写入数组
        del ds_result
        del ds_result_raster

    print("【程序准备阶段】")
    print("[%s]正在读取图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    try:
        image = cv2.imread(image_path)  # 读取的图像以NumPy数组的形式存储在变量image中
        print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
        print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    except:
        print("图片打开失败!请检查路径!")
        pass
        sys.exit()
    sys.path.append("..")  # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
    sam_checkpoint = model_path  # 定义模型路径

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)  # 定义模型参数
    mask_generator = SamAutomaticMaskGenerator(model=sam,  # 用于掩膜预测的SAM模型
                                               points_per_side=32,  # 图像一侧的采样点数,总采样点数是一侧采样点数的平方,点数给的越多,分割越细
                                               # points_per_batch=64,  # 设置模型同时运行的点的数量。更高的数字可能会更快,但会使用更多的GPU内存
                                               pred_iou_thresh=0.86,  # 滤波阈值,在[0,1]中,使用模型的预测掩膜质量0.86
                                               stability_score_thresh=0.92,
                                               # 滤波阈值,在[0,1]中,使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性0.92
                                               # stability_score_offset=1.0,  # 计算稳定性分数时,对截止点的偏移量
                                               # box_nms_thresh=0.7,  # 非最大抑制用于过滤重复掩码的箱体IoU截止点
                                               crop_n_layers=1,  # 如果>0,蒙版预测将在图像的裁剪上再次运行。设置运行的层数,其中每层有2**i_layer的图像裁剪数1
                                               # crop_nms_thresh=0.7,  # 非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值
                                               # crop_overlap_ratio=512 / 1500,  # 设置作物重叠的程度
                                               crop_n_points_downscale_factor=2,
                                               # 在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减2
                                               # point_grids=None,  # 用于取样的明确网格的列表,归一化为[0,1]
                                               min_mask_region_area=100,
                                               # 如果>0,后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。50
                                               # output_mode="binary_mask"  # 掩模的返回形式。
                                               # 可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。
                                               # coco_rle’需要pycocotools。对于大的分辨率,'binary_mask’可能会消耗大量的内存
                                               )  # 激活函数
    print("【模型预测阶段】")
    print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    masks = mask_generator.generate(image)  # 类别掩膜提取(包含所有的,可按照索引查看)

    # ---------------------------masks输出内容---------------------------
    # segmentation : np的二维数组,为二值的mask图片
    # area : mask的像素面积
    # bbox : mask的外接矩形框,为X Y WH格式
    # predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
    # point_coords : 用于生成该mask的point输入
    # stability_score : mask质量的附加指标
    # crop_box : 用于以X Y WH格式生成此遮罩的图像裁剪
    # ------------------------------------------------------------------

    print("[%s]正在绘制图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    plt.figure(figsize=(20, 20))  # 创建一个新的图形窗口,设置其大小为10x10英寸
    plt.imshow(image)  # 使用imshow函数在创建的图形窗口中显示图像
    print("[%s]正在制作掩膜......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print("【结果保存阶段】")
    show_mask_auto(masks, out_path, out_path1)
    plt.axis('on')  # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
    print("[%s]正在保存叠加结果......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    plt.savefig(out_image_path, dpi=300)
    plt.show()  # 显示已经创建的图形窗口和其中的内容
    print("-----------------------------------------语义分割已完成----------------------------------------")


if __name__ == "__main__":
    print("\n")
    print("--------------------------------------Segment Anything--------------------------------------")
    Image_path = r'B:/Personal/satellite.tif'  # 分割的影像
    Model_path = "G:/Neat Download Manager/Misc/sam_vit_h_4b8939.pth"  # 模型路径
    Out_mask_path = 'B:/Personal/my_figure1.tif'  # 彩色掩膜
    Out_mask_path1 = 'B:/Personal/my_figure2.tif'  # 二维掩膜用于转矢量
    Out_image_path = 'B:/Personal/my_figure3.png'  # 叠加结果
    Model_type = "vit_h"  # 定义模型类型
    Device = "cuda"  # "cpu"  or  "cuda"
    SAM_auto(Image_path, Model_path, Model_type, Device, Out_mask_path, Out_mask_path1, Out_image_path)
    # 图片,模型,类型,算力,彩色掩膜,黑白掩膜,叠加图片


标签:-%,分割,掩膜,语义,mask,masks,path,data
From: https://www.cnblogs.com/RSran/p/17759156.html

相关文章

  • 【Python&语义分割】Segment Anything(SAM)模型详细使用教程+代码解释(一)
    ​1SegmentAnything介绍1.1概况        MetaAI公司的SegmentAnything模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。    论文......
  • 【Python深度学习】目标检测和语义分割的区别
    ​        在计算机视觉领域,语义分割和目标检测是两个关键的任务,它们都是对图像和视频进行分析,但它们之间存在着明显的区别。本文将通过图像示例,详细阐述语义分割和目标检测之间的差异。一、基本概念        1.1语义分割(SemanticSegmentation)      ......
  • 【Python&语义分割】语义分割的原理及常见模型的介绍
    1概述        语义分割是计算机视觉中的重要任务之一,其目的是将图像中的每个像素分配给特定的类别,从而实现对图像的精细分割。与目标检测不同,语义分割并不需要对物体进行位置和边界框的检测,而是更加注重对图像中每个像素的分类。随着深度学习的兴起,语义分割得到了广泛......
  • 【Python&语义分割】Segment Anything(SAM)模型介绍&安装教程
    ​1SegmentAnything介绍1.1概况        MetaAI公司的SegmentAnything模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。    论......
  • 基于HSV空间的彩色图像分割技术
    1.引言每当我们看到图像时,它通常都是由各种元素和目标组成的。在某些情况下,我们可能会想要从图像中提取某个特定的对象,大家会怎么做?首先我们会想到的是进行crop相关的操作,这在某种程度上是可行的,但是这通常也会有一些不相关的像素会被包括在内,我确信大多情况下我们不希望这样。事......
  • 什么是语义化版本里的 Major,Minor 和 Patch 版本号
    语义化版本(SemanticVersioning):Major、Minor和Patch版本号解析语义化版本,通常简称为SemVer,是一种软件版本号的标准化方案,旨在使软件版本号的管理更加透明和可预测。它主要由三个部分组成:Major(主版本号)、Minor(次版本号)和Patch(修订版本号)。在这篇文章中,我们将深入解释这三个部分......
  • C# OpenVino Yolov8 Seg 分割
    效果 项目代码usingSystem;usingSystem.Collections.Generic;usingSystem.ComponentModel;usingSystem.Data;usingSystem.Drawing;usingSystem.Linq;usingSystem.Text;usingSystem.Windows.Forms;usingOpenCvSharp;namespaceOpenVino_Yolov8_Demo{publi......
  • Go每日一库之166:go-version(语义化版本)
    今天给大家推荐的是一个版本比较工具。该工具基于语义化标准的版本号进行比较、约束以及校验。以下是go-version的基本情况:安装通过goget进行安装:gogetgithub.com/hashicorp/go-version解析和比较版本号v1,err:=version.NewVersion("1.2")给版本号增加约束并校验v1......
  • openvino道路分割
    我这里仅显示道路和车道线1mask=np.zeros((hh,ww,3),dtype=np.uint8)2mask[np.where(res>0)]=(0,255,0)#路面3mask[np.where(res>1)]=(255,0,0)#车道线 模型的下载还是老方法AccuracyThequalitymetricscalculatedon500imagesfrom"Might......
  • json数据传输压缩以及数据切片分割分块传输多种实现方法,大数据量情况下zlib压缩以及by
    json数据传输压缩以及数据切片分割分块传输多种实现方法,大数据量情况下zlib压缩以及bytes指定长度分割。importsysimportzlibimportjsonimportmathKAFKA_MAX_SIZE=1024*1024CONTENT_MIN_MAX_SIZE=KAFKA_MAX_SIZE*0.9defsplit_data(data):""":param......