首页 > 其他分享 >PytorchOCR工程的CRNN文本识别模型训练

PytorchOCR工程的CRNN文本识别模型训练

时间:2023-03-29 16:06:29浏览次数:48  
标签:PytorchOCR txt self batch CRNN path model 文本 type

环境:python3.9+pytorch1.8.1+opencv4.5.2+cuda11.1

pyTorchOCR工程:

https://github.com/WenmuZhou/PytorchOCR

 

PytorchOCR工程的CRNN文本识别模型训练_OCR


1、准备训练数据:(这里是生成的数据

生成:https://blog.51cto.com/u_8681773/6004679

生成工具:https://blog.51cto.com/u_8681773/6157100

)

1、这里以日期数据为例子:

PytorchOCR工程的CRNN文本识别模型训练_OCR_02

 

PytorchOCR工程的CRNN文本识别模型训练_python_03

 

 

2、根据文档要求,数据集列表格式如下:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_04

 

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_05

 

 

3、准备图像标签列表文件:

PytorchOCR工程的CRNN文本识别模型训练_CRNN_06

注意文件内容只有两列,路径和内容;同时注意文件的编码格式要求utf-8的:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_07

 

Utf-8-BOM还不行:

转换成utf-8再保存即可:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_08

 

准备的校验测试文件列表也一样:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_09

 

准备字符标签文件:(这里和之前的不同,这里第一个没有Blank),同时要注意编码是utf-8的。

PytorchOCR工程的CRNN文本识别模型训练_pytorch_10

 

 

这样数据就准备好了。


1、在有一个训练列表文件时,拆分成一个训练的、一个测试的。

PytorchOCR工程的CRNN文本识别模型训练_python_11

 

2、使用脚本:

import os
import glob
import pathlib
import random

# 将-生成的数字数据train.txt列表,分成两个,train.txt和test.txt
#适配pytorchOCR的工程所需

data_path = r'E:\datasets\gen_mini3_charset'

save_path = r'E:\datasets\gen_mini3_charset'

for txt_path in glob.glob(data_path + '/train.txt', recursive=True):
    d = pathlib.Path(txt_path)
    if os.path.exists(txt_path):
        print(txt_path)
    else:
        print('不存在', txt_path)
        continue

    save_train_path = save_path + '/pytorchocr_train'+ '.txt'
    save_test_path = save_path + '/pytorchocr_test' + '.txt'
    cnt = 0
    f_w_train = open(save_train_path, 'w', encoding='utf8')
    f_w_test = open(save_test_path, 'w', encoding='utf8')

    try:
        with open(txt_path, "r", encoding='utf8') as f:
            for line in f.readlines():
                line = line.strip('\n')  # 去掉列表中每一行元素的换行符
                lineData = line.split('\t')
                image = lineData[0]
                text_value = lineData[1]
                idxlist = []  ## 空列表
                idxlist.append(image)
                idxlist.append(text_value)
                outlinestr = '\t'.join(idxlist)
                if cnt < 1000000 and random.randint(1, 100) % 5 == 0:
                    cnt = cnt + 1
                    f_w_test.write('{}\n'.format(outlinestr))
                else:
                    f_w_train.write('{}\n'.format(outlinestr))
    except ValueError:
        f_w_train.close()


    f_w_train.close()
    f_w_test.close()

PytorchOCR工程的CRNN文本识别模型训练_pytorch_12

 

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_13



3、准备好配置文件:

1、复制一份CRNN配置文件:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_14

 

 

2、修改配置:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_15

 

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_16

 

PytorchOCR工程的CRNN文本识别模型训练_OCR_17

 

PytorchOCR工程的CRNN文本识别模型训练_python_18

 

PytorchOCR工程的CRNN文本识别模型训练_python_19

 

3、复制一份训练的脚本文件到PytorchOCR目录下:

PytorchOCR工程的CRNN文本识别模型训练_CRNN_20

修改一下指定的配置脚本文件:

PytorchOCR工程的CRNN文本识别模型训练_OCR_21

 1、运行训练脚本进行模型的训练:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_22

 

PytorchOCR工程的CRNN文本识别模型训练_python_23

 


4、模型测试:

1、复制一份测试脚本到pytorchOCR目录下:

PytorchOCR工程的CRNN文本识别模型训练_CRNN_24

 

2、指定好模型和测试图像:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_25

 

PytorchOCR工程的CRNN文本识别模型训练_python_26

 

PytorchOCR工程的CRNN文本识别模型训练_pytorch_27

 

3、测试结果:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_28

 

PytorchOCR工程的CRNN文本识别模型训练_pytorch_29

 

PytorchOCR工程的CRNN文本识别模型训练_pytorch_30

 

有结果,效果不好。

可以修改主干网络和数据集增强来提升模型的效果:

PytorchOCR工程的CRNN文本识别模型训练_文本识别模型_31

 

PytorchOCR工程的CRNN文本识别模型训练_pytorch_32

 

譬如改resnet34:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_33

 

PytorchOCR工程的CRNN文本识别模型训练_CRNN_34

 

5、可能报错:

