首页 > 编程语言 >DBNet源码详解

DBNet源码详解

时间:2023-01-20 11:02:39浏览次数:63  
标签:distance square polygon map DBNet point 源码 np 详解

参考项目:https://github.com/WenmuZhou/DBNet.pytorch

标签制作

制作threshold map 标签

image-20230119092440121

make_border_map.py

程序入口if __name__ == '__main__'

if __name__ == '__main__':
    import numpy as np

    make_border_map = MakeBorderMap()	# 实例化MakeBorderMap对象
    img = cv2.imread("../../datasets/train/img/img_41.jpg")	# 随机选取一张图片做演示
    # shape (4, 4, 2),表示该张图片上有4个文本区域,每个文本区域有4个坐标(x,y)
    points = np.array([[[533, 134], [562, 133], [561, 145], [532, 146]],
                       [[564, 131], [617, 129], [617, 145], [564, 146]],
                       [[620, 126], [657, 127], [656, 143], [618, 143]],
                       [[153, 150], [209, 144], [210, 159], [154, 165]]])
    draw_img = img.copy()
    # 可视化一下该图片上的文本区域
    for pt in points:
        cv2.polylines(draw_img, [pt], True, color=(0, 255, 0))
    cv2.imshow("draw", draw_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    # 该图片上的文本内容
    texts = ['EW15', 'Tanjong', 'Pagar', 'CAUTION']
    # 是否要忽略掉该文本区域。如果文本区域过小,那么需要忽略,后面的代码有解释。
    ignore_tags = [False, False, False, False]
    data = {"img": img, "img_41": "img_41", "text_polys": points, "texts": texts, "ignore_tags": ignore_tags}
    # 实际构造threshold map 标签
    data = make_border_map(data)
    print(data)

该图片文本区域可视化的结果如下:

image-20230119093520311

调用make_border_map对象

def __call__(self, data: dict) -> dict:
	"""
	从scales中随机选择一个尺度,对图片和文本框进行缩放
	:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
	:return:
	"""
    # 需要标注的图片
	im = data['img']
    # 图片上的实际的文本区域
	text_polys = data['text_polys']
    # 是否需要忽略掉该文本区域
	ignore_tags = data['ignore_tags']
	# canvas,即最后的threshold_map,其shape为图片的实际大小,用0初始化
	canvas = np.zeros(im.shape[:2], dtype=np.float32)
    # mask shape为图片的实际大小,用0初始化
	mask = np.zeros(im.shape[:2], dtype=np.float32)

	for i in range(len(text_polys)):    # 文本框的个数
        # 如果该文本区域需要忽略,则跳过
		if ignore_tags[i]:
			continue
        # 实际构造标签的方法
		self.draw_border_map(text_polys[i], canvas, mask=mask)
	canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min

	data['threshold_map'] = canvas
	data['threshold_mask'] = mask
	return data

canvas的初始状态

image-20230119094750682

其中,h表示原图的高度,w表示原图的宽度

mask的初始状态

image-20230119094758686

其中,h表示原图的高度,w表示原图的宽度

进入draw_border_map()方法

def draw_border_map(self, polygon, canvas, mask):
	polygon = np.array(polygon) # (4,2)
	assert polygon.ndim == 2
	assert polygon.shape[1] == 2
	#构造多边形对象
	polygon_shape = Polygon(polygon)
    # 多边形面积小于0,直接return
	if polygon_shape.area <= 0: 
		return
    # 计算膨胀和收缩的距离,该公式为论文中的公式
	distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length     
	subject = (tuple(l) for l in polygon)
    # 计算坐标的偏移
	padding = pyclipper.PyclipperOffset() 
	padding.AddPath(subject, pyclipper.JT_ROUND,
					pyclipper.ET_CLOSEDPOLYGON)
    # 计算出来的distance是正数,所以是膨胀边框
	padded_polygon = np.array(padding.Execute(distance)[0])     
    # 用计算出来的膨胀多边形填充mask
	cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], (255, 255, 255))
	plt.imshow(mask)
	plt.show()
	# cv2.imshow("mask", mask)
	# cv2.waitKey(0)
	# # closing all open windows
	# cv2.destroyAllWindows()

	xmin = padded_polygon[:, 0].min()   # x坐标的最小值
	xmax = padded_polygon[:, 0].max()   # x坐标的最大值
	ymin = padded_polygon[:, 1].min()   # y坐标的最小值
	ymax = padded_polygon[:, 1].max()   # y坐标的最大值
	width = xmax - xmin + 1             # 宽
	height = ymax - ymin + 1            # 高
    # 将多边形相对于原始图片的坐标转换为多边形相对于膨胀后的多边形的坐标
	polygon[:, 0] = polygon[:, 0] - xmin
	polygon[:, 1] = polygon[:, 1] - ymin

	xs = np.broadcast_to(
		np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
	ys = np.broadcast_to(
		np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
	# (4,22,39)  定义一个distance_map 用于保存每一个点到每一条边的距离
	distance_map = np.zeros(
		(polygon.shape[0], height, width), dtype=np.float32)
    # 遍历每一个文本框,计算该文本框同膨胀后多边形每一个边的距离
	for i in range(polygon.shape[0]):
		j = (i + 1) % polygon.shape[0]
        # 核心方法:计算扩张后的边界框内所有点到该条边的距离
		absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 
        # 将得到的距离除以扩张的偏移量进行归一化,并且将值限制在[0,1]以内
		distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 
    # 只保留某个点到距离最近的边的距离
	distance_map = distance_map.min(axis=0)  
    
	# 保证xmin,xmax,ymin,ymax的坐标在canvas的范围内
	xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
	xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
	ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
	ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
    
	canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
		1 - distance_map[
			ymin_valid - ymin:ymax_valid - ymax + height,
			xmin_valid - xmin:xmax_valid - xmax + width],
		canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])

