首页 > 其他分享 >代码复现:Copy-Paste 数据增强for语义分割

代码复现:Copy-Paste 数据增强for语义分割

时间:2022-10-22 16:27:03浏览次数:101  
标签:src img rescale cv2 mask 复现 np Copy Paste

一、前言

前些天分享了一篇谷歌的数据增强论文,解读在这:https://www.cnblogs.com/tangjielin/p/16812816.html。

可能由于方法比较简单,官方没有开源代码,于是,我自己尝试在语义分割数据集上进行了实现,代码见GitHub。

先看下实现的效果:

原图:

2059520595

使用复制-粘贴方法增强后:

copy_paste_20595copy_paste_20595

二、思路及代码

从上面的可视化结果,可以看出,我们需要两组样本:一组image+annotation为源图,一组image+annotation为主图,我们的目的是将源图及其标注信息叠加到主图及其标注信息上;同时,需要对源图的信息做随机水平翻转、大尺度抖动/随机缩放的操作。

思路如下:

  1. 随机选取源图像 \(I_{s r c}\) (用于提取目标) 、主图像 \(I_{\text {main }}\) (用于将所提取的目前粘贴在其之上);
  2. 对 \(I_{s r c}\) 和 \(I_{\text {main }}\) 分别进行随机水平翻转;
  3. 根据参数设置,对 \(I_{\text {src }}\) 和 \(I_{\text {main }}\) 进行大尺度抖动(Large Scale Jittering,LSJ),或者仅对 \(I_{\text {srci }}\) 进行随机尺度缩放;
  4. 将 \(I_{s r c}\) 及其对应的mask \(k_{\text {src 分别使用公式 }} I_1 \times \alpha+I_2 \times(1-\alpha)\) 进行合成,生成合成的图像及其对应mask;
  5. 保存图像及mask,其中, mask转为8位调色板模式保存;

具体实现的代码如下(需要你的数据集为VOC格式,如果是coco格式,需要先将coco数据集的mask提取出来,可以参考这篇博客):

# -*- coding: utf-8 -*-
"""
PROJECT_NAME: RS_Toolbox 
FILE_NAME: Copy_Paste 
AUTHOR: welt 
E_MAIL: tjlwelt@foxmail.com
DATE: 2022/10/21 
"""


from PIL import Image
import imgviz
import cv2
import argparse
import os
import numpy as np
import tqdm


def save_colored_mask(mask, save_path):
    lbl_pil = Image.fromarray(mask.astype(np.uint8), mode="P")
    colormap = imgviz.label_colormap()
    lbl_pil.putpalette(colormap.flatten())
    lbl_pil.save(save_path)


def random_flip_horizontal(mask, img, p=0.5):
    if np.random.random() < p:
        img = img[:, ::-1, :]
        mask = mask[:, ::-1]
    return mask, img


def img_add(img_src, img_main, mask_src):
    if len(img_main.shape) == 3:
        h, w, c = img_main.shape
    elif len(img_main.shape) == 2:
        h, w = img_main.shape
    mask = np.asarray(mask_src, dtype=np.uint8)
    sub_img01 = cv2.add(img_src, np.zeros(np.shape(img_src), dtype=np.uint8), mask=mask)
    mask_02 = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
    mask_02 = np.asarray(mask_02, dtype=np.uint8)
    sub_img02 = cv2.add(img_main, np.zeros(np.shape(img_main), dtype=np.uint8),
                        mask=mask_02)
    img_main = img_main - sub_img02 + cv2.resize(sub_img01, (img_main.shape[1], img_main.shape[0]),
                                                 interpolation=cv2.INTER_NEAREST)
    return img_main


def rescale_src(mask_src, img_src, h, w):
    if len(mask_src.shape) == 3:
        h_src, w_src, c = mask_src.shape
    elif len(mask_src.shape) == 2:
        h_src, w_src = mask_src.shape
    max_reshape_ratio = min(h / h_src, w / w_src)
    rescale_ratio = np.random.uniform(0.2, max_reshape_ratio)

    # reshape src img and mask
    rescale_h, rescale_w = int(h_src * rescale_ratio), int(w_src * rescale_ratio)
    mask_src = cv2.resize(mask_src, (rescale_w, rescale_h),
                          interpolation=cv2.INTER_NEAREST)
    # mask_src = mask_src.resize((rescale_w, rescale_h), Image.NEAREST)
    img_src = cv2.resize(img_src, (rescale_w, rescale_h),
                         interpolation=cv2.INTER_LINEAR)

    # set paste coord
    py = int(np.random.random() * (h - rescale_h))
    px = int(np.random.random() * (w - rescale_w))

    # paste src img and mask to a zeros background
    img_pad = np.zeros((h, w, 3), dtype=np.uint8)
    mask_pad = np.zeros((h, w), dtype=np.uint8)
    img_pad[py:int(py + h_src * rescale_ratio), px:int(px + w_src * rescale_ratio), :] = img_src
    mask_pad[py:int(py + h_src * rescale_ratio), px:int(px + w_src * rescale_ratio)] = mask_src

    return mask_pad, img_pad


