首页 > 编程语言 >DBNet源码详解

DBNet源码详解

时间:2023-01-19 15:35:00浏览次数:58  
标签:distance canvas square point DBNet map 源码 np 详解

参考项目:https://github.com/WenmuZhou/DBNet.pytorch

标签制作

制作threshold map 标签

image-20230119092440121

make_border_map.py

程序入口if __name__ == '__main__'

if __name__ == '__main__':
    import numpy as np

    make_border_map = MakeBorderMap()	# 实例化MakeBorderMap对象
    img = cv2.imread("../../datasets/train/img/img_41.jpg")	# 随机选取一张图片做演示
    # shape (4, 4, 2),表示该张图片上有4个文本区域,每个文本区域有4个坐标(x,y)
    points = np.array([[[533, 134], [562, 133], [561, 145], [532, 146]],
                       [[564, 131], [617, 129], [617, 145], [564, 146]],
                       [[620, 126], [657, 127], [656, 143], [618, 143]],
                       [[153, 150], [209, 144], [210, 159], [154, 165]]])
    draw_img = img.copy()
    # 可视化一下该图片上的文本区域
    for pt in points:
        cv2.polylines(draw_img, [pt], True, color=(0, 255, 0))
    cv2.imshow("draw", draw_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    # 该图片上的文本内容
    texts = ['EW15', 'Tanjong', 'Pagar', 'CAUTION']
    # 是否要忽略掉该文本区域。如果文本区域过小,那么需要忽略,后面的代码有解释。
    ignore_tags = [False, False, False, False]
    data = {"img": img, "img_41": "img_41", "text_polys": points, "texts": texts, "ignore_tags": ignore_tags}
    # 实际构造threshold map 标签
    data = make_border_map(data)
    print(data)

该图片文本区域可视化的结果如下:

image-20230119093520311

调用make_border_map对象

def __call__(self, data: dict) -> dict:
	"""
	从scales中随机选择一个尺度,对图片和文本框进行缩放
	:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
	:return:
	"""
    # 需要标注的图片
	im = data['img']
    # 图片上的实际的文本区域
	text_polys = data['text_polys']
    # 是否需要忽略掉该文本区域
	ignore_tags = data['ignore_tags']
	# canvas,即最后的threshold_map,其shape为图片的实际大小,用0初始化
	canvas = np.zeros(im.shape[:2], dtype=np.float32)
    # mask shape为图片的实际大小,用0初始化
	mask = np.zeros(im.shape[:2], dtype=np.float32)

	for i in range(len(text_polys)):    # 文本框的个数
        # 如果该文本区域需要忽略,则跳过
		if ignore_tags[i]:
			continue
        # 实际构造标签的方法
		self.draw_border_map(text_polys[i], canvas, mask=mask)
	canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min

	data['threshold_map'] = canvas
	data['threshold_mask'] = mask
	return data

canvas的初始状态

image-20230119094750682

其中,h表示原图的高度,w表示原图的宽度

mask的初始状态

image-20230119094758686

其中,h表示原图的高度,w表示原图的宽度

进入draw_border_map()方法

def draw_border_map(self, polygon, canvas, mask):
	polygon = np.array(polygon) # (4,2)
	assert polygon.ndim == 2
	assert polygon.shape[1] == 2
	#构造多边形对象
	polygon_shape = Polygon(polygon)
    # 多边形面积小于0,直接return
	if polygon_shape.area <= 0: 
		return
    # 计算膨胀和收缩的距离,该公式为论文中的公式
	distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length     
	subject = (tuple(l) for l in polygon)
    # 计算坐标的偏移
	padding = pyclipper.PyclipperOffset() 
	padding.AddPath(subject, pyclipper.JT_ROUND,
					pyclipper.ET_CLOSEDPOLYGON)
    # 计算出来的distance是正数,所以是膨胀边框
	padded_polygon = np.array(padding.Execute(distance)[0])     
    # 用计算出来的膨胀多边形填充mask
	cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], (255, 255, 255))
	plt.imshow(mask)
	plt.show()
	# cv2.imshow("mask", mask)
	# cv2.waitKey(0)
	# # closing all open windows
	# cv2.destroyAllWindows()

	xmin = padded_polygon[:, 0].min()   # x坐标的最小值
	xmax = padded_polygon[:, 0].max()   # x坐标的最大值
	ymin = padded_polygon[:, 1].min()   # y坐标的最小值
	ymax = padded_polygon[:, 1].max()   # y坐标的最大值
	width = xmax - xmin + 1             # 宽
	height = ymax - ymin + 1            # 高
    # 将多边形相对于原始图片的坐标转换为多边形相对于膨胀后的多边形的坐标
	polygon[:, 0] = polygon[:, 0] - xmin
	polygon[:, 1] = polygon[:, 1] - ymin

	xs = np.broadcast_to(
		np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
	ys = np.broadcast_to(
		np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
	# (4,22,39)  定义一个distance_map 用于保存每一个点到每一条边的距离
	distance_map = np.zeros(
		(polygon.shape[0], height, width), dtype=np.float32)
    # 遍历每一个文本框,计算该文本框同膨胀后多边形每一个边的距离
	for i in range(polygon.shape[0]):
		j = (i + 1) % polygon.shape[0]
        # 核心方法:计算扩张后的边界框内所有点到该条边的距离
		absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 
        # 将得到的距离除以扩张的偏移量进行归一化,并且将值限制在[0,1]以内
		distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 
    # 只保留某个点到距离最近的边的距离
	distance_map = distance_map.min(axis=0)  
    
	# 保证xmin,xmax,ymin,ymax的坐标在canvas的范围内
	xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
	xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
	ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
	ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
    
	canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
		1 - distance_map[
			ymin_valid - ymin:ymax_valid - ymax + height,
			xmin_valid - xmin:xmax_valid - xmax + width],
		canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])

