首页 > 其他分享 >centernet的数据增强操作--仿射变换

centernet的数据增强操作--仿射变换

时间:2022-12-01 18:44:37浏览次数:73  
标签:src scale img -- dst centernet np float32 仿射变换

https://blog.csdn.net/yang332233/article/details/110164808?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166988875516782428685782%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=166988875516782428685782&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_ecpm_v1~rank_v31_ecpm-3-110164808-null-null.nonecase&utm_term=centernet&spm=1018.2226.3001.4450
其实在这里也分析过。奈何当初写的代码不知道哪里去了;

先看下效果图:

从图上可以看到,在原图随机确定的三个点都映射到变换之后的图,然后这三点包围的外接矩形区域在仿射变换之后都是肯定在的。整体呈现出平移缩放放大的效果。

画图改动代码如下:
在CenterNet-master/src/lib/utils/image.py复制函数get_affine_transform,返回src和dst三对点。

def get_affine_transform_point_src_dst(center,
                         scale,
                         rot,
                         output_size,
                         shift=np.array([0, 0], dtype=np.float32),
                         inv=0):
    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
        scale = np.array([scale, scale], dtype=np.float32)

    scale_tmp = scale
    src_w = scale_tmp[0]
    dst_w = output_size[0]
    dst_h = output_size[1]

    rot_rad = np.pi * rot / 180
    src_dir = get_dir([0, src_w * -0.5], rot_rad)
    dst_dir = np.array([0, dst_w * -0.5], np.float32)

    src = np.zeros((3, 2), dtype=np.float32)
    dst = np.zeros((3, 2), dtype=np.float32)
    src[0, :] = center + scale_tmp * shift
    src[1, :] = center + src_dir + scale_tmp * shift
    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir

    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

    if inv:
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
    else:
        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

    return trans, src, dst

在/CenterNet-master/src/lib/datasets/sample/ctdet.py中,画图, 添加show_3pt函数

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.utils.data as data
import numpy as np
import torch
import json
import cv2
import os
from utils.image import flip, color_aug
from utils.image import get_affine_transform, affine_transform, get_affine_transform_point_src_dst
from utils.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian
from utils.image import draw_dense_reg
import math

def show_3pt(src_img, inp, src_3pt, dst_3pt):
  h,w,c = src_img.shape
  x = src_3pt[:, 0]
  y = src_3pt[:, 1]
  min_x = np.min(x)
  min_y = np.min(y)

  width_new = w
  height_new = h
  if min_x < 0:
    width_new += (-min_x)
    src_3pt[:, 0] = src_3pt[:, 0] + (-min_x)

  if min_y < 0:
    height_new += (-min_y)
    src_3pt[:, 1] = src_3pt[:, 1] + (-min_y)

  start_x, start_y = 0, 0
  if min_x < 0:
    start_x = -min_x
  if min_y < 0:
    start_y = -min_y

  new_img = np.zeros([int(height_new + 2), int(width_new + 2), int(c)], dtype=np.uint8)
  new_img[int(start_y): int(start_y+h), int(start_x):int(start_x+w), :] = src_img.astype(np.uint8)

  for cnt in range(3):
    pt = (src_3pt[cnt][0], src_3pt[cnt][1])
    # print("pt=", pt)
    cv2.circle(new_img, pt, 14, (0, 0, 255), -1)

  for cnt in range(3):
    pt = (dst_3pt[cnt][0], dst_3pt[cnt][1])
    # print("pt=", pt)
    cv2.circle(inp, pt, 14, (0, 255, 255), -1)



  cv2.imshow("new_img", new_img)
  cv2.imshow("inp", inp)
  cv2.imshow("src_img", src_img)
  cv2.waitKey(0)

在这里调用:

  def __getitem__(self, index):
    img_id = self.images[index]
    file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
    img_path = os.path.join(self.img_dir, file_name)
    ann_ids = self.coco.getAnnIds(imgIds=[img_id])
    anns = self.coco.loadAnns(ids=ann_ids)
    num_objs = min(len(anns), self.max_objs)

    img = cv2.imread(img_path)

    height, width = img.shape[0], img.shape[1]
    c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
    if self.opt.keep_res:#False
      input_h = (height | self.opt.pad) + 1
      input_w = (width | self.opt.pad) + 1
      s = np.array([input_w, input_h], dtype=np.float32)
    else:
      s = max(img.shape[0], img.shape[1]) * 1.0
      input_h, input_w = self.opt.input_h, self.opt.input_w
    
    flipped = False
    if self.split == 'train':
      if not self.opt.not_rand_crop:#yes
        s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
        w_border = self._get_border(128, img.shape[1])
        h_border = self._get_border(128, img.shape[0])
        c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
        c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
      else:
        sf = self.opt.scale
        cf = self.opt.shift
        c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
      
      if np.random.random() < self.opt.flip:
        flipped = True
        img = img[:, ::-1, :]
        c[0] =  width - c[0] - 1
        

    trans_input, src_3pt, dst_3pt = get_affine_transform_point_src_dst(
      c, s, 0, [input_w, input_h])
    inp = cv2.warpAffine(img, trans_input, 
                         (input_w, input_h),
                         flags=cv2.INTER_LINEAR)

    show_3pt(img, inp, src_3pt, dst_3pt)

