pytorch-Dataset-Dataloader
目录pyTorch为我们提供的两个Dataset和DataLoader类分别负责可被Pytorh使用的数据集的创建以及向训练传递数据的任务。
data.Dataset
torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。
只负责数据的抽象,一次只是返回一个数据
Dataset是用来解决数据从哪里读取以及如何读取的问题。pytorch给定的Dataset是一个抽象类,所有自定义的Dataset都要继承它,
并且复写__getitem__()
和__len__()
类方法,__getitem__()
的作用是接受一个索引,返回一个样本或者标签。
__len__
前者提供了数据集的大小,
_getitem__
后者支持整数索引,范围从0到len(self)
。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
print(data_tensor.shape)
print(target_tensor.shape)
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
print(my_dataset)
print(my_dataset[0])
# 执行结果
torch.Size([10, 3])
torch.Size([10])
<__main__.MyDataset object at 0x000002496FF46820>
(tensor([1.0655, 1.4536, 1.0800]), tensor(1))
MyDataset 的特点
1、继承于torch.utils.data.Dataset。
2、通过读取任意格式的数据、预处理、数据增强、以及数据转换、将数据以tensor输出
3、输出的结果有两个。tensor格式的数据和数据标签
4、主要是实现了三个函数__init__,__len__,__getitem__
from torch.utils.data import Dataset
from PIL import Image
import os
import json
class GetData(Dataset):
def __init__(self, img_dir, labelfile):
# self
self.img_dir = img_dir
self.img_list = os.listdir(self.img_dir)
with open(str(labelfile)) as f:
label = json.load(f)
self.label = label
def __getitem__(self, idx):
imgname = self.img_list[idx] # 只获取了文件名
img_path = os.path.join(self.img_dir, imgname) # 每个图片的位置
img = Image.open(img_path)
label = self.label[imgname]
return img, label
def __len__(self):
return len(self.img_list)
root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
ants_dataset = GetData(root_dir, ants_label_dir)
print(len(ants_dataset))
img, lable = ants_dataset[0] # 返回一个元组,返回值就是__getitem__的返回值
print(img.size)
print(lable)
3
(473, 266)
dog
===================================
数据目录结构
|---datasets/
|----/1.jpg
|----/2.jpg
---classes.json
classes.json标注结构
{
"1.jpg": "dog",
"2.jpg": "dog",
"3.jpg": "dog"
}
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要
data.DataLoader
Dataset这个类中的__getitem__
的返回值,应该是某一个样本的数据和标签,在训练的过程中,一般是需要将多个数据组成batch。所以PyTorch中存在DataLoader这个迭代器。
形成batch数据,并且可以使用shuffe和加速
数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。 返回迭代器。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')
"""输入参数
dataset: 数据集的储存的路径位置等信息,定义好的Map式或者Iterable式数据集
batch_size: 每次取数据的数量,比如batchi_size=2
shuffle default: False 打乱数据
sampler: 如果指定,"shuffle"必须为false,提取样本的策略
batch_sampler None,和batch_size、shuffle 、sampler and drop_last参数
num_workers 加载数据的进程,多进程会更快
collate_fn 如何将多个样本数据拼接成一个batch,自定义数据读取方式,可以用来过滤数据
pin_memory 张量复制到CUDA内存,能够加快内存访问速度
drop_last 如何处理数据集长度除于batch_size余下的数据。True就抛弃
timeout default:0 读取超时,超时报错
在数据处理中,有时会出现某个样本无法读取等问题,如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。
返回参数
迭代器 tensor_loader
for data, target in tensor_dataloader:
print(data, target)
"""
# 定义加载器
tensor_loader=DataLoader(mydataset, batch_size=64,)
# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_dataloader):
print(data, target)
import torch
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
print(data_tensor.shape)
print(target_tensor.shape)
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
tensor_loader=DataLoader(my_dataset, batch_size=3,)
# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_loader):
print(len(data), len(target))
torch.Size([10, 3])
torch.Size([10])
3 3
3 3
3 3
1 1
自定义collate_fn
过滤失效数据
'''
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,
此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,
然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
'''
import os, json
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
class NewDogCat(Dataset): # 继承前面实现的DogCat数据集
# 构造函数
def __init__(self, img_dir, labelfile, transform):
# self
self.img_dir = img_dir
self.img_list = os.listdir(self.img_dir)
self.transform = transform
with open(str(labelfile)) as f:
label = json.load(f)
self.label = label
def __getitem__(self, idx):
try:
imgname = self.img_list[idx] # 只获取了文件名
img_path = os.path.join(self.img_dir, imgname) # 每个图片的位置
label = self.label[imgname]
img = Image.open(img_path)
img = self.transform(img)
return img, label
except Exception as e:
print(e,"数据读取错误")
return None,None
def __len__(self):
return len(self.img_list)
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
from torch.utils.data import DataLoader
from torchvision import transforms
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x: x[0] is not None, batch))
if len(batch) == 0: return torch.Tensor()
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
transform = transforms.Compose([
transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边的长为224像素,
transforms.CenterCrop(224), # 从中间切出 224*224的图片
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])
root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
dataset = NewDogCat(root_dir, ants_label_dir, transform=transform)
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, shuffle=True)
for batch_datas, batch_labels in dataloader:
print(batch_datas[0].shape, len(batch_labels))
'5.jpg' 数据读取错误
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 2
总结
-
首先我们要去构建自己继承Dataset的MyDataSet
-
传入到Dataloader中,最后进行enumerate遍历每个batchsize
-
Dataset通过index输出的最好是tensor
-
整体的Dataset和Dataloader中,基本上是Dataloader每次给你返回一个shuffle过的index
参考资料
https://zhuanlan.zhihu.com/p/340465632
标签:__,tensor,img,self,Dataloader,batch,Dataset,pytorch,data From: https://www.cnblogs.com/tian777/p/17556545.html