首页 > 其他分享 >pytorch+CRNN实现

pytorch+CRNN实现

时间:2023-07-14 17:34:28浏览次数:43  
标签:__ 实现 self label pytorch file CRNN image size

最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。

 

结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。

超参数请不要调整!!!!CRNN前期训练极其慢,需要良好的调参,loss才会慢慢下降。

千万不要用百度开源的那个ctc!!!

网络和训练代码:

#crnn.py
import torch.nn as nn
import torch.nn.functional as F

class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):
    #                   32    1   37     256
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        # conv features
        #print('---forward propagation---')
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2) # b *512 * width
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        output = F.log_softmax(self.rnn(conv), dim=2)
        return output
#train.py
import os
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import crnn
import time
import re
import matplotlib.pyplot as plt
dic={" ":0,"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26,
     "A":27,"B":28,"C":29,"D":30,"E":31,"F":32,"G":33,"H":34,"I":35,"J":36,"K":37,"L":38,"M":39,"N":40,"O":41,"P":42,"Q":43,"R":44,"S":45,"T":46,"U":47,"V":48,"W":49,"X":50,"Y":51,"Z":52}

STR=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
n_class=53
# 调整图像大小和归一化操作
class resizeAndNormalize():
    def __init__(self, size, interpolation=cv2.INTER_LINEAR):
        # 注意对于opencv,size的格式是(w,h)
        self.size = size
        self.interpolation = interpolation
        # ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
        self.toTensor = transforms.ToTensor()

    def __call__(self, image):
        # (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴
        image = cv2.resize(image, self.size, interpolation=self.interpolation)
        # 转为tensor的数据结构
        image = self.toTensor(image)
        # 对图像进行归一化操作
        #image = image.sub_(0.5).div_(0.5)
        return image

def load_data(label_folder,image_folder,label_suffix_name=".txt",image_suffix_name=".jpg"):
    image_file,label_file,num_file=[],[],[]
    for parent_folder, _, file_names in os.walk(label_folder):
        # 遍历当前子文件夹中的所有文件
        for file_name in file_names:
            # 只处理图片文件
            # if file_name.endswith(('jpg', 'jpeg', 'png', 'gif')):#提取jpg、jpeg等格式的文件到指定目录
            if file_name.endswith((label_suffix_name)):  # 提取json格式的文件到指定目录
                # 构造源文件路径和目标文件路径
                a,b=file_name.split("gt_")
                c,d=b.split(label_suffix_name)
                image_name=image_folder + "\\" + c + image_suffix_name
                if os.path.exists(image_name):
                    label_name = label_folder + "\\" + file_name
                    txt=open(label_name,'rb')
                    txtl=txt.readlines()
                    for line in range(len(txtl)):
                        image_file.append(image_name)
                        label_file.append(label_name)
                        num_file.append(line)
    return image_file,label_file,num_file

def zl2lable(zl):
    label_list=[]
    for str in zl:
        label_list.append(dic[str])
    return label_list

class NewDataSet(Dataset):
    def __init__(self, label_source,image_source,train=True):
        super(NewDataSet, self).__init__()
        self.image_file,self.label_file,self.num_file= load_data(label_source,image_source)

    def __len__(self):
        return len(self.image_file)

    def __getitem__(self, index):
        txt = open(self.label_file[index], 'rb')
        img=cv2.imread(self.image_file[index],cv2.IMREAD_GRAYSCALE)
        wordL = txt.readlines()
        word=str(wordL[self.num_file[index]])
        pl = re.findall(r'\d+',word)
        zl = re.findall(r"[a-zA-Z]+", word)[1]  #1

        #img tensor
        x1, y1, x2, y2 = pl[:4]
        img= img[int(y1):int(y2),int(x1):int(x2), ]
        (height, width)=img.shape
        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        # ratio = 32 / float(height)
        size_width =100
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        imageTensor = transform(img)

        #label tensor
        l = zl2lable(zl)
        labelTensor = torch.IntTensor(l)
        return imageTensor,labelTensor




class CRNNDataSet(Dataset):
    def __init__(self, imageRoot, labelRoot):
        self.image_root = imageRoot
        self.image_dict = self.readfile(labelRoot)
        self.image_name = [fileName for fileName, _ in self.image_dict.items()]

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root, self.image_name[index])
        keys = self.image_dict.get(self.image_name[index])
        label = [int(x) for x in keys]

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        # if image is None:
        #     return None,None
        (height, width) = image.shape

        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        ratio = 32 / float(height)
        size_width = int(ratio * width)
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        image = transform(image)
        # 标签格式转换为IntTensor
        label = torch.IntTensor(label)
        return image, label

    def __len__(self):
        return len(self.image_name)

    def readfile(self, fileName):
        res = []
        with open(fileName, 'r') as f:
            lines = f.readlines()
            for line in lines:
                res.append(line.strip())
        dic = {}
        total = 0
        for line in res:
            part = line.split(' ')
            # 由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在
            if not os.path.exists(os.path.join(self.image_root, part[0])):
                print(os.path.join(self.image_root, part[0]))
                total += 1
            else:
                dic[part[0]] = part[1:]
        print(total)

        return dic
