首页 > 其他分享 >[PyTorch] 自定义数据集

[PyTorch] 自定义数据集

时间:2022-11-20 11:11:25浏览次数:43  
标签:定义数据 self batch PyTorch train csv data bananas

步骤

  • 自定义Dataset实例:
    • 定义 __init__ 方法:返回feature和label两个部分的数据;
    • 定义 __getitem() ;
    • 定义 _len_()方法;
  • 使用 torch.utils.data.DataLoader 加载数据;

示例

  • 目的:载入自定义的目标检测数据集——“banana-detection”;

  • 数据集格式:包含两个文件夹“bananas_train”和“bananans_val”,

    • 文件夹:两个文件夹包含内容都一样,分别为一个“label.csv”文件和一个“images”文件夹(存放的是图片);

    • label.csv:存放的是每张图片的文件名和目标边框信息,字段如下:

      字段名 img_name label xmin ymin xmax ymax
      含义 图片文件名 具体含义我不太清楚,全为0,好像没用到 目标左上角的x轴坐标 目标左上角的y轴坐标 目标右下角的x轴坐标 目标右下角的y轴坐标
# %%
import os
import pandas as pd
import torch 
import torchvision
from d2l import torch as d2l
import numpy as np

# %%
def read_data_bananas(is_train = True):
    """读取香蕉数据集中的图像和标签"""
    data_dir = '..\\data\\banana-detection'
    csv_fname = os.path.join(data_dir, 
                             'bananas_train' if is_train else 'bananas_val',
                             'label.csv')
    csv_data = pd.read_csv(csv_fname)
    # 指定索引,下边的for循环取到数的index(第一个参数img_name)就为该字段,其他字段的数据就会放到target中
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    # 将图片读取到内存中(数据集大的时候不能用该方法)
    for img_name, target in csv_data.iterrows():                # pandas.DataFrame的iterrows()函数是在数据框中的行进行迭代的一个生成器,它返回每行的索引及一个包含行本身的对象。
        images.append(
            # 读取图片,返回一个三维tensor(channels, height, width)
            torchvision.io.read_image(              
                os.path.join(data_dir,
                             'bananas_train' if is_train else 'bananas_val',
                             'images',
                             f'{img_name}')     # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式
            )
        )
        tmp = torchvision.io.read_image(              
                os.path.join(data_dir,
                             'bananas_train' if is_train else 'bananas_val',
                             'images',
                             f'{img_name}')     # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式
            )
        targets.append(list(target))
    print(tmp)
    print(tmp.shape)
    # images包含所有图片的张量
    # images是一个列表,每个元素为一个多维tensor(一张图片的张量),转换为tensor会报错。如果是其他数据的话,如果能转换为tensor的话,不知道这个代码会不会报错,但是target这里也转换成tensor了,也没问题。有人知道的话请告诉我
    # targets转换成tensor,每个元素除以256。实际上target是列表也不会报错((torch.tensor(targets).unsqueeze(1) / 256).numpy().tolist())
    return images, torch.tensor(targets).unsqueeze(1) / 256     # target扩一个维度,然后每个数字除以256

# %%
# 创建一个自定义Dataset实例
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        # print(self.labels)
        # print(f'feature_type:{type(self.features)},label_type:{type(self.labels)}')
        print('read' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))
    
    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])
    
    def __len__(self):
        return len(self.features)

# %%
# 训练集和测试集返回两个数据加载器实例
def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train = True),
                                             batch_size,
                                             shuffle = True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train = False),
                                           batch_size)
    return train_iter, val_iter

# %%
# 读取一个小批量,并打印其中的图像和标签的形状
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))      # 迭代train_iter,取第一轮的结果,也就是一个结果
# 输出第一个batch的维度(batch_size, channels, height, width),第二个是
batch[0].shape, batch[1].shape

# %%
# 演示(用的是深度学习模块 d2l)
# permute()是维度换位,(0,2,3,1)表示把原来的
# 0维 -> 0维
# 2维 -> 1维 
# 3维 -> 2维
# 1维 -> 3维
imgs = (batch[0][0 : 10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors = ['w'])

感谢李沐老哥哥赞助的代码!

标签:定义数据,self,batch,PyTorch,train,csv,data,bananas
From: https://www.cnblogs.com/chasemeng/p/16908059.html

相关文章

  • RNN的PyTorch实现
    官方实现PyTorch已经实现了一个RNN类,就在torch.nn工具包中,通过torch.nn.RNN调用。使用步骤:实例化类;将输入层向量和隐藏层向量初始状态值传给实例化后的对象,获得RNN的......
  • pytorch学习笔记(1)
    pytorch学习笔记(1)   expand向左扩展维度、扩展元素个数a=t.ones(2,3)只能在左侧增加维度,而不能在右侧增加维度,也不能在中间增加维度新增维度的元素个数可以为任......
  • Pytorch基于MNIST数据集简单实现手写数字识别
    """模型训练代码"""importtorchimporttorchvision.datasetsfromtorchimportnnfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderi......
  • pytorch输出tensor张量时显示省略号的问题
    问题描述:由于tensor的数据量过大,再用print输出tensor数据时会将中间的部分用省略号代替,无法看到全部的tensor元素。   解决方案:在程序开头加入下面的代......
  • pytorch使用docker部署后卡死现象
    现象基于pytorch的模型服务,本地裸跑代码都是正常的,一旦上docker服务部署后,程序会出现卡死现象解决原因是,默认情况下,pytorch会启动宿主机当前的CPU核数作为线程数去运行,......
  • 神经网络中的权重初始化方式和pytorch应用
    目录计算增益常数初始化均匀分布初始化正态分布初始化Xavier初始化均匀分布(glorot初始化)正态分布Kaiming初始化均匀分布正态分布具体应用一些问答或tips深度学习模型中的......
  • win10安装cuda、cuDNN和pytorch笔记
    特别注意:由于自己安装时没有做记录,所以下面大部分安装步骤图片都是参考的网络图,但不影响阅读,每一步都讲得很详细1.安装CUDA1.1查看自己显卡最高支持的CUDA版本在桌面......
  • 自定义数据类型
    枚举枚举故名思义就是一一列举把可能的取值一一列举1定义enumDay//星期{//枚举的可能取值Mon,Tus,...};enumSex//星期{//枚举的可能取值——常量......
  • C语言自定义数据类型
    结构体参考视频:https://www.bilibili.com/video/BV1oi4y1g7CF?p=58大纲:结构体的声明结构体的自引用结构体内存对齐结构体传参结构体实现位段(位段的填充&可移植性)charshor......
  • NLP入门之——Word2Vec词向量Skip-Gram模型代码实现(Pytorch版)
    代码地址:https://github.com/liangyming/NLP-Word2Vec.git1.什么是Word2VecWord2vec是Google开源的将词表征为实数值向量的高效工具,其利用深度学习的思想,可以通过训练......