首页 > 其他分享 >pytorch-Dataset-Dataloader

pytorch-Dataset-Dataloader

时间:2023-07-15 17:34:47浏览次数:51  
标签:__ tensor img self Dataloader batch Dataset pytorch data

pytorch-Dataset-Dataloader

目录

pyTorch为我们提供的两个Dataset和DataLoader类分别负责可被Pytorh使用的数据集的创建以及向训练传递数据的任务。

data.Dataset

torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

只负责数据的抽象,一次只是返回一个数据

Dataset是用来解决数据从哪里读取以及如何读取的问题。pytorch给定的Dataset是一个抽象类,所有自定义的Dataset都要继承它,

并且复写__getitem__()__len__()类方法,__getitem__()的作用是接受一个索引,返回一个样本或者标签。

__len__ 前者提供了数据集的大小,

_getitem__ 后者支持整数索引,范围从0到len(self)

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)
    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 标签是0或1

print(data_tensor.shape)
print(target_tensor.shape)

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
print(my_dataset)
print(my_dataset[0])    

# 执行结果
torch.Size([10, 3])
torch.Size([10])
<__main__.MyDataset object at 0x000002496FF46820>
(tensor([1.0655, 1.4536, 1.0800]), tensor(1))

MyDataset 的特点
1、继承于torch.utils.data.Dataset。
2、通过读取任意格式的数据、预处理、数据增强、以及数据转换、将数据以tensor输出
3、输出的结果有两个。tensor格式的数据和数据标签
4、主要是实现了三个函数__init__,__len__,__getitem__

from torch.utils.data import Dataset
from PIL import Image
import os
import json


class GetData(Dataset):

    def __init__(self, img_dir, labelfile):
        # self
        self.img_dir = img_dir
        self.img_list = os.listdir(self.img_dir)

        with open(str(labelfile)) as f:
            label = json.load(f)
        self.label = label

    def __getitem__(self, idx):
        imgname = self.img_list[idx]  # 只获取了文件名
        img_path = os.path.join(self.img_dir, imgname)  # 每个图片的位置
        img = Image.open(img_path)

        label = self.label[imgname]
        return img, label

    def __len__(self):
        return len(self.img_list)


root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
ants_dataset = GetData(root_dir, ants_label_dir)
print(len(ants_dataset))
img, lable = ants_dataset[0]  # 返回一个元组,返回值就是__getitem__的返回值
print(img.size)
print(lable)


3
(473, 266)
dog
===================================
数据目录结构
|---datasets/
	|----/1.jpg
	|----/2.jpg   
---classes.json

classes.json标注结构
{
  "1.jpg": "dog",
  "2.jpg": "dog",
  "3.jpg": "dog"
}

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要

data.DataLoader

Dataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签,在训练的过程中,一般是需要将多个数据组成batch。所以PyTorch中存在DataLoader这个迭代器。

形成batch数据,并且可以使用shuffe和加速

数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。 返回迭代器。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')


"""输入参数

dataset:      		数据集的储存的路径位置等信息,定义好的Map式或者Iterable式数据集
batch_size: 		每次取数据的数量,比如batchi_size=2
shuffle 			default: False  打乱数据
sampler: 			如果指定,"shuffle"必须为false,提取样本的策略
batch_sampler 		None,和batch_size、shuffle 、sampler and drop_last参数
num_workers 		加载数据的进程,多进程会更快
collate_fn 		    如何将多个样本数据拼接成一个batch,自定义数据读取方式,可以用来过滤数据	
pin_memory 			张量复制到CUDA内存,能够加快内存访问速度
drop_last 			如何处理数据集长度除于batch_size余下的数据。True就抛弃
timeout 			default:0	   读取超时,超时报错


在数据处理中,有时会出现某个样本无法读取等问题,如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。

返回参数
迭代器 tensor_loader
for data, target in tensor_dataloader: 
    print(data, target)
"""


# 定义加载器
tensor_loader=DataLoader(mydataset, batch_size=64,)

# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_dataloader): 
    print(data, target)
    
import torch
from torch.utils.data import Dataset,DataLoader

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]


data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 标签是0或1

print(data_tensor.shape)
print(target_tensor.shape)

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
tensor_loader=DataLoader(my_dataset, batch_size=3,)

# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_loader):
    print(len(data), len(target))
    
torch.Size([10, 3])
torch.Size([10])
3 3
3 3
3 3
1 1

自定义collate_fn 过滤失效数据

'''
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,
此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,
然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
'''
import os, json
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset


class NewDogCat(Dataset):  # 继承前面实现的DogCat数据集
    # 构造函数
    def __init__(self, img_dir, labelfile, transform):
        # self
        self.img_dir = img_dir
        self.img_list = os.listdir(self.img_dir)
        self.transform = transform

        with open(str(labelfile)) as f:
            label = json.load(f)
        self.label = label

    def __getitem__(self, idx):
        try:
            imgname = self.img_list[idx]  # 只获取了文件名
            img_path = os.path.join(self.img_dir, imgname)  # 每个图片的位置

            label = self.label[imgname]
            img = Image.open(img_path)
            img = self.transform(img)

            return img, label
        except Exception as e:
            print(e,"数据读取错误")
            return None,None

    def __len__(self):
        return len(self.img_list)


