步骤:
- 自定义Dataset实例:
- 定义 __init__ 方法:返回feature和label两个部分的数据;
- 定义 __getitem() ;
- 定义 _len_()方法;
- 使用 torch.utils.data.DataLoader 加载数据;
示例:
-
目的:载入自定义的目标检测数据集——“banana-detection”;
-
数据集格式:包含两个文件夹“bananas_train”和“bananans_val”,
-
文件夹:两个文件夹包含内容都一样,分别为一个“label.csv”文件和一个“images”文件夹(存放的是图片);
-
label.csv:存放的是每张图片的文件名和目标边框信息,字段如下:
字段名 img_name label xmin ymin xmax ymax 含义 图片文件名 具体含义我不太清楚,全为0,好像没用到 目标左上角的x轴坐标 目标左上角的y轴坐标 目标右下角的x轴坐标 目标右下角的y轴坐标
-
# %%
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
import numpy as np
# %%
def read_data_bananas(is_train = True):
"""读取香蕉数据集中的图像和标签"""
data_dir = '..\\data\\banana-detection'
csv_fname = os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'label.csv')
csv_data = pd.read_csv(csv_fname)
# 指定索引,下边的for循环取到数的index(第一个参数img_name)就为该字段,其他字段的数据就会放到target中
csv_data = csv_data.set_index('img_name')
images, targets = [], []
# 将图片读取到内存中(数据集大的时候不能用该方法)
for img_name, target in csv_data.iterrows(): # pandas.DataFrame的iterrows()函数是在数据框中的行进行迭代的一个生成器,它返回每行的索引及一个包含行本身的对象。
images.append(
# 读取图片,返回一个三维tensor(channels, height, width)
torchvision.io.read_image(
os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'images',
f'{img_name}') # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式
)
)
tmp = torchvision.io.read_image(
os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'images',
f'{img_name}') # python的print字符串前面加f表示格式化字符串,加f后可以在字符串里面使用用花括号括起来的变量和表达式
)
targets.append(list(target))
print(tmp)
print(tmp.shape)
# images包含所有图片的张量
# images是一个列表,每个元素为一个多维tensor(一张图片的张量),转换为tensor会报错。如果是其他数据的话,如果能转换为tensor的话,不知道这个代码会不会报错,但是target这里也转换成tensor了,也没问题。有人知道的话请告诉我
# targets转换成tensor,每个元素除以256。实际上target是列表也不会报错((torch.tensor(targets).unsqueeze(1) / 256).numpy().tolist())
return images, torch.tensor(targets).unsqueeze(1) / 256 # target扩一个维度,然后每个数字除以256
# %%
# 创建一个自定义Dataset实例
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
# print(self.labels)
# print(f'feature_type:{type(self.features)},label_type:{type(self.labels)}')
print('read' + str(len(self.features)) + (f' training examples' if is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
# %%
# 训练集和测试集返回两个数据加载器实例
def load_data_bananas(batch_size):
"""加载香蕉检测数据集"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train = True),
batch_size,
shuffle = True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train = False),
batch_size)
return train_iter, val_iter
# %%
# 读取一个小批量,并打印其中的图像和标签的形状
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter)) # 迭代train_iter,取第一轮的结果,也就是一个结果
# 输出第一个batch的维度(batch_size, channels, height, width),第二个是
batch[0].shape, batch[1].shape
# %%
# 演示(用的是深度学习模块 d2l)
# permute()是维度换位,(0,2,3,1)表示把原来的
# 0维 -> 0维
# 2维 -> 1维
# 3维 -> 2维
# 1维 -> 3维
imgs = (batch[0][0 : 10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors = ['w'])
标签:定义数据,self,batch,PyTorch,train,csv,data,bananas From: https://www.cnblogs.com/chasemeng/p/16908059.html感谢李沐老哥哥赞助的代码!