首页 > 其他分享 >3. dataset、dataloader

3. dataset、dataloader

时间:2024-04-01 22:11:55浏览次数:23  
标签:__ index 批次 self dataloader dataset def

dataset 数据集
dataloader 数据加载器

1. AI训练时的需求

  1. 有一个数据集文件来,里面有100w的样本和标签
  2. 训练时,通常希望,一次在100w中随机抓取batch个样本,拿去训练
  3. 如果全部抓取完毕,则重新打乱后,再来一次

2. dataset,数据集

  • 作用:储存数据集的信息self.xxx
  • 获取数据集长度__len__
  • 获取数据集某个特定条目的内容 __getitem__
class ImageDataset:
    def __init__(self, raw_data):
        self.raw_Data = raw_data
    
    def __len__(self):
        return len(self.raw_Data)
    
    def __getitem__(self, index):
        image, lable = self.raw_Data[index]
        return image, lable
        # return self.raw_Data[index]

images = [[f"image{i}", i] for i in range(100)]

3. dataloader

数据加载器

作用:从数据集随机加载数据,并拼接为一个batch
实现选代器,可以让使用时,选代获取数据内容

class DataLoader:
    def __init__(self, dataset, batchsize):
        self.dataset = dataset
        self.batchsize = batchsize
    
    def __iter__(self):
        # 每次准备迭代的时候打乱顺序,并清空指针
        # 正常序列
        self.indexs = np.arange(len(self.dataset))
        self.cursor = 0  # 指针

        # 打乱序列
        np.random.shuffle(self.indexs)
        return self
    
    def __next__(self):
        # 预期在这里返回一个batch的数据,在dataset随机抓取
        # 抓batch个数据前,先抓batch个index
        begin = self.cursor
        end = self.cursor + self.batchsize

        # 如果迭代已经到底
        if end > len(self.dataset):  # 等于也是在范围里
            raise StopIteration()
        
        # 如果还没到底,就更新指针,再取出数据
        self.cursor = end
        batched_data = []
        for index in self.indexs[begin:end]:
            item = self.dataset[index]
            batched_data.append(item)
        
        return batched_data

4. 整个演示

import numpy as np

class ImageDataset:
    def __init__(self, raw_data):
        self.raw_Data = raw_data
    
    def __len__(self):
        return len(self.raw_Data)
    
    def __getitem__(self, index):
        image, lable = self.raw_Data[index]
        return image, lable
        # return self.raw_Data[index]

class DataLoader:
    def __init__(self, dataset, batchsize):
        self.dataset = dataset
        self.batchsize = batchsize
    
    def __iter__(self):
        # 每次准备迭代的时候打乱顺序,并清空指针
        # 正常序列
        self.indexs = np.arange(len(self.dataset))
        self.cursor = 0  # 指针

        # 打乱序列
        np.random.shuffle(self.indexs)
        return self
    
    def __next__(self):
        # 预期在这里返回一个batch的数据,在dataset随机抓取
        # 抓batch个数据前,先抓batch个index
        begin = self.cursor
        end = self.cursor + self.batchsize

        # 如果迭代已经到底
        if end > len(self.dataset):  # 等于也是在范围里
            raise StopIteration()
        
        # 如果还没到底,就更新指针,再取出数据
        self.cursor = end
        batched_data = []
        for index in self.indexs[begin:end]:
            item = self.dataset[index]
            batched_data.append(item)
        
        return batched_data


images = [[f"image{i}", i] for i in range(100)]  # [['image0', 0], ['image1', 1],...]

dataset = ImageDataset(images)
loader = DataLoader(dataset, 5)

for index, batched_data in enumerate(loader):
    print(f"第{index}个批次是", batched_data)

