1.先对该数据集做数据清洗
import cv2
import numpy as np
import json
import jsonlines
import os
def iou(bbox1, bbox2):
"""
Calculates the intersection-over-union of two bounding boxes.
"""
bbox1 = [float(x) for x in bbox1]
bbox2 = [float(x) for x in bbox2]
(x0_1, y0_1, x1_1, y1_1) = bbox1
(x0_2, y0_2, x1_2, y1_2) = bbox2
# get the overlap rectangle
overlap_x0 = max(x0_1, x0_2)
overlap_y0 = max(y0_1, y0_2)
overlap_x1 = min(x1_1, x1_2)
overlap_y1 = min(y1_1, y1_2)
# check if there is an overlap
if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
return 0
# if yes, calculate the ratio of the overlap to each ROI size and the unified size
size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
# size_union = size_1 + size_2 - size_intersection
size_union = min(size_1, size_2)
return size_intersection / size_union
def get_flag(box_list):
length = len(box_list)
for i in range(length - 1):
for j in range(i + 1, length):
box1 = box_list[i]
box2 = box_list[j]
threshold = iou(box1, box2)
if threshold > 0.3:
print(threshold)
return False
return True
if __name__ == "__main__":
with jsonlines.open('E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0.jsonl', "r") as f:
with jsonlines.open("E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0_clean_train.jsonl", "w") as train_f:
labels = []
for data in f:
filename = data["filename"]
if data['split'] == 'train':
img_path = os.path.join("E:/pubtabnet/pubtabnet/pubtabnet/Images/train", filename)
img = cv2.imread(img_path)
if np.max(img) is None:
print("该图片为空")
continue
cells = data["html"]["cells"]
box_list = []
for idx, cell in enumerate(cells):
if len(cell["tokens"]) == 0 or "bbox" not in cell.keys():
continue
box = cell['bbox']
x1, y1, x2, y2 = box
box_list.append(box)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1)
flag = get_flag(box_list)
if not flag:
cv2.imwrite(os.path.join("badcase", filename), img)
else:
train_f.write(data)
2.将清洗后的数据读取为模型数据输入格式
import cv2 as cv
import numpy as np
import json
import jsonlines
import os
from html import escape
from convert_html_ann import html_to_davar
cnt = 0
with jsonlines.open('E:/PubTabNet_2.0.0_clean_train.jsonl', "r") as f:
with open('E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0_train_merge_label.jsonl', 'w', encoding='utf-8-sig') as file:
images_dir = "E:/pubtabnet/pubtabnet/pubtabnet/Images"
for content_ann in f:
labels = {}
image_path = os.path.join(images_dir, content_ann["file_path"])
if len(content_ann['bboxes']) != len(content_ann['cells']):
continue
img = cv.imread(image_path)
if np.max(img) is None:
# print("该图片为空")
print("image_null")
continue
h, w, c = img.shape
mask_r = np.zeros((h, w)).astype(np.uint8)
mask_c = np.zeros((h, w)).astype(np.uint8)
mask_char = np.zeros((h, w)).astype(np.uint8)
for i in range(len(content_ann['bboxes'])):
bboxes = content_ann['bboxes'][i]
cells = content_ann['cells'][i]
if len(bboxes) != 4 or len(cells) != 4:
continue
x1, y1, x2, y2 = bboxes[0], bboxes[1], bboxes[2], bboxes[3]
sr, sc, er, ec = cells[0], cells[1], cells[2], cells[3]
mask_char[y1: y2, x1: x2] = 1
if sr == er and sc == ec:
mask_r[y1: y2, x1: x2] = 1
mask_c[y1: y2, x1: x2] = 1
elif sr == er and sc != ec:
mask_r[y1: y2, x1: x2] = 1
elif sr != er and sc == er:
mask_c[y1: y2, x1: x2] = 1
has_char_row = np.nonzero(np.sum(mask_r, axis=1))[0]
gt_row = np.ones((h)).astype(np.int8) # * 255
gt_row[has_char_row] = 0
has_char_col = np.nonzero(np.sum(mask_c, axis=0))[0]
gt_col = np.ones((w)).astype(np.int8) # * 255
gt_col[has_char_col] = 0
labels["image_path"] = content_ann["file_path"]
labels["gt_row"] = gt_row.tolist()
labels["gt_col"] = gt_col.tolist()
labels["mask_char"] = mask_char.tolist()
json_object = json.dumps(labels, ensure_ascii=False)
file.write(json_object)
file.write("\n")
cnt = cnt + 1
print(cnt)
标签:pubtabnet,处理,代码,mask,y1,np,import,x1
From: https://www.cnblogs.com/jingweip/p/17504627.html