上面代码中,cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1)的结果可视化如下:

因为每次传入的mask是之前的mask,所以每次会在之前的mask上继续填充多边形。

第一个文本区域

image-20230119095918791

第一、二个文本区域

image-20230119100047504

第一、二、三个文本区域

image-20230119100110358

第一、二、三、四个文本区域

image-20230119100131541

上面代码中

polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin

这两行代码的作用用可视化的方式解释如下:

image-20230119101343639

image-20230119101432170

上面代码中,

xs = np.broadcast_to(
    np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
    np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))

用可视化的方式解释如下:

image-20230119102904937

定义xs和ys的目的是为了计算每一个点到原始文本框每一条边的距离,具体体现在absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 这一行代码

进入distance()方法

def distance(self, xs, ys, point_1, point_2):
	'''
	compute the distance from point to a line
	ys: coordinates in the first axis
	xs: coordinates in the second axis
	point_1, point_2: (x, y), the end of the line
	'''
	# height, width = xs.shape[:2]
	square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
	square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
       # 文本框一条边的长度
	square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])   
	# 余弦距离,不知道为什么是这个公式? 如果有理解的大佬还请解释一下。 shape:(22*39)
	cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
	square_sin = 1 - np.square(cosin)       # sin² x + cos² x = 1
	square_sin = np.nan_to_num(square_sin)  # 用0替代nan值

	result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
    # 不知道这行代码的意思?如果有理解的大佬还请解释一下
	result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
	# self.extend_line(point_1, point_2, result)
	return result

上面代码中,

square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])

用可视化的方式解释如下:

image-20230119132850069

result的shape为膨胀之后的多边形的shape,表示每一个点到一条边的距离。

再退回到draw_border_map()方法

来到下面这行代码

distance_map[i] = np.clip(absolute_distance / distance, 0, 1)

用可视化的方式解释如下:

image-20230119144403988

再继续,来到下面这行代码

distance_map = distance_map.min(axis=0)  # 只保留某个点到距离最近的边的距离

