首页 > 其他分享 >pytorch数据集加载Dataset

pytorch数据集加载Dataset

时间:2024-02-04 11:01:34浏览次数:32  
标签:__ self torch dataset pytorch Dataset data 加载

一、Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,可以快速实现对数据的加载

torch.utils.data.Dataset的源码如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

我们需在自定义的数据集类中继承Dataset类,同时还需要实现以下方法:

1、__getitem__,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

二、torch.utils.data.Dataloader

 DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

 

1、dataset:提前定义的dataset实例

2、batch_size:传入数据的batch的大小,常用有128,256等等

3、shuffle:bool类型,表示是否在每次获取数据的时候提前打扰数据

4、num_workers:加载数据的线程数

三、数据加载案例

下载国外正常短信和骚扰短信数据集,数据下载地址:

http://archive.ics.uci.edu/dataset/228/sms+spam+collection     

代码示例:

 

import torch
from torch.utils.data import Dataset,DataLoader

data_path = r"D:\coding\learning\python\pytorchtest\data\SMSSpamCollection"

#完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()

    def __getitem__(self, index):
        #获取索引对应位置的一条数据
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()   #取短信内容类型,前4个字符
        content = cur_line[4:].strip()  #短信内容
        return label,content

    def __len__(self):
        #返回数据集总量
        return  len(self.lines)
my_dataset = MyDataset()
data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True)

if __name__ == '__main__':
    my_dataset = MyDataset()
    print(my_dataset[0])   #取第0个数据
    print(len(my_dataset))  #数据数量
    for i in data_loader:
        print(i)  #循环输出

    print(len(my_dataset))
    print(len(data_loader))  #math.ceil(len(my_dataset)/batch_size) 向上取整

 运行结果:

 

标签:__,self,torch,dataset,pytorch,Dataset,data,加载
From: https://www.cnblogs.com/handsomeziff/p/18005782

相关文章

  • 手撸代码:从零开始的 AlexNet 图像分类(PyTorch框架)
    摘要:本文在PyTorch框架下搭建了AlexNet,并在CIFAR10上完成了图片分类。同时,更正了一些原论文中的小错误(如:输入图像尺寸)。由于CIFAR10没有验证集,本文将训练集的10%当作验证集。完整代码已上传至GitHub:https://github.com/TiezhuXing01/AlexNet_in_PyTorch1.引入库i......
  • PyTorch神操作:一图秒懂Tensor变形记!
    亲爱的码农小伙伴们,你们是否还在为Tensor的各种变换头大如斗?别怕,今天给大家送上一张超实用的PyTorch变换秘籍图,让你的Tensor操作如行云流水,CPU和GPU之间的切换如穿梭自如!......
  • Python中用PyTorch机器学习神经网络分类预测银行客户流失模型|附代码数据
    阅读全文:http://tecdat.cn/?p=8522最近我们被客户要求撰写关于神经网络的研究报告,包括一些图形和统计输出。分类问题属于机器学习问题的类别,其中给定一组特征,任务是预测离散值。分类问题的一些常见示例是,预测肿瘤是否为癌症,或者学生是否可能通过考试在本文中,鉴于银行客户的某些......
  • tacotron2:深度学习语音合成模型--pytorch
    https://www.python100.com/html/83067.html 一、tacotron2环境搭建如要安装tacotron2环境,需要完成以下步骤:1、安装CUDA。CUDA是Nvidia开发的并行计算平台和编程模型,需要前往官网下载并安装对应版本的CUDA,同时保证显卡支持CUDA。2、安装cuDNN。cuDNN是针对深度神经网络加速......
  • pytorch的模型推理:TensorRT的使用
    相关教程视频:TRTorch真香,一键启用TensorRT图片来源:https://www.bilibili.com/video/BV1TY411h7xC/图片来源:https://www.bilibili.com/video/BV1TY411h7xC/......
  • 华为显卡已经支持pytorch计算框架
    相关链接:https://support.huawei.com/enterprise/zh/doc/EDOC1100079287/a21c08dehttps://www.zhihu.com/question/624955377/answer/3240350483https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies/pies_00004.htmlAscend/pytorch项目地址:https:......
  • PyTorch中实现Transformer模型
    前言关于Transformer原理与论文的介绍:详细了解Transformer:AttentionIsAllYouNeed对于论文给出的模型架构,使用PyTorch分别实现各个部分。引入的相关库函数:importcopyimporttorchimportmathfromtorchimportnnfromtorch.nn.functionalimportlog_softmax......
  • 如何将PyTorch模型迁移到昇腾平台
    https://bbs.huaweicloud.com/blogs/399602?utm_source=cnblog&utm_medium=bbs-ex&utm_campaign=other&utm_content=content如何将PyTorch模型迁移到昇腾平台举报 昇腾CANN 发表于2023/04/1809:54:50  5k+  0  1 【摘要】本文介绍将PyTorch网络模型迁移到昇......
  • 【极简】Pytorch中的register_buffer()
    registerbuffer定义模型能用torch.save保存的、但是不更新参数。使用:只要是nn.Module的子类就能直接self.调用使用:classA(nn.Module):#...self.register_buffer('betas',torch.linspace(beta_1,beta_T,T).double())#...手动定义参数上述的参数显然可以......
  • 1/31JVM虚拟机 类加载
    loading加载  JAVA。lang包底下的reflect。反编译从应用破解源码,盗版!加载已经在内存中有大的class文件 验证 准备阶段静态变量都初始化为0,常量都已经初始化好符号引用 一个字节码文件不可能全装下各种需要用到的类,而是用一个符号代指,解析就是把符号引用变成指......