首页 > 其他分享 >处理pubtabnet数据集代码

处理pubtabnet数据集代码

时间:2023-06-26 10:12:38浏览次数:42  
标签:pubtabnet 处理 代码 mask y1 np import x1

1.先对该数据集做数据清洗

import cv2
import numpy as np
import json
import jsonlines
import os


def iou(bbox1, bbox2):
    """
    Calculates the intersection-over-union of two bounding boxes.
    """
    bbox1 = [float(x) for x in bbox1]
    bbox2 = [float(x) for x in bbox2]
    (x0_1, y0_1, x1_1, y1_1) = bbox1
    (x0_2, y0_2, x1_2, y1_2) = bbox2
    # get the overlap rectangle
    overlap_x0 = max(x0_1, x0_2)
    overlap_y0 = max(y0_1, y0_2)
    overlap_x1 = min(x1_1, x1_2)
    overlap_y1 = min(y1_1, y1_2)
    # check if there is an overlap
    if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
        return 0
    # if yes, calculate the ratio of the overlap to each ROI size and the unified size
    size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
    size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
    size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
    # size_union = size_1 + size_2 - size_intersection
    size_union = min(size_1, size_2)
    return size_intersection / size_union


def get_flag(box_list):
    length = len(box_list)
    for i in range(length - 1):
        for j in range(i + 1, length):
            box1 = box_list[i]
            box2 = box_list[j]
            threshold = iou(box1, box2)
            if threshold > 0.3:
                print(threshold)
                return False
    return True


if __name__ == "__main__":
    with jsonlines.open('E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0.jsonl', "r") as f:
        with jsonlines.open("E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0_clean_train.jsonl", "w") as train_f:
            labels = []
            for data in f:
                filename = data["filename"]
                if data['split'] == 'train':
                    img_path = os.path.join("E:/pubtabnet/pubtabnet/pubtabnet/Images/train", filename)
                    img = cv2.imread(img_path)
                    if np.max(img) is None:
                        print("该图片为空")
                        continue
                    cells = data["html"]["cells"]
                    box_list = []
                    for idx, cell in enumerate(cells):
                        if len(cell["tokens"]) == 0 or "bbox" not in cell.keys():
                            continue
                        box = cell['bbox']
                        x1, y1, x2, y2 = box
                        box_list.append(box)
                        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 1)
                    flag = get_flag(box_list)
                    if not flag:
                        cv2.imwrite(os.path.join("badcase", filename), img)
                    else:
                        train_f.write(data)

2.将清洗后的数据读取为模型数据输入格式

import cv2 as cv
import numpy as np
import json
import jsonlines
import os
from html import escape
from convert_html_ann import html_to_davar
cnt = 0
with jsonlines.open('E:/PubTabNet_2.0.0_clean_train.jsonl', "r") as f:
    with open('E:/pubtabnet/pubtabnet/pubtabnet/PubTabNet_2.0.0_train_merge_label.jsonl', 'w', encoding='utf-8-sig') as file:
        images_dir = "E:/pubtabnet/pubtabnet/pubtabnet/Images"
        for content_ann in f:
            labels = {}
            image_path = os.path.join(images_dir, content_ann["file_path"])
            if len(content_ann['bboxes']) != len(content_ann['cells']):
                continue
            img = cv.imread(image_path)
            if np.max(img) is None:
                # print("该图片为空")
                print("image_null")
                continue
            h, w, c = img.shape
            mask_r = np.zeros((h, w)).astype(np.uint8)
            mask_c = np.zeros((h, w)).astype(np.uint8)
            mask_char = np.zeros((h, w)).astype(np.uint8)
            for i in range(len(content_ann['bboxes'])):
                bboxes = content_ann['bboxes'][i]
                cells = content_ann['cells'][i]
                if len(bboxes) != 4 or len(cells) != 4:
                    continue
                x1, y1, x2, y2 = bboxes[0], bboxes[1], bboxes[2], bboxes[3]
                sr, sc, er, ec = cells[0], cells[1], cells[2], cells[3]
                mask_char[y1: y2, x1: x2] = 1
                if sr == er and sc == ec:
                    mask_r[y1: y2, x1: x2] = 1
                    mask_c[y1: y2, x1: x2] = 1
                elif sr == er and sc != ec:
                    mask_r[y1: y2, x1: x2] = 1
                elif sr != er and sc == er:
                    mask_c[y1: y2, x1: x2] = 1
            has_char_row = np.nonzero(np.sum(mask_r, axis=1))[0]
            gt_row = np.ones((h)).astype(np.int8)  # * 255
            gt_row[has_char_row] = 0
            has_char_col = np.nonzero(np.sum(mask_c, axis=0))[0]
            gt_col = np.ones((w)).astype(np.int8)  # * 255
            gt_col[has_char_col] = 0
            labels["image_path"] = content_ann["file_path"]
            labels["gt_row"] = gt_row.tolist()
            labels["gt_col"] = gt_col.tolist()
            labels["mask_char"] = mask_char.tolist()

            json_object = json.dumps(labels, ensure_ascii=False)
            file.write(json_object)
            file.write("\n")
            cnt = cnt + 1
            print(cnt)