这里其实关键的是确定三对点。两个关键参数c和s

    s = max(img.shape[0], img.shape[1]) * 1.0
    if self.split == 'train':
      if not self.opt.not_rand_crop:#yes
        s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
        w_border = self._get_border(128, img.shape[1])
        h_border = self._get_border(128, img.shape[0])
        c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) #w_border = 128
        c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) #h_border = 128

这里的c就是代表center的意思,图像周围去掉128内圈就是c的范围, s是图像最长边然后随机的乘以[0.6,1.4,0.1]
这三对点第一个点就是c为中心点

def get_affine_transform_point_src_dst(center,
                         scale,
                         rot,
                         output_size,
                         shift=np.array([0, 0], dtype=np.float32),
                         inv=0):
    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
        scale = np.array([scale, scale], dtype=np.float32)

    scale_tmp = scale
    src_w = scale_tmp[0]
    dst_w = output_size[0]
    dst_h = output_size[1]

    rot_rad = np.pi * rot / 180
    src_dir = get_dir([0, src_w * -0.5], rot_rad)
    dst_dir = np.array([0, dst_w * -0.5], np.float32)

    src = np.zeros((3, 2), dtype=np.float32)
    dst = np.zeros((3, 2), dtype=np.float32)
    src[0, :] = center + scale_tmp * shift
    src[1, :] = center + src_dir + scale_tmp * shift
    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir

    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

    if inv:
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
    else:
        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

    return trans, src, dst

然后第二个点是沿着c向上src_w * -0.5

src_dir = get_dir([0, src_w * -0.5], rot_rad)

src[1, :] = center + src_dir + scale_tmp * shift
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir

这里看到图片有黑边就是因为这里的src_w * -0.5大于c的y,就导致y-0.5×src_w为负数。

第三对点

def get_3rd_point(a, b):
    direct = a - b
    return b + np.array([-direct[1], direct[0]], dtype=np.float32)


src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

这里是根据前面2个点来计算得到的。这里其实很简单,比如src_pt[0]=[500,500], src_pt[1]=[500,250]
那么direct=[0, 250]
return( [500,250] + [-250, 0])
即[250,250]
有没有发现!这里其实就是之前0-->1的时候向上偏移了比如h,然后这里在1的基础上又向左偏移h。

所以,以上就是三对点产生的过程!

产生黑边就是因为向上的偏移量,src_w * -0.5大于c的y,!左边的黑边就是因为src_w * -0.5 大于c的x。
s = max(img.shape[0], img.shape[1]) * 1.0
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))

src_dir = get_dir([0, src_w * -0.5], rot_rad) #这里src_w就是s

标签:src,scale,img,--,dst,centernet,np,float32,仿射变换
From: https://www.cnblogs.com/yanghailin/p/16942341.html

相关文章

  • Springboot自动装配源码及启动原理理解
    Springboot自动装配源码及启动原理理解springboot版本:2.2.2传统的Spring框架实现一个Web服务,需要导入各种依赖JAR包,然后编写对应的XML配置文件等,相较而言,SpringBoot显......
  • Mysql在数据应用中的注意事项
    1. 前言1.1.背景● 数据库被广泛应用:各类业务系统、信息化系统,数据仓库、数据分析、数据挖掘。● 数据使用中存在的常见问题。● 了解基本、强制的使用规范,有助于更好的......
  • RocketMQ 全链路灰度探索与实践
    本文作者:肖京,SpringCloudAlibabaPMC,阿里云智能技术专家。01 全链路灰度背景介绍发布新版本时,为了有效、谨慎地验证新版本代码逻辑的正确性,通常会采用灰度发布,从而达到减......
  • PF_RING调研及实践
    一PF_RING简介1.与libpcap不同,pf_ring核心思想是通过DMA将网卡流量直接MMAP到用户空间(绕过内核网络协议栈),避免libpcap的网卡->内核,内核→用户空间的方式,压缩拷贝次数,节省了......
  • 干货|成为优秀软件测试工程师的六大必备能力
    “软件吞噬世界”、“软件定义一切”。随着软件行业的迅速发展,保障软件质量的关键环节——软件测试也变得越来越重要。而执行测试工作的测试工程师,便是软件质量的把关者。......
  • 关键概念
    介绍HyperledgerFabric是分布式账本解决方案的平台,采用模块化架构,提供高安全性、弹性、灵活性和可扩展性。它被设计为支持以可插拔方式实现不同组件,并适应复杂的经济生态......
  • re1
    1.惯例查位数2.放进ida643.shift+f12无脑找与flag有关的字符串4.随便一个双击,ctrl+x查找引用了该字符串的函数跳到这里来了容易看到右下这块有一个分支,左边wro......
  • 会话技术-概述
    会话技术-概述会话:一次会话中包含多次请求和响应。一次会话:浏览器第一次给服务器资源发送请求,会话建立,直到有一方断开为止功能:在一次会话的范围内的多次请求间,共享......
  • GO各种包的用途
    Go语言标准库常用的包及功能Go语言标准库包名功 能bufio带缓冲的I/O操作bytes实现字节操作container封装堆、列表和环形列表等容器crypto加密算法database数据库驱动和......
  • fabirc 开发环境搭建
    启动测试网络#1.拉代码gitclonehttps://github.com/hyperledger/fabric-samples.git#2.进入目录cdfabric-samples/test-network#启动网络并创建通道./network.shupcre......