label_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task1_GT"
image_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task12_Images"

trainData =NewDataSet(label_sources,image_sources)

trainLoader = DataLoader(dataset=trainData, batch_size=1, shuffle=True, num_workers=0)




# valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
#                       labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
#
# valLoader = DataLoader(dataset=valData, batch_size=1, shuffle=True, num_workers=1)


#
# def decode(preds):
#     pred = []
#     for i in range(len(preds)):
#         if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i - 1])):
#             pred.append(int(preds[i]))
#     return pred
#
#
def toSTR(l):
    str_l=[]
    if isinstance(l, int):
        l=[l]
    for i in range(len(l)):
        str_l.append(STR[l[i]])
    return str_l
def toRES(l):
    new_l=[]
    new_str=' '
    for i in range(len(l)):
        if(l[i]==' '):
            new_str = ' '
            continue
        elif new_str!=l[i]:
            new_l.append(l[i])
            new_str=l[i]
    return new_l

def val(model=torch.load("pytorch-crnn.pth")):
    # 将模式切换为验证评估模式
    loss_func = torch.nn.CTCLoss(blank=0, reduction='mean')
    model.eval()

    test_n=10



    for i, (data, label) in enumerate(trainLoader):
        if(i>test_n):
            break;
        output = model(data.cuda())
        pred_label=output.max(2)[1]
        input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
        target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
        # forward(self, log_probs, targets, input_lengths, target_lengths)
        #log_probs = output.log_softmax(2).requires_grad_()
        targets = label.cuda()
        loss = loss_func(output.cpu(), targets.cpu(), input_lengths, target_lengths)

        pred_l=np.array(pred_label.cpu().squeeze()).tolist()
        label_l=np.array(targets.cpu().squeeze()).tolist()
        print(i,":",loss,"pred:",toRES(toSTR(pred_l)),"label_l",toSTR(label_l))




def train():
    use_gpu = True
    learning_rate = 0.0001

    max_epoch = 100
    n_class=53
    batch_size=20

    model = crnn.CRNN(32, 1, n_class, 256)
    if torch.cuda.is_available() and use_gpu:
        model.cuda()

    loss_func = torch.nn.CTCLoss(blank=0,reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999))

    lossTotal = 0.0
    k = 0
    printInterval = 100
    start_time = time.time()
    loss_list=[]
    total_list=[]
    for epoch in range(max_epoch):
        n=0
        data_list = []
        label_list = []
        label_len=[]
        for i, (data, label) in enumerate(trainLoader):
            #
            data_list.append(data)
            label_list.append(label)
            label_len.append(label.size(1))
            n=n+1
            if n%batch_size!=0:
                continue
            k=k+1
            data=torch.cat(data_list, dim=0)
            data_list.clear()

            label = torch.cat(label_list, dim=1).squeeze(0)
            label_list.clear()

            target_lengths=torch.tensor(np.array(label_len))
            label_len.clear()
            # 开启训练模式
            model.train()


            if torch.cuda.is_available and use_gpu:
                data = data.cuda()
                loss_func = loss_func.cuda()
                label = label.cuda()

            output = model(data)
            log_probs = output
            # example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题
            #log_probs = output.log_softmax(2).requires_grad_()
            targets = label.cuda()
            input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
            # forward(self, log_probs, targets, input_lengths, target_lengths)
            #targets =torch.zeros(targets.shape)
            loss = loss_func(log_probs.cpu(), targets, input_lengths, target_lengths)/batch_size
            lossTotal += float(loss)
            print("epoch:",epoch,"num:",i,"loss:",float(loss))
            loss_list.append(float(loss))
            if k % printInterval == 0:
                print("[%d/%d] [%d/%d] loss:%f" % (
                    epoch, max_epoch, i + 1, len(trainLoader), lossTotal / printInterval))
                total_list.append( lossTotal / printInterval)
                lossTotal = 0.0
                torch.save(model, 'pytorch-crnn.pth')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    plt.figure()
    plt.plot(loss_list)
    plt.savefig("loss.jpg")

    plt.clf()
    plt.figure()
    plt.plot(total_list)
    plt.savefig("total.jpg")
    end_time = time.time()
    print("takes {}s".format((end_time - start_time)))
    return model