def Large_Scale_Jittering(mask, img, min_scale=0.1, max_scale=2.0):
    rescale_ratio = np.random.uniform(min_scale, max_scale)
    h, w, _ = img.shape

    # rescale
    h_new, w_new = int(h * rescale_ratio), int(w * rescale_ratio)
    img = cv2.resize(img, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
    mask = cv2.resize(mask, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
    # mask = mask.resize((w_new, h_new), Image.NEAREST)

    # crop or padding
    x, y = int(np.random.uniform(0, abs(w_new - w))), int(np.random.uniform(0, abs(h_new - h)))
    if rescale_ratio <= 1.0:  # padding
        img_pad = np.ones((h, w, 3), dtype=np.uint8) * 168
        mask_pad = np.zeros((h, w), dtype=np.uint8)
        img_pad[y:y + h_new, x:x + w_new, :] = img
        mask_pad[y:y + h_new, x:x + w_new] = mask
        return mask_pad, img_pad
    else:  # crop
        img_crop = img[y:y + h, x:x + w, :]
        mask_crop = mask[y:y + h, x:x + w]
        return mask_crop, img_crop


def copy_paste(mask_src, img_src, mask_main, img_main):
    mask_src, img_src = random_flip_horizontal(mask_src, img_src)
    mask_main, img_main = random_flip_horizontal(mask_main, img_main)

    # LSJ, Large_Scale_Jittering
    if args.lsj:
        mask_src, img_src = Large_Scale_Jittering(mask_src, img_src)
        mask_main, img_main = Large_Scale_Jittering(mask_main, img_main)
    else:
        # rescale mask_src/img_src to less than mask_main/img_main's size
        h, w, _ = img_main.shape
        mask_src, img_src = rescale_src(mask_src, img_src, h, w)

    img = img_add(img_src, img_main, mask_src)
    mask = img_add(mask_src, mask_main, mask_src)

    return mask, img


def main(args):
    # input path
    segclass = os.path.join(args.input_dir, 'SegmentationClass')
    JPEGs = os.path.join(args.input_dir, 'JPEGImages')

    # create output path
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'SegmentationClass'), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, 'JPEGImages'), exist_ok=True)

    masks_path = os.listdir(segclass)
    tbar = tqdm.tqdm(masks_path, ncols=100)
    for mask_path in tbar:
        # get source mask and img
        mask_src = np.asarray(Image.open(os.path.join(segclass, mask_path)), dtype=np.uint8)
        img_src = cv2.imread(os.path.join(JPEGs, mask_path.replace('.png', '.jpg')))

        # random choice main mask/img
        mask_main_path = np.random.choice(masks_path)
        mask_main = np.asarray(Image.open(os.path.join(segclass, mask_main_path)), dtype=np.uint8)
        img_main = cv2.imread(os.path.join(JPEGs, mask_main_path.replace('.png', '.jpg')))

        # Copy-Paste data augmentation
        mask, img = copy_paste(mask_src, img_src, mask_main, img_main)

        mask_filename = "copy_paste_" + mask_path
        img_filename = mask_filename.replace('.png', '.jpg')
        save_colored_mask(mask, os.path.join(args.output_dir, 'SegmentationClass', mask_filename))
        cv2.imwrite(os.path.join(args.output_dir, 'JPEGImages', img_filename), img)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", default="./dataset/Terrace", type=str,
                        help="input annotated directory")
    parser.add_argument("--output_dir", default="./dataset/Terrace_copy_paste", type=str,
                        help="output dataset directory")
    parser.add_argument("--lsj", default=True, type=bool, help="if use Large Scale Jittering")
    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()
    main(args)

标签:src,img,rescale,cv2,mask,复现,np,Copy,Paste
From: https://www.cnblogs.com/tangjielin/p/16816313.html

相关文章

  • xcopy命令拷贝文件时忽略指定文件夹
    现在弄的项目,前端是居于一个框架进行开发的。问题是,框架还不算成熟,仍然在不断修改中。这样问题就来了,我需要常常在具体项目中更新这个框架。怎么更新呢?手动更新,问你死未。真......
  • Vue笔记2 v-bind,截图软件snipaste、computed
                                              ......
  • CopyOnWriteArrayList集合
    CopyOnWriteArrayList是为了增加在写操作的时候的读操作的性能因为并发问题主要是写操作,当一个线程进行写操作时,会使用Reetranlock加锁,然后会复制一份原数组在新数组上进......
  • Kettle需求场景复现
    kettle真实需求开发,可实现kettle入门,包含细节亿点点。前置说明遍历文件夹下的文件,读取所有的sheet页(指定的sheet)落库读取execl文件和csv文件,获得文件中shee......
  • 非视距 TDOA 算法复现
    定位算法复现,非视距环境下TDOA定位算法论文名字:RobustConvexApproximationMethodsforTDOA-BasedLocalizationunderNLOSConditionsmatalab使用CVX实现完全OK,......
  • 复制-粘贴大法(Copy-Paste):简单而有效的数据增强
    论文标题:SimpleCopy-PasteisaStrongDataAugmentationMethodforInstanceSegmentation论文地址:https://arxiv.org/pdf/2012.07177.pdf1、摘要建立有效的实例......
  • 试图复现一次coredump但失败的经历
    昨天实现my_memmove的时候出现了coredump,现在试图复现找出问题1.src字符串复现过程中首先想到的是,有可能是因为src字符串是字符串常量。但是又回想了下,src不是字符......
  • 最新CS RCE(CVE-2022-39197)复现心得分享
    0x01前言CS作为目前最流行的远控工具,其爆出的远程命令行漏洞CVE-2022-39197号称脚本小子杀手神器。之前看了@漂亮鼠大佬的文章《最新CSRCE曲折的复现路》,对文章的内容非常......
  • svn your working copy appears to be locked run cleanup to amend the situation
    https://blog.csdn.net/anmei1912/article/details/101614285https://blog.csdn.net/zouyujie1127/article/details/7683602/右击选择svn,选择clearup, 则解决解决sv......
  • (BADI)Copy PR header text to PO header when ME21N
    货铺QQ群号:834508274进群统一修改群名片,例如BJ_ABAP_森林木。群内禁止发广告及其他一切无关链接,小程序等,进群看公告,谢谢配合不修改昵称会被不定期踢除,谢谢配合下面开始干货:......