到这里,已经求出了每一个点到每一条边的距离,这时候只选最近的那个距离,distance_map的shape就变为了(height, width),即膨胀之后的多边形的高度和宽度。

再往下,

canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
	1 - distance_map[
		ymin_valid - ymin:ymax_valid - ymax + height,
		xmin_valid - xmin:xmax_valid - xmax + width],
	canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])

这段代码其实就是在往canvas图像上填充值了,用1减去distance_map上每一个像素点的值,相当于把距离做了一下变化,之前最大的距离1,现在变成了0,之前最小的距离0,现在变成了1。然后再同原始的canvas图像上的像素点(初始的时候都是0)计算最大值。

image-20230119144421424

最后,退回到make_border_map对象的调用方法

来到这行代码,

canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min

该行代码的作用就是将canvas里面的值压缩到self.thresh_min到self.thresh_max之间,即0.3-0.7。

可视化结果如下:

image-20230119145050093

到这里threshold_map的标签就制作好了。

未完待续......

标签:distance,canvas,square,point,DBNet,map,源码,np,详解
From: https://www.cnblogs.com/kyle-blog/p/17061577.html

相关文章

  • win10下python3.9的代理报错问题解决(附web3的polygon爬虫源码)
    背景因为工作中经常需要代理访问,而开了代理,request就会报错SSLError,如下:requests.exceptions.SSLError:HTTPSConnectionPool(host='test-admin.xxx.cn',port=443):Ma......
  • 浅谈SpringAOP功能源码执行逻辑
    如题,该篇博文主要是对Spring初始化后AOP部分代码执行流程的分析,仅仅只是粗略的讲解整个执行流程,喜欢细的小伙伴请结合其他资料进行学习。在看本文之前请一定要有动态代理的......
  • 【并发编程】ThreadLocal详解
    文章目录​​1.ThreadLocal简介​​​​2.ThreadLocal的简单使用​​​​3.ThreadLocal的实现原理​​​​4.ThreadLocal不支持继承性​​​​5.InheritableThreadLocal支持......
  • OpenCV Mat类详解
    1.Mat类常用成员函数和成员变量        由于Mat类使用的非常广泛,使用的形式也非常之多,这里只对较为常用的成员函数和成员变量做出了整理;1.1构造函数(1)默认构......
  • SpringBoot源码学习3——SpringBoot启动流程
    系列文章目录和关于我一丶前言在《SpringBoot源码学习1——SpringBoot自动装配源码解析+Spring如何处理配置类的》中我们学习了SpringBoot自动装配如何实现的,在《Sprin......
  • 浅谈Netty中ServerBootstrap服务端源码(含bind全流程)
    文章目录​​一、梳理Java中NIO代码​​​​二、Netty服务端代码​​​​1、newNioEventLoopGroup()​​​​2、group​​​​3、channel​​​​4、NioServerSocketChanne......
  • 浅谈Redisson底层源码
    Redisson源码分析​​一、加锁时使用lua表达式,执行添加key并设置过期时间​​​​二、加锁成功之后给锁添加对应的事件​​​​三、加锁完成,看门狗自动续命未处理完的线程​......
  • drf快速使用 CBV源码分析 drf之APIView分析 drf之Request对象分析
     目录序列化和反序列化drf介绍和安装使用原生django写接口djangoDRF安装drf快速使用模型序列化类视图路由datagrip使用postman测试接口CBV源码分......
  • 浅谈Zookeeper集群选举Leader节点源码
    写在前面:zookeeper源码比较复杂,本文讲解的重点为各个zookeeper服务节点之间的state选举。至于各个节点之间的数据同步,不在文本的侧重讲解范围内。在没有对zookeeper组件有一......
  • 浅谈Redis基本数据类型底层编码(含C源码)
    文章目录​​一、String​​​​1、int​​​​2、embstr​​​​3、raw​​​​4、bitmap​​​​5、hyperloglog​​​​二、List​​​​1、ziplist​​​​2、quicklist......