# 第0个批次是 [('image46', 46), ('image79', 79), ('image44', 44), ('image65', 65), ('image50', 50)]
# 第1个批次是 [('image2', 2), ('image34', 34), ('image90', 90), ('image73', 73), ('image17', 17)]
# 第2个批次是 [('image49', 49), ('image55', 55), ('image7', 7), ('image20', 20), ('image31', 31)]
# 第3个批次是 [('image98', 98), ('image70', 70), ('image52', 52), ('image26', 26), ('image47', 47)]
# 第4个批次是 [('image5', 5), ('image37', 37), ('image38', 38), ('image66', 66), ('image81', 81)]
# 第5个批次是 [('image71', 71), ('image1', 1), ('image43', 43), ('image86', 86), ('image35', 35)]
# 第6个批次是 [('image85', 85), ('image61', 61), ('image92', 92), ('image23', 23), ('image16', 16)]
# 第7个批次是 [('image67', 67), ('image69', 69), ('image63', 63), ('image8', 8), ('image21', 21)]
# 第8个批次是 [('image32', 32), ('image0', 0), ('image14', 14), ('image22', 22), ('image42', 42)]
# 第9个批次是 [('image6', 6), ('image40', 40), ('image72', 72), ('image62', 62), ('image39', 39)]
# 第10个批次是 [('image3', 3), ('image10', 10), ('image30', 30), ('image59', 59), ('image97', 97)]
# 第11个批次是 [('image11', 11), ('image36', 36), ('image25', 25), ('image80', 80), ('image84', 84)]
# 第12个批次是 [('image76', 76), ('image96', 96), ('image29', 29), ('image18', 18), ('image94', 94)]
# 第13个批次是 [('image68', 68), ('image24', 24), ('image57', 57), ('image12', 12), ('image13', 13)]
# 第14个批次是 [('image28', 28), ('image91', 91), ('image89', 89), ('image27', 27), ('image58', 58)]
# 第15个批次是 [('image53', 53), ('image82', 82), ('image87', 87), ('image93', 93), ('image33', 33)]
# 第16个批次是 [('image64', 64), ('image83', 83), ('image74', 74), ('image51', 51), ('image60', 60)]
# 第17个批次是 [('image19', 19), ('image41', 41), ('image15', 15), ('image77', 77), ('image56', 56)]
# 第18个批次是 [('image99', 99), ('image9', 9), ('image45', 45), ('image54', 54), ('image78', 78)]
# 第19个批次是 [('image75', 75), ('image48', 48), ('image95', 95), ('image4', 4), ('image88', 88)]

标签:__,index,批次,self,dataloader,dataset,def
From: https://www.cnblogs.com/ratillase/p/18109488

相关文章

  • Pytorch - Dataloader
    BasicallytheDataLoaderworkswiththeDatasetobject.SotousetheDataLoaderyouneedtogetyourdataintothisDatasetwrapper.Todothisyouonlyneedtoimplementtwomagicmethods:__getitem__and__len__.The__getitem__takesanindexandretu......
  • 05-快速理解SparkSQL的DataSet
    1定义一个数据集是分布式的数据集合。Spark1.6增加新接口Dataset,提供RDD的优点:强类型、能够使用强大lambda函数SparkSQL优化执行引擎的优点可从JVM对象构造Dataset,然后函数式转换(map、flatMap、filter等)操作。DatasetAPI在Scala和Java中可用。Python不支持DatasetAPI,......
  • Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss
    省去冗长的数学证明,直接看文章的贡献:提出了新的Loss函数以及延迟re-weighting的trick。并在多个数据集,包括情感分类、图像分类进行实验。Motivation&Methods:LDAM(Label-Distribution-AwareMargie)Losstailclasses的信息基本上较少,而且部署的模型通常很大,因此对tailclasse......
  • 【发疯毕设日志day7】hagrid_dataset_512数据集作者论文原文逐句翻译——大疆tello手
    论文原文::::2206.08219.pdf(arxiv.org)https://arxiv.org/pdf/2206.08219.pdf摘要     本文介绍了一个庞大的手势识别数据集——海格(HAndGestrueRecognitionImagedataset),以简历一个手势识别(HGR)系统,专注于与设备的交互管理。这就是为什么所选的18个手势都呗赋予......
  • C# EPPlus导出dataset----Excel2绘制图像
    一、生成折线图方法 ///<summary>    ///生成折线图    ///</summary>    ///<paramname="worksheet">sheet页数据</param>    ///<paramname="colcount">总列数</param>    ///<paramname="......
  • 5-1Dataset和DataLoader
    Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个......
  • dataset 判断整列是否有重复,找出重复数据
    DataTabledt=ds.Tables[0];DataViewdv=newDataView(dt);if(dv.Count!=dv.ToTable(true,jsonColumnNameNo).Rows.Count){......
  • cnpack支持调试状态查看TDataSet对象
    在Debug状态下,cnpack支持查看TDataSet对象了!具体用法:在Debug状态下运行项目,如下图:把鼠标放到q对象上,q是一个基于TDataSet继承来的TkbmMWClientQuery对象,也就是他是一个TDataSet,这时候会弹出一个窗口,也就是一个hint。注意左上角的放大镜,下移鼠标,让鼠标进入hint区域,点击放大镜......
  • (23)lazarus memdataset的filter问题
    参考https://www.cnblogs.com/qiufeng2014/p/17388138.html链接:https://pan.baidu.com/s/1ayzgDbXjgXBnw-jM1FR4gA提取码:ogqzunitUnit1;{$modeobjfpc}{$H+}interfaceusesClasses,SysUtils,memds,db,Forms,Controls,Graphics,Dialogs,DBGrids;type{TForm1......
  • 李宏毅2022机器学习HW4 Speaker Identification上(Dataset &Self-Attention)
    Homework4Dataset介绍及处理Datasetintroduction训练数据集metadata.json包括speakers和n_mels,前者表示每个speaker所包含的多条语音信息(每条信息有一个路径feature_path和改条信息的长度mel_len或理解为frame数即可),后者表示滤波器数量,简单理解为特征数即可,由此可知每个.pt......