标签:pubtabnet,处理,代码,mask,y1,np,import,x1
From: https://www.cnblogs.com/jingweip/p/17504627.html

相关文章

  • 《编写高质量代码》读书笔记系列开篇
    前言:   时间过的好快,进入这个互联网的fe行业已经快*年了,读书还是一个需要坚持的东东,是一种坚持,因为兴趣所以热爱。 正文:   其实这边书一种在间断地看着,今天组里买了一本,决定开一个系列,静静地品一下,重新审视自己的深度和方向。   1、如何做的更好的Web前端工程师?   ......
  • Flutter延迟执行一段代码的几种方式以及Timer的说明
    延迟执行在Flutter中,可以使用以下方式实现延迟执行一段代码的效果使用Future.delayed方法:Future.delayed(Duration(milliseconds:500),(){//延迟执行的代码});使用Timer类:Timer(Duration(milliseconds:500),(){//延迟执行的代码});使用Future......
  • 目标字符串驼峰化处理
    功能函数的设计初衷是将目标字符串驼峰化的api:比如CSS样式特性与JavaScipt样式属性的切换  background-color与style.backgroundColorfont-weight与fontWeightfont-family与fontFamily  ~~~~~~~~~~~~~~  /**toCamelCase--将目标字符串进行驼峰化处理**@func......
  • 自然语言处理 Paddle NLP - 检索式文本问答-理论
    问答系统(QuestionAnsweringSystem,QA)是信息检索系统的一种高级形式,它能用准确、简洁的自然语言回答用户用自然语言提出的问题。其研究兴起的主要原因是人们对快速、准确地获取信息的需求。问答系统是人工智能.抽取式阅读理解:它的答案一定是段落里的一个片段,所以在训练前,先要......
  • flv.js视频流出错,断流处理
    flv.js视频流出错,断流处理可乐加冰5152023年02月20日17:45 ·  阅读274场景:前端使用flv.js播放视频流Bug表现:视频流播放两分钟左右video标签出现暂停按钮,控制台flv.js报错:Failedtoexecute'appendBuffer'on'SourceBuffer':TheHTMLMediaElement.erroratt......
  • R语言618电商大数据文本分析LDA主题模型可视化报告|附代码数据
    原文链接:http://tecdat.cn/?p=1078最近我们被客户要求撰写关于文本分析LDA主题模型的研究报告,包括一些图形和统计输出。618购物狂欢节前后,网民较常搜索的关键词在微博、微信、新闻三大渠道的互联网数据表现,同时通过分析平台采集618相关媒体报道和消费者提及数据社交媒体指数趋......
  • 零代码量化投资:用ChatGPT提取企业PDF年报中的多页表格
    企业PDF年报中有很多信息,里面表格很多,所以经常需要提取其中的表格。用ChatGPT来编程实现,非常简单。案例1:提取鑫铂股份募集说明书中的行业主要法律法规及政策表格在ChatGPT输入提示语如下:写一段Python代码,实现提取PDF文件中表格的功能。具体步骤如下:打开PDF文件,文件路径是:F:\金属材......
  • Win32k 是 Windows 操作系统中的一个核心组件,它负责处理图形显示、窗口管理和用户交互
    Win32k是Windows操作系统中的一个核心组件,它负责处理图形显示、窗口管理和用户交互等功能。在Windows中,Win32k.sys是一个内核模式驱动程序,它提供了访问图形子系统的接口。因此,Win32k具有较高的权限和特权。作为一个内核模式驱动程序,Win32k有比普通用户程序更高的权限级别......
  • 数字图像处理考试 简答
    数字图像处理图像控件怎么理解?图像控件是指在用户界面中用于显示、处理和交互图像的一类控件。它们通常是一些可视化元素,用户可以通过它们在应用程序中查看图像、编辑图像、调整图像参数、选择图像、上传图像等。图像控件可以是按钮、文本框、滑块、列表框、画布等。例如,一个......
  • 跨平台技术是指能够在不同操作系统和硬件平台上运行的技术。它允许开发人员使用一套代
    跨平台技术是指能够在不同操作系统和硬件平台上运行的技术。它允许开发人员使用一套代码来构建应用程序,然后将该应用程序部署到多个平台上,而无需进行大量的平台特定代码修改。以下是一些常见的跨平台技术:国产的跨平台技术:Weex:Weex是由阿里巴巴开发的跨平台移动应用开发框架。它......