首页 > 其他分享 >PyTorch数据集处理

PyTorch数据集处理

时间:2023-01-01 10:31:21浏览次数:64  
标签:__ img 处理 data self labels transform PyTorch 数据


数据样本处理的代码可能会变得杂乱且难以维护,因此理想状态下我们应该将模型训练的代码和数据集代码分开封装,以获得更好的代码可读性和模块化代码。

PyTorch 提供了两个基本方法 ​​torch.utils.data.DataLoader​​和​​torch.utils.data.Dataset​​可以让你预加载数据集或者你的数据。

​Dataset​​存储样本及其相关的标签, ​​DataLoader​​封装了关于 ​​Dataset​​的迭代器,让我们可以方便地读取样本。

PyTorch库中也提供了一些常用的数据集可以方便用户做预加载可以通过​​torch.utils.data.Dataset​​调用,还提供了一些对应数据集的方法。它们可以用于模型的原型和基准测试。

详细可以戳这里:


加载数据集

接下来我们看一下怎么从TorchVision加载​​Fashion-MNIST​​数据集。

Fashion-MNIST是Zalando的一个数据集,包含6万个训练样例和1万个测试样例。

每个样例由两部分组成,一个28×28灰度图像和一个十分类标签中的某一个标签。

我们要加载 ​​FashionMNIST Dataset​​需要用到以下几个参数:

  • ​root​​ 数据集的存储地址
  • ​train​​ 指定你要取训练集还是测试集
  • ​download=True​​ 如果你指定的 ​​root​​中没有数据集,会自动从网上下载数据集
  • ​transform​​ 、 ​​target_transform​​ 指定特征和标签转换

下边这段代码是取FashionMNIST的训练集和测试集,root设置了一个data文件,运行下边这段代码以后你可以看到当前目录下边应该多了一个data文件夹,里边就是FashionMNIST数据集文件了。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)

test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
复制代码

迭代和可视化数据集

我们可以像列表索引一样查看​​Datasets​​。 可以使用​​matplotlib​​可视化我们的数据集。

其他代码解析看注释。

至于画子图有两个方法,二者的区别仅在于一个面向方法,一个面向对象,别的完全一样。

  1. subplot
figure = plt.figure()
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
plt.subplot(rows, cols, i)

plt.show()
复制代码
  1. add_subplot
figure = plt.figure()
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
figure.subplot(rows, cols, i)

plt.show()
复制代码
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item() # 从数据集中随机采样
img, label = training_data[sample_idx] # 取得数据集的图和标签
figure.add_subplot(rows, cols, i) # 画子图,也可以plt.subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray") # 是黑白图,这里做一个维度压缩,把1通道的1压缩掉
plt.show()
复制代码

最后随机采样的结果大概是这样的:

PyTorch数据集处理_加载


使用DataLoader

​Dataset​​可以检索我们数据集中一个样本的特征和标签。但是在训练模型的时候,我们通常希望数据以小批量(minibatch)的方式作为输入,在每个epoch中重新调整数据以防止过拟合,并且还能使用Python的​​multiprocessing​​加速数据检索。

​DataLoader​​是一个迭代器,将刚才提到的复杂方法抽象成简单的API。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
复制代码

通过DataLoader迭代获取数据

我们已经将数据集加载到​​DataLoader​​中,并可以根据需要迭代数据集。

下面的每次迭代返回一个批量数据的​​train_features​​和​​train_labels​​(分别包含​​batch_size=64​​个特征和标签)。