上面代码中,cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1)的结果可视化如下:

因为每次传入的mask是之前的mask,所以每次会在之前的mask上继续填充多边形。

第一个文本区域

image-20230119095918791

第一、二个文本区域

image-20230119100047504

第一、二、三个文本区域

image-20230119100110358

第一、二、三、四个文本区域

image-20230119100131541

上面代码中

polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin

这两行代码的作用用可视化的方式解释如下:

image-20230119101343639

image-20230119101432170

上面代码中,

xs = np.broadcast_to(
    np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
    np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))

用可视化的方式解释如下:

image-20230119102904937

定义xs和ys的目的是为了计算每一个点到原始文本框每一条边的距离,具体体现在absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 这一行代码

进入distance()方法

def distance(self, xs, ys, point_1, point_2):
	'''
	compute the distance from point to a line
	ys: coordinates in the first axis
	xs: coordinates in the second axis
	point_1, point_2: (x, y), the end of the line
	'''
	# height, width = xs.shape[:2]
	square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
	square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
    # 文本框一条边的长度
	square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])   
	# 余弦距离,不知道为什么是这个公式? 如果有理解的大佬还请解释一下。 shape:(22*39)
	cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
	square_sin = 1 - np.square(cosin)       # sin² x + cos² x = 1
	square_sin = np.nan_to_num(square_sin)  # 用0替代nan值

	result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
    # 不知道这行代码的意思?如果有理解的大佬还请解释一下
	result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
	# self.extend_line(point_1, point_2, result)
	return result

上面代码中,

square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])

用可视化的方式解释如下:

image-20230119132850069

result的shape为膨胀之后的多边形的shape,表示每一个点到一条边的距离。

再退回到draw_border_map()方法

来到下面这行代码

distance_map[i] = np.clip(absolute_distance / distance, 0, 1)

用可视化的方式解释如下:

image-20230119144403988

再继续,来到下面这行代码

distance_map = distance_map.min(axis=0)  # 只保留某个点到距离最近的边的距离

到这里,已经求出了每一个点到每一条边的距离,这时候只选最近的那个距离,distance_map的shape就变为了(height, width),即膨胀之后的多边形的高度和宽度。

再往下,

canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
	1 - distance_map[
		ymin_valid - ymin:ymax_valid - ymax + height,
		xmin_valid - xmin:xmax_valid - xmax + width],
	canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])

这段代码其实就是在往canvas图像上填充值了,用1减去distance_map上每一个像素点的值,相当于把距离做了一下变化,之前最大的距离1,现在变成了0,之前最小的距离0,现在变成了1。然后再同原始的canvas图像上的像素点(初始的时候都是0)计算最大值。

image-20230119144421424

最后,退回到make_border_map对象的调用方法

来到这行代码,

canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min

该行代码的作用就是将canvas里面的值压缩到self.thresh_min到self.thresh_max之间,即0.3-0.7。

可视化结果如下:

image-20230119145050093

到这里threshold_map的标签就制作好了。

制作probability map 标签

image-20230120090139268

make_shrink_map.py

还是以../../datasets/train/img/img_41.jpg这张图片为例

调用MakeShrinkMap对象

def __call__(self, data: dict) -> dict:
	"""
	从scales中随机选择一个尺度,对图片和文本框进行缩放
	:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
	:return:
	"""
	image = data['img']
	text_polys = data['text_polys']
	ignore_tags = data['ignore_tags']

	h, w = image.shape[:2]
    # 验证文本框坐标的有效性以及变换坐标顺序
	text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
    # 真实标签
	gt = np.zeros((h, w), dtype=np.float32)
	mask = np.ones((h, w), dtype=np.float32)
	for i in range(len(text_polys)):
		polygon = text_polys[i]
        # 计算文本框的高度
		height = max(polygon[:, 1]) - min(polygon[:, 1])
        # 计算文本框的宽度
		width = max(polygon[:, 0]) - min(polygon[:, 0])
		if ignore_tags[i] or min(height, width) < self.min_text_size:
            # 在mask上填充
			cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
			ignore_tags[i] = True
		else:
            # 完成收缩的方法
			shrinked = self.shrink_func(polygon, self.shrink_ratio)
			if shrinked.size == 0:
				cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
				ignore_tags[i] = True
				continue
			cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)

	data['shrink_map'] = gt
	data['shrink_mask'] = mask
	return data

