首页 > 其他分享 >pytorch训练简单加减验证码(一):数据加载器实现

pytorch训练简单加减验证码(一):数据加载器实现

时间:2024-05-07 10:37:02浏览次数:20  
标签:__ tensor img self torch 验证码 pytorch data 加载

1、torch.utils.data.Dataset

torch.utils.data.Dataset 是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()和__getitem__()这两个魔法方法。

len():返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()来获取序列的长度,魔法方法__len__()的目的就是方便像序列那样直接获取对象的长度。如果A是一个类,a是类A的实例化对象,当A中定义了魔法方法__len__(),len(a)则返回对象的大小。

getitem():实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,getitem()则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()中实现数据预处理

示例:

import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    """
    TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
    实现将一组Tensor数据对封装成Tensor数据集
    能够通过index得到数据集的数据,能够通过len,得到数据集大小
    """

    def __init__(self, data_tensor, target_tensor):
        super().__init__()
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    def __len__(self):
       return self.data_tensor.size(0)

    def __getitem__(self, idx):
        return self.data_tensor[idx], self.target_tensor[idx]


if __name__ == '__main__':
    # 生成数据
    data_tensor = torch.randn(4, 3)
    target_tensor = torch.rand(4)
    # 将数据封装成Dataset
    tensor_dataset = MyDataset(data_tensor, target_tensor)

    # 可使用索引调用数据
    print(tensor_dataset[1])
    # 输出:(tensor([-1.0351, -0.1004,  0.9168]), tensor(0.4977))

    # 获取数据集大小
    print(len(tensor_dataset))
    # 输出:4

实现数据加载器

1、将验证码向量化

api说明:

1.1 torch.zeros 返回一个形状为为size,类型为torch.dtype,里面的每一个值都是0的tensor
    import torch
    torch.zeros(3,5) # 输出3行5列每一个值都是0的tensor

示例:

def text2Vec(text):
    """

    :param text:  验证码
    :return:  向量化数据
    """
    captcha_array = list("0123456789+-×÷=?abcdefghijklmnopqrstuvwxyz")
    vec = torch.zeros(5, 42)  # 返回一个形状为为size,类型为torch.dtype,里面的每一个值都是0的tensor
    # text: xxb
    for i in range(len(text)):
        # print(common.captcha_array.index(text[i]))
        # 例子: i=0 , text[i] =x,
        # common.captcha_array.index(text[i] 在字符串中查找x, 返回下标
        # vec[i, captcha_array.index(text[i])] = 1  将每个字符,在向量中进行标记 ==> vec[0, 39] 向量中0行39列被标记为1
        vec[i, captcha_array.index(text[i])] = 1
    return vec

2、 将向量化的验证码还原

def vec2Text(vec):
    
    captcha_array = list("0123456789+-×÷=?abcdefghijklmnopqrstuvwxyz")
    # torch.argmax 取出维度最大值的序号
    # torch.argmax(input, dim=None, keepdim=False)
    # input;输入向量, dim:dim的不同值表示不同维度。特别的在dim=0表示二维中的列,dim=1在二维矩阵中表示行。
    vec = torch.argmax(vec, dim=1)  # 把为1的取出来
    # print(vec)
    text = ''
    for i in vec:
        text += captcha_array[i]
    return text

3、实现数据加载器

def make_dataset(data_path):
    """
    处理数据集路径
    :param data_path: 数据集路径
    :return:  一个列表 列表里有一个集合,集合中包含图片路劲和image_name,
    [('D:\\img_evn\\recognition\\dataset\\train\\9÷9=?_ce12847432adccafc1684a6e1351317e.jpg', '9÷9=?')]
    """
    img_names = os.listdir(data_path)  # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
    samples = []
    for img_name in img_names:
        img_path = data_path + '\\' + img_name
        target_str = img_name.split('_')[0]  # 以下划线切割,取出真实值[9×0=?,fdb20a0dea5b8e092cb7a9e846001730.jpg][0]
        samples.append((img_path, target_str))  # 写进列表
    return samples


class CaptchaData(Dataset):
    def __init__(self, data_path, transform=None):
        """
        初始化方法
        :param data_path: 数据集路径
        :param transform: 数据预处理
        """
        super(Dataset, self).__init__()
        self.transform = transform
        self.samples = make_dataset(data_path)

    def __len__(self):
        return len(self.samples)  # 返回数据集图片数量

    def __getitem__(self, idx):
        img_path, target = self.samples[idx]
        target = text2Vec(target)
        # view: 该函数返回一个有__相同数据__但不同大小的 Tensor。通俗一点,就是__改变矩阵维度__
        # x = torch.randn(4, 4)
        # print(x.size())
        # y = x.view(16)
        # print(y.size())
        # z = x.view(-1, 8)  # -1表示该维度取决于其它维度大小,即(4*4)/ 8
        # print(z.size())
        # m = x.view(2, 2, 4) # 也可以变为更多维度
        # print(m.size())
        target = target.view(1, -1)[0]
        img = Image.open(img_path)
        img = img.resize((160, 60))
        img = img.convert('RGB')  # img转成向量
        if self.transform is not None:
            img = self.transform(img)
        return img, target