因为我们指定了​​shuffle=True​​,在遍历所有批量之后,数据会被打乱(要对数据加载顺序进行更细粒度的控制,戳这里​​pytorch.org/docs/stable…​​ 。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
复制代码

为你的数据创建自定义数据集

自定义Dataset类必须实现三个函数:​​__init__​​, ​​__len__​​和​​__getitem__​​。看看这个FashionMNIST图像存储在img_dir目录中,它们的标签单独存储在CSV文件annotations_file中。 在下一节我们详细分析一下每个函数中发生的事情。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform

def __len__(self):
return len(self.img_labels)

def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
复制代码

init

​__init__​​函数在实例化Dataset对象时运行一次,帮我们初始化一个目录,其中包含图像、注释文件和两个变换(下一节将详细介绍)。

The labels.csv file looks like:

tshirt1.jpg, 0

tshirt2.jpg, 0

......

ankleboot999.jpg, 9

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
复制代码

len

​__len__​​方法返回我们数据集中的样本数量。

def __len__(self):
return len(self.img_labels)
复制代码

getitem

​__getitem__​​函数当你给定一个索引​​idx​​的时候,用于加载并返回样本。

基于索引,该函数去寻找图像在磁盘上的位置,使用​​read_image​​ 将其转换为一个张量,从​​self​​中的csv数据中检索相应的标签​​img_labels​​,调用它们上的变换函数(如果适用),并返回一个元组,元组中是图像的张量和对应的标签。

def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label


标签:__,img,处理,data,self,labels,transform,PyTorch,数据
From: https://blog.51cto.com/Lolitann/5981285

相关文章

  • 学习用Pandas处理分类数据!
     Datawhale干货 作者:耿远昊,Datawhale成员,华东师范大学分类数据(categoricaldata)是按照现象的某种属性对其进行分类或分组而得到的反映事物类型的数据,又称定类数据。直白......
  • 常用数据分析方法:方差分析及实现!
     Datawhale干货 作者:吴忠强,Datawhale优秀学习者,东北大学一个复杂的事物,其中往往有许多因素互相制约又互相依存。方差分析是一种常用的数据分析方法,其目的是通过数据分析......
  • 基于OpenCV的图像分割处理!
     Datawhale干货 作者:姚童,Datawhale优秀学习者,华北电力大学图像阈值化分割是一种传统的最常用的图像分割方法,因其实现简单、计算量小、性能较稳定而成为图像分割中最基本和......
  • 初探 InfluxDB 篇(六)InfluxDB 修改数据存放路径
    初探InfluxDB篇(六)InfluxDB修改数据存放路径 1、创建数据存放目录mkdir-p/home/data/influxdb说明:目录可以根据实际情况进行修改 2、设置目录访问权限sud......
  • pytorch的基本使用
    1.Anaconda配置pytorch环境1.创建环境在AnacondaPrompt工具中输入condacreate-npyTorch,报如下错误。解决方法:为Anaconda配置国内镜像源。1.方式1:使用conda......
  • 从NCBI中下载SRA数据
     今天测试了fastq-dump直接根据SRA号无法下载。只有下面一种方法测试成功。001、   002、   003、   004、   005、[root@PC1test......
  • 真知灼见|国产分析型数据库技术研究报告
    在国产分析型数据库技术研究报告(上)篇内容中,我们主要阐述了分析型数据库的发展脉络,并从技术维度对其进行了研究。在今天的文章里,我们将围绕分析型数据库的挑战与趋势继续进行......
  • 真知灼见|国产分析型数据库技术研究报告(上)
    一、分析型数据库的定义及发展1.数据库的定义及分类对数据库的分类可以从很多角度展开,同一个数据库分类角度不同,会被归类为不同的类型。本文主要采用了中国计算机学会(CCF)对......
  • 一次SQL调优 聊一聊 SQLSERVER 数据页
    一:背景1.讲故事最近给一位朋友做​​SQL慢语句​​优化,花了些时间调优,遗憾的是SQLSERVER非源码公开,玩起来不是那么顺利,不过从这次经历中我觉得明年的一个重大任务就是......
  • 手把手教你玩转 Excel 数据透视表
    1. 什么是数据透视表数据透视表是一种可以快速汇总、分析大量数据表格的交互式分析工具。使用数据透视表可以按照数据表格的不同字段从多个角度进行透视,并建立交叉表格,用......