上面代码中,

text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)

这行代码的作用是验证文本框的边框是否是超出了原始图片的范围,并且将文本框的坐标顺序进行了一个调换。例如,[[533, 134], [562, 133], [561, 145], [532, 146]]这个文本框坐标调换顺序之后变为[[532, 146], [561, 145], [562, 133], [533, 134]]

gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)

gt的可视化结果如下:

image-20230120102015403

mask的可视化结果如下:

image-20230120101944648

shrinked = self.shrink_func(polygon, self.shrink_ratio)

这行代码的作用是完成边框的收缩

进入shrink_func()方法

def shrink_polygon_pyclipper(polygon, shrink_ratio):
    from shapely.geometry import Polygon
    import pyclipper
    polygon_shape = Polygon(polygon)
    # 计算需要缩放的偏移量,其公式来源于论文
    distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length  
    subject = [tuple(l) for l in polygon]
    padding = pyclipper.PyclipperOffset()
    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
    # 缩放边框
    shrinked = padding.Execute(-distance)
    if shrinked == []:
        shrinked = np.array(shrinked)
    else:
        shrinked = np.array(shrinked[0]).reshape(-1, 2)
    return shrinked

上面代码中,

distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length  

改行代码的作用是计算需要缩放的偏移量,其计算公式如下:
$$
D=\frac{A\left(1-r^{2}\right)}{L}
$$

最后,回到MakeShrinkMap对象

来到这行代码

cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)

这行代码的作用就是在gt上用白色填充缩放后的多边形,可视化效果如下:

第一个文本框的缩放gt

image-20230120103814784

第一个和第二个文本框的缩放gt

image-20230120103841096

第一个、第二个和第三个文本框的缩放gt

image-20230120103912287

第一个、第二个、第三个和第四个文本框的缩放gt

image-20230120103927303

到这里,probability map的标签就制作好了。

未完待续......

标签:distance,square,polygon,map,DBNet,point,源码,np,详解
From: https://www.cnblogs.com/kyle-blog/p/17062517.html

相关文章

  • Spring AOP中@Pointcut切入点表达式详解
    目录一、瞅一眼标准的AspectJAop的pointcut的表达式二、SpringAop的十一种AOP表达式三、演示使用1、execution:2、within:3、this:4、target:5、args:6、@target:7、......
  • KMP 详解
    简介KMP这个名字的由来:它是三个人:D.E·Knuth、J.H·Morris和V.R·Pratt同时发现的。KMP是一种字符串单模匹配算法,复杂度\(\operatorname{O(n+k)}\)。其中......
  • Target EDI 对接详解 – ECGrid AS2 连接
    Target塔吉特是美国仅次于Walmart沃尔玛的第二大巨型折扣零售百货集团。Target在2020财年实现零售收入同比增长19.8%,赶超了CVS和Tesco,并在2020财年的销售额增长超过......
  • 【Django drf】 序列化类常用字段类和字段参数 定制序列化字段的两种方式 关系表外键
    目录序列化类常用字段类和字段参数常用字段类常用字段参数选项参数通用参数序列化类高级用法之sourcesource填写类中字段source填写模型类中方法source支持跨表查询定制序......
  • ecs-lite 源码简单分析
    初学typescript,分析的不到位欢迎指正。 ecs-lite基于ts实现的纯ecs库,可用于学习交流及H5游戏开发!https://gitee.com/aodazhang/ecs-lite?_from=gitee_search文......
  • KMP算法详解(逻辑分析&数学证明&代码实现)
    前言KMP算法是Knuth、Morris、Pratt三人在BF算法的基础上同时提出的模式匹配的高效算法。本文以字符串匹配问题为例,以通俗易懂的语言对KMP算法进行逻辑分析、数学证明和代码......
  • 单调队列详解
    简介单调栈和单调队列都是思维难度比较大的数据结构,但只要想明白了就会觉得很简单。要理解单调队列,首先得明白“单调”是指它存储的内容单调,而不是指它简单。实现模板......
  • DBNet源码详解
    参考项目:https://github.com/WenmuZhou/DBNet.pytorch标签制作制作thresholdmap标签make_border_map.py程序入口if__name__=='__main__'if__name__=='__main......
  • win10下python3.9的代理报错问题解决(附web3的polygon爬虫源码)
    背景因为工作中经常需要代理访问,而开了代理,request就会报错SSLError,如下:requests.exceptions.SSLError:HTTPSConnectionPool(host='test-admin.xxx.cn',port=443):Ma......
  • 浅谈SpringAOP功能源码执行逻辑
    如题,该篇博文主要是对Spring初始化后AOP部分代码执行流程的分析,仅仅只是粗略的讲解整个执行流程,喜欢细的小伙伴请结合其他资料进行学习。在看本文之前请一定要有动态代理的......