标签:__,tensor,img,self,torch,验证码,pytorch,data,加载
From: https://www.cnblogs.com/zdl-spider/p/18175546

相关文章

  • Unity热更学习toLua使用--[1]toLua的导入和默认加载执行lua脚本
    [0]toLua的导入下载toLua资源包,访问GitHub项目地址,点击下载即可。将文件导入工程目录中:导入成功之后会出现Lua菜单栏,如未成功生成文件,可以点击GenerateAll重新生成(注意很可能是路径问题导致的生成失败!)之后就可以开始编写脚本执行第一个lua程序了![1]C#调用Lua脚本编写C#......
  • HarmonyOS 实现下拉刷新,上拉加载更多
    组件介绍PullToRefreshList允许用户通过下拉动作来刷新列表内容,以及通过上拉动作来加载更多的数据。组件内部封装了滚动监听、状态管理和动画效果,使得开发者可以轻松集成到自己的项目中。1.实现思路封装成可复用的公共控件:将下拉刷新和上拉加载更多功能封装为一个可复用的组......
  • 在IDEA中加载OpenJDK源码
    之所以要阅读OpenJDK源码,是因为SunJDK的某些源码是缺失的,以JDK1.8为例,sun.reflect,sun.rmi及其子包下的类都是没有源码的。如下以下载OpenJDK1.8源码为例进行说明。下载OpenJDK源码文件,如下载zip格式的压缩包。解压OpenJDK源码压缩包文件,在IDEA中按如下路径加载:【File】......
  • ubuntu 上安装pytorch-cuda
    安装nvidia驱动不再赘述安装gcc环境sudoapt-getinstallbuild-essentialsudoportaudio19-devunzipx11-utils1build-essential用于安装一个软件包集合,其中包含了编译软件时经常需要使用的工具和库。这个软件包集合通常包括编译器(如gcc)、make工具、头文件等。build......
  • docker pytorch离线安装
    先在ubuntu18.0464位环境里,有联网情况下操作:安装dockerpytorch镜像:dockerpullpytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime下载依赖:bonelee@ubuntu:~/Desktop/pythonProject$sudodockerps-aCONTAINERIDIMAGE......
  • 《深度学习原理与Pytorch实战》(第二版)(三)11-15章
    第11章神经机器翻译器——端到端机器翻译神经机器翻译,google旗下的NMT编码-解码模型:用编码器和解码器组成一个翻译机,先用编码器将源信息编码为内部状态,再通过解码器将内部状态解码为目标语言。编码过程对应了阅读源语言句子的过程,解码过程对应了将其重组为目标语言的过程——......
  • Echarts -- 实现动态加载series
    Echarts--实现动态加载series:https://blog.csdn.net/m0_74444744/article/details/134467184?ops_request_misc=&request_id=&biz_id=102&utm_term=echarts%20series&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-1-1344671......
  • Spring学习之——Bean加载流程
    Spring IOC容器就像是一个生产产品的流水线上的机器,Spring创建出来的Bean就好像是流水线的终点生产出来的一个个精美绝伦的产品。既然是机器,总要先启动,Spring也不例外。因此Bean的加载流程总体上来说可以分为两个阶段:容器启动阶段Bean创建阶段一、容器启动阶段:容器的启动阶......
  • Sxstrace.exe 是 Windows 操作系统提供的一个工具,用于跟踪和分析应用程序的依赖项解析
    sxstrace|MicrosoftLearnSxstrace.exe是Windows操作系统提供的一个工具,用于跟踪和分析应用程序的依赖项解析过程。该工具可以帮助用户诊断应用程序启动或运行时出现的依赖项错误或加载问题。在Windows中,许多应用程序依赖于共享组件和库文件,如动态链接库(DLL)。当应用......
  • 《深度学习原理与Pytorch实战》(第二版)(二)
    第6章手写数字加法器——迁移学习迁移学习允许训练集和测试集的数据有不同的分布、目标、领域;而一般的监督学习要求训练集和测试集上的数据有相同的分布特性一个有意思的想法:大公司运用大数据训练大模型,再将这些模型迁移到小公司擅长的特定垂直领域中,这样就可以将泛化的大模......