if __name__ == '__main__':
    val()

测试结果如下:

最后给一些参考文献:

(34条消息) 使用pytorch实现crnn_crnn pytorch_wzw12315的博客-CSDN博客

ctc loss - azheng333 - 博客园 (cnblogs.com)

另外给出一些数据集和我训练好的模型:

链接:https://pan.baidu.com/s/1-jTA22bLKv2ut_1EJ1WMKA?pwd=jvk8
提取码:jvk8

 

标签:__,实现,self,label,pytorch,file,CRNN,image,size
From: https://www.cnblogs.com/xmds/p/17554561.html

相关文章

  • 【JavaScript】js 处理复制函数实现
    exportconstcopyText=(text:string)=>{constinput=document.createElement('input');input.setAttribute('readonly','readonly');input.setAttribute('value',text);document.body.appendChild(input);......
  • Eigen::Tensor实现permute方法
    需求使用C++处理Eigen::Tensor希望交换指定维度的位置注意是交换(改变内存顺序)而不是reshape实现torch.tensor中内置了permute方法实现快速交换Eigen::Tensor中实现相同操作需要一点技巧例如,将一个1x2x3的tensor排列为3x1x2那么对应t1[0,1,1]==t2[1,0,1]则排列生效代码如......
  • java--类单继承多实现,接口多继承
     继承是从已有的类中派生出新的类,新的类能吸收已有类的数据属性和行为,并能扩展新的能力。java继承分为:  单继承:一个子类最多只能有一个父类。  多继承:一个子类有两个以上的父类。一个接口可以继承多个类,一个类可以实现多个接口。为什么说类是单继承的?1、java的类只......
  • Python如何实现自动生成报表并以邮件发送
    Python如何实现自动生成报表并以邮件发送首先来介绍下实现自动报表要使用到的Python库:pymysql一个可以连接MySQL实例并且实现增删改查功能的库datetimePython标准库中自带的关于时间的库openpyxl一个可以读写07版以后的Excel文档(.xlsx格式也支持)的库smtplibSMTP即简单......
  • Ubuntu DC + Samba4 AD 实现双域控主机模
    文章将讲解如何使用Ubuntu16.04服务器版系统来创建第二台Samba4域控制器,并将其加入到已创建好的SambaADDC林环境中,以便为一些关键的ADDC服务提供负载均衡及故障切换功能,尤其是为那些重要的服务,比如DNS服务和使用SAM数据库的ADDCLDAP模式。第一步:为设置......
  • js实现多列排序
    js实现多列排序根据业务逻辑调整sortData的数据。排序的规则是按照第一列排序,第一列相同按照第二列排序,依次类推//要排序的数据constarray=[{name:'甲'asd,age:10,money:100},{name:'亿',age:10,money:90},{name:'丙',age:9,money:100}]//......
  • vue 实现自定义主题切换+日夜切换
    此处只参考修改主题色,日夜模式参考elep官网functiongradientColor(this:any,startColor:any,endColor:any,step:any){letstartRGB=gradientColor.prototype.colorRgb(startColor);//转换为rgb数组模式letstartR=startRGB[0];letstartG=star......
  • YOLOX目标检测实战:LabVIEW+YOLOX ONNX模型实现推理检测(含源码)
    (文章目录)前言好长一段时间没更博了,没更新博客的这段时间博主都有在努力产出,前段时间好多朋友私信问我说自己的yolov5模型是比较老的版本,使用LabVIEW推理的时候会报错。为各位朋友新老版本都能兼容,博主这段时间做了一个LabVIEWYOLOv5的插件,里面包含了大部分的新旧版本,老版本的......
  • 虚拟主机httpd.ini怎么实现301重定向
    首先询问你的服务商你使用的虚拟主机是否支持伪静态,如果支持的话,请在根目录创建文件名为:httpd.ini的文件,代码写法如下:RewriteCondHost:^jinmengqiang\.cn$RewriteRule(.*)http\://www\.jinmengqiang\.cn$1[I,RP]把以上代码中本站的域名替换成你的即可.本文链接:https://......
  • java公式如何实现
    Java公式如何实现导言在实际的软件开发过程中,经常会遇到需要实现一些数学公式的情况。Java作为一种强大的编程语言,提供了丰富的数学函数库和操作符,可以很方便地实现各种数学公式。本文将详细介绍如何使用Java来实现数学公式,并通过一个实际问题的例子进行说明。实际问题假设我们......