from torch.utils.data.dataloader import default_collate  # 导入默认的拼接方式
from torch.utils.data import DataLoader
from torchvision import transforms


def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x: x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch)  # 用默认方式拼接过滤后的batch数据

transform = transforms.Compose([
    transforms.Resize(224),  # 缩放图片,保持长宽比不变,最短边的长为224像素,
    transforms.CenterCrop(224),  # 从中间切出 224*224的图片
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 标准化至[-1,1]
])
root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
dataset = NewDogCat(root_dir, ants_label_dir, transform=transform)

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas[0].shape, len(batch_labels))
    
'5.jpg' 数据读取错误
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 2

总结

  1. 首先我们要去构建自己继承Dataset的MyDataSet

  2. 传入到Dataloader中,最后进行enumerate遍历每个batchsize

  3. Dataset通过index输出的最好是tensor

  4. 整体的Dataset和Dataloader中,基本上是Dataloader每次给你返回一个shuffle过的index

参考资料

https://zhuanlan.zhihu.com/p/340465632

标签:__,tensor,img,self,Dataloader,batch,Dataset,pytorch,data
From: https://www.cnblogs.com/tian777/p/17556545.html

相关文章

  • pytorch+CRNN实现
    最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。 结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。超参数请不要调整!!!!CRNN前期训练极其慢,需要良好的调参,loss才会......
  • albumentations 的数据增强为什么是 先 Normalize, 再 ToTensorV2,而 pytorch 正好相反
    albumentations:T+=[A.Normalize(mean=mean,std=std),ToTensorV2()]#NormalizeandconverttoTensortorchvision:T.ToTensor(),T.Normalize(IMAGENET_MEAN,IMAGENET_STD),原因:A.Normalize已经包含了将8位图像(0-255)转换为(0-1)(将mean和stdx255,然后再......
  • AI_Pytorch—内容回顾
    pytorch基本结构与组件-基本流程与步骤-基本方法和应用组件PyTorch都是用C++和CUDA编写的modulesandclassestorch.nn, torch.optim, Dataset,andDataLoader学、练、训、赛、研、用device=torch.device("cuda:0"iftorch.cuda.is_available()else......
  • pytorch保存模型及加载模型
    ClassTestModle(nn.Module): def__init__(self): self.conv=nn.Conv(3,6,5) self.pool=nn.MaxPool2d(2,2) ... defforward(self,x): ... ....假如有这样一个模型一、使用状态字典保存模型参数(官方推荐用法)保存模型torch.save(model.state_dict(),PATH......
  • 机器学习洞察 | 分布式训练让机器学习更加快速准确 分布式 机器学习 PyTorch Amazon S
    机器学习能够基于数据发现一般化规律的优势日益突显,我们看到有越来越多的开发者关注如何训练出更快速、更准确的机器学习模型,而分布式训练(DistributedTraining)则能够大幅加速这一进程。亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案......
  • 如何实现pso优化神经网络pytorch的具体操作步骤
    PSO优化神经网络(PyTorch)实现流程介绍本文将介绍如何使用粒子群优化(ParticleSwarmOptimization,PSO)算法来优化神经网络模型,并使用PyTorch框架来实现。PSO算法是一种基于群体智能的优化算法,通过模拟鸟群觅食行为,来搜索最优解。在神经网络中,我们可以将待优化的参数作为粒子,利用......
  • pytorch学习笔记
    1环境 opencv和pytorchpipinstallopencv-python==4.5.1.48pipinstalltorch==1.7.1+cu101torchvision==0.8.2+cu101torchaudio===0.7.2-fhttps://download.pytorch.org/whl/torch_stable.htmlDevTools安装非常方便,直接通过官方脚本命令行选择安装即可,唯一需要注意......
  • Jetson配置pytorch出现的问题
    由于无法安装Anaconda因此使用miniforge进行虚拟环境搭建,具体方法参照: 几个重要网站①JetsonZoo-eLinux.org 包含深度学习需要的下载资源配置② 安装pytorch后进行验证:1importtorch23defSettingTest():4print(torch.__version__)5print(torch.......
  • pytorch
    model.train()的作用是启用BatchNormalization和Dropout。model.eval()的作用是不启用BatchNormalization和Dropout。训练流程:deftrain(model,optimizer,epoch,train_loader,validation_loader):forbatch_idx,(data,target)inexperiment.batch_loop(it......
  • 【本周特惠课程】深度学习6大模型部署场景(Pytorch+NCNN+MNN+Tengine+TensorRT+微信小
    前言欢迎大家关注有三AI的视频课程系列,我们的视频课程系列共分为5层境界,内容和学习路线图如下:第1层:掌握学习算法必要的预备知识,包括Python编程,深度学习基础,数据使用,框架使用。第2层:掌握CV算法最底层的能力,包括模型设计基础,图像分类,模型分析。第3层:掌握CV算法最核心的方向,包括图像分......