2023-01-20 23:15:34,331 - torchocr - INFO - {'exp_name': 'CRNN_E13B', 'train_options': {'resume_from': '', 'third_party_name': '', 'checkpoint_save_dir': './output/CRNN_E13B/checkpoint', 'device': 'cuda:0', 'epochs': 200, 'fine_tune_stage': ['backbone', 'neck', 'head'], 'print_interval': 10, 'val_interval': 100, 'ckpt_save_type': 'HighestAcc', 'ckpt_save_epoch': 2}, 'SEED': 927, 'optimizer': {'type': 'Adam', 'lr': 0.001, 'weight_decay': 0.0001}, 'lr_scheduler': {'type': 'StepLR', 'step_size': 60, 'gamma': 0.5}, 'model': {'type': 'RecModel', 'backbone': {'type': 'MobileNetV3', 'model_name': 'small'}, 'neck': {'type': 'PPaddleRNN', 'hidden_size': 48}, 'head': {'type': 'CTC', 'n_class': 15}, 'in_channels': 3}, 'loss': {'type': 'CTCLoss', 'blank_idx': 0}, 'dataset': {'alphabet': 'E:/datasets/e13b/labels_E13B.txt', 'train': {'dataset': {'type': 'RecTextLineDataset', 'file': 'E:/datasets/e13b/train.txt', 'input_h': 32, 'mean': 0.5, 'std': 0.5, 'augmentation': False}, 'loader': {'type': 'DataLoader', 'batch_size': 16, 'shuffle': False, 'num_workers': 1, 'collate_fn': {'type': 'RecCollateFn', 'img_w': 640}}}, 'eval': {'dataset': {'type': 'RecTextLineDataset', 'file': 'E:/datasets/e13b/eval.txt', 'input_h': 32, 'mean': 0.5, 'std': 0.5, 'augmentation': False}, 'loader': {'type': 'RecDataLoader', 'batch_size': 4, 'shuffle': False, 'num_workers': 1, 'collate_fn': {'type': 'RecCollateFn', 'img_w': 640}}}}}

2023-01-20 23:15:36,951 - torchocr - INFO - net resume from scratch.

Traceback (most recent call last):

  File "E:\pywork\PytorchOCR_pro\PytorchOCR\rec_train_for_e13b.py", line 346, in <module>

    main()

  File "E:\pywork\PytorchOCR_pro\PytorchOCR\rec_train_for_e13b.py", line 331, in main

    loss_func = build_loss(cfg['loss'])

  File "E:\pywork\PytorchOCR_pro\PytorchOCR\torchocr\networks\losses\__init__.py", line 21, in build_loss

    criterion = eval(loss_type)(**copy_config)

TypeError: __init__() got an unexpected keyword argument 'blank_idx'

 

Process finished with exit code 1

 

PytorchOCR工程的CRNN文本识别模型训练_CRNN_35

 

 

解决:

PytorchOCR工程的CRNN文本识别模型训练_pytorch_36

 

PytorchOCR工程的CRNN文本识别模型训练_CRNN_37

 



6、pytorch模型转onnx模型:

PytorchOCR工程的CRNN文本识别模型训练_OCR_38

 

PytorchOCR工程的CRNN文本识别模型训练_python_39

# -*- coding: utf-8 -*-
# @Time    : 2020/6/16 10:57
# @Author  : zhoujun
import os
import sys
import pathlib

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 将 torchocr路径加到python陆经里
__dir__ = pathlib.Path(os.path.abspath(__file__))

import numpy as np

sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))

import torch
from torch import nn
from torchocr.networks import build_model
from torchocr.datasets.RecDataSet import RecDataProcess
from torchocr.utils import CTCLabelConverter


class RecInfer:
    def __init__(self, model_path, batch_size=16):
        ckpt = torch.load(model_path, map_location='cpu')
        cfg = ckpt['cfg']
        self.model = build_model(cfg['model'])
        # state_dict = {}
        # for k, v in ckpt['state_dict'].items():
        #     state_dict[k.replace('module.', '')] = v
        self.model.load_state_dict(ckpt['state_dict'])

        # self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # self.model.to(self.device)
        self.model.eval()

        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
        self.batch_size = batch_size

    def predict(self, imgs):
        # 预处理根据训练来
        if not isinstance(imgs,list):
            imgs = [imgs]
        imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
        widths = np.array([img.shape[1] for img in imgs])
        idxs = np.argsort(widths)
        txts = []
        for idx in range(0, len(imgs), self.batch_size):
            batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
            batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
            batch_imgs = np.stack(batch_imgs)
            tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
            tensor = tensor.to(self.device)
            with torch.no_grad():
                out = self.model(tensor)
                out = out.softmax(dim=2)
            out = out.cpu().numpy()
            txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
        #按输入图像的顺序排序
        idxs = np.argsort(idxs)
        out_txts = [txts[idx] for idx in idxs]
        return out_txts


def init_args():
    import argparse
    parser = argparse.ArgumentParser(description='PytorchOCR infer')
    parser.add_argument('--model_path',  default='pretrainmodel/ch_rec_server_crnn_res34.pth', type=str, help='rec model path')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    import cv2

    args = init_args()
    model = RecInfer(args.model_path)
    batch_size = 1
    dummy_input = torch.randn(batch_size, 3, 32, 320)

    # print(dummy_input)
    input_names = ["input"]
    output_names = ["output"]
    #dynamic_axes = {"input": {3: 'width'}, "output": {0: 'width'}}
    dynamic_axes = {"input": {0: 'batch', 3: 'width'}, "output": {0: 'width'}}
    #dynamic_axes = {"input": {0: 'batch', 3: 'width'}, "output": {0: 'width'}} #这个是动态batch
    torch.onnx.export(model.model, dummy_input, "pretrainmodel/ch_rec_server_crnn_res34.onnx",
                      verbose=True, input_names=input_names, output_names=output_names,
                      dynamic_axes=dynamic_axes,
                      opset_version=11)
    print("convert crnn rec to onnx ok !!!")

标签:PytorchOCR,txt,self,batch,CRNN,path,model,文本,type
From: https://blog.51cto.com/u_8681773/6157207

相关文章