1、torch.utils.data.Dataset
torch.utils.data.Dataset 是代表自定义数据集方法的类,用户可以通过继承该类来自定义自己的数据集类,在继承时要求用户重载__len__()和__getitem__()这两个魔法方法。
len():返回的是数据集的大小。我们构建的数据集是一个对象,而数据集不像序列类型(列表、元组、字符串)那样可以直接用len()来获取序列的长度,魔法方法__len__()的目的就是方便像序列那样直接获取对象的长度。如果A是一个类,a是类A的实例化对象,当A中定义了魔法方法__len__(),len(a)则返回对象的大小。
getitem():实现索引数据集中的某一个数据。我们知道,序列可以通过索引的方法获取序列中的任意元素,getitem()则实现了能够通过索引的方法获取对象中的任意元素。此外,我们可以在__getitem__()中实现数据预处理
示例:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
"""
TensorDataset继承Dataset, 重载了__init__(), __getitem__(), __len__()
实现将一组Tensor数据对封装成Tensor数据集
能够通过index得到数据集的数据,能够通过len,得到数据集大小
"""
def __init__(self, data_tensor, target_tensor):
super().__init__()
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __len__(self):
return self.data_tensor.size(0)
def __getitem__(self, idx):
return self.data_tensor[idx], self.target_tensor[idx]
if __name__ == '__main__':
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = MyDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print(tensor_dataset[1])
# 输出:(tensor([-1.0351, -0.1004, 0.9168]), tensor(0.4977))
# 获取数据集大小
print(len(tensor_dataset))
# 输出:4
实现数据加载器
1、将验证码向量化
api说明:
1.1 torch.zeros 返回一个形状为为size,类型为torch.dtype,里面的每一个值都是0的tensor
import torch
torch.zeros(3,5) # 输出3行5列每一个值都是0的tensor
示例:
def text2Vec(text):
"""
:param text: 验证码
:return: 向量化数据
"""
captcha_array = list("0123456789+-×÷=?abcdefghijklmnopqrstuvwxyz")
vec = torch.zeros(5, 42) # 返回一个形状为为size,类型为torch.dtype,里面的每一个值都是0的tensor
# text: xxb
for i in range(len(text)):
# print(common.captcha_array.index(text[i]))
# 例子: i=0 , text[i] =x,
# common.captcha_array.index(text[i] 在字符串中查找x, 返回下标
# vec[i, captcha_array.index(text[i])] = 1 将每个字符,在向量中进行标记 ==> vec[0, 39] 向量中0行39列被标记为1
vec[i, captcha_array.index(text[i])] = 1
return vec
2、 将向量化的验证码还原
def vec2Text(vec):
captcha_array = list("0123456789+-×÷=?abcdefghijklmnopqrstuvwxyz")
# torch.argmax 取出维度最大值的序号
# torch.argmax(input, dim=None, keepdim=False)
# input;输入向量, dim:dim的不同值表示不同维度。特别的在dim=0表示二维中的列,dim=1在二维矩阵中表示行。
vec = torch.argmax(vec, dim=1) # 把为1的取出来
# print(vec)
text = ''
for i in vec:
text += captcha_array[i]
return text
3、实现数据加载器
def make_dataset(data_path):
"""
处理数据集路径
:param data_path: 数据集路径
:return: 一个列表 列表里有一个集合,集合中包含图片路劲和image_name,
[('D:\\img_evn\\recognition\\dataset\\train\\9÷9=?_ce12847432adccafc1684a6e1351317e.jpg', '9÷9=?')]
"""
img_names = os.listdir(data_path) # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
samples = []
for img_name in img_names:
img_path = data_path + '\\' + img_name
target_str = img_name.split('_')[0] # 以下划线切割,取出真实值[9×0=?,fdb20a0dea5b8e092cb7a9e846001730.jpg][0]
samples.append((img_path, target_str)) # 写进列表
return samples
class CaptchaData(Dataset):
def __init__(self, data_path, transform=None):
"""
初始化方法
:param data_path: 数据集路径
:param transform: 数据预处理
"""
super(Dataset, self).__init__()
self.transform = transform
self.samples = make_dataset(data_path)
def __len__(self):
return len(self.samples) # 返回数据集图片数量
def __getitem__(self, idx):
img_path, target = self.samples[idx]
target = text2Vec(target)
# view: 该函数返回一个有__相同数据__但不同大小的 Tensor。通俗一点,就是__改变矩阵维度__
# x = torch.randn(4, 4)
# print(x.size())
# y = x.view(16)
# print(y.size())
# z = x.view(-1, 8) # -1表示该维度取决于其它维度大小,即(4*4)/ 8
# print(z.size())
# m = x.view(2, 2, 4) # 也可以变为更多维度
# print(m.size())
target = target.view(1, -1)[0]
img = Image.open(img_path)
img = img.resize((160, 60))
img = img.convert('RGB') # img转成向量
if self.transform is not None:
img = self.transform(img)
return img, target
标签:__,tensor,img,self,torch,验证码,pytorch,data,加载
From: https://www.cnblogs.com/zdl-spider/p/18175546