目录
这四个东西之间的关系(简单理解)
Dataset
提供一种方式获取数据及其label
主要提供两个功能
- 如何获取每一个数据及其label
- 告诉我们总共有多少数据
获取数据集有两种方式:从电脑文件中读入、从torchvision官方获取提供的数据集
从电脑文件中读入数据集
- 创建一个类继承Dataset
- 确定数据文件所在的位置(路径/目录)和数据的标签
- 通过os和PIL Image相关操作进行处理
具体代码及细节如下
from torch.utils.data import Dataset
from PIL import Image
import os
# 定义MyData类继承Dataset类
class MyData(Dataset):
"""
构造函数 __init__(self, arg1, arg2...)
用该类创建一个对象时先自动调用构造函数,创建对象时可以指定初始化参数
构造函数中定义的self.root_dir self.label_dir等这些是实例属性
此处构造函数的具体作用是通过传入参数,获取某个标签下的所有图片信息
"""
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # dataset/train 注意是相对路径
self.label_dir = label_dir # ants
self.path = os.path.join(self.root_dir, self.label_dir) # dataset/train/ants
self.img_list = os.listdir(self.path) # ants下所有文件名列表
"""
__getitem__(self, idx)
通过下标idx获取一张图片
"""
def __getitem__(self, idx):
img_name = self.img_list[idx] # 获取文件名
# 拼接获得文件的相对地址
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 通过文件地址打开文件
img = Image.open(img_item_path)
label = self.label_dir
return img, label # 返回文件及其标签
def __len__(self):
return len(self.img_list)
# 定义参数
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir) # 实例化ants_dataset对象
bees_dataset = MyData(root_dir, bees_label_dir) # 实例化bees_dataset对象
train_dataset = ants_dataset + bees_dataset # 将两个数据集拼接
从torchvision官方获取提供的数据集
在此过程中可结合transforms(对数据进行处理和变换)和tensorboard(展示数据)
具体代码实现及细节如下
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 指定读取数据集时将图片转为tensor
dataset_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 读取训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transforms, download=True)
# 读取测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)
# print(test_set[0])
# print(test_set.classes)
# img, label = test_set[0]
# print(label)
# print(test_set.classes[label])
# img.show()
# 使用tensorboard展示测试数据集前10个图片
writer = SummaryWriter(log_dir="./logs")
for i in range(10):
img, label = test_set[i]
writer.add_image("tensor image", img, i)
writer.close()
transforms
什么是transforms
transforms是一个工具包,里面定义了ToTensor这样一个类,可以将PIL image转化为tensor
首先要用ToTensor实例化一个对象tool,再把image作为参数调用tool,返回值就是这个图片的tensor
具体实现代码如下
from torchvision import transforms
from PIL import Image
img_path = "dataset/train/ants/5650366_e22b7e1065.jpg" # 原始图片相对路径
img = Image.open(img_path)
tensor_trans = transforms.ToTensor() # 实例化一个ToTensor对象
tensor_img = tensor_trans(img) # 传入image参数,获取该图片tensor返回值
# print(tensor_img)
什么是tensor
tensor是封装了原始图片信息以训练神经网络所需要的其他的一些新的数据类型
DataLoader
Dataset是获取数据集,而Dataloader用来管理如何从数据集中取数据进行训练,例如:每次取几个数据、取完一轮之后要不要进行打乱、刚刚取过的一批数据要不要放到底部等
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
trans = torchvision.transforms.ToTensor()
# 准备测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=trans, download=True)
# 64个图片为一批,打包为一个整体(batch_size)
# 一趟epoch后打乱顺序(shuffle=True)
# 最后一批不足64个舍去(drop_last=True)
dataloader = DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
writer = SummaryWriter(log_dir="./logs")
step = 0
for data in dataloader:
imgs, targets = data # data为一个batch,返回64个图片的tensor和标签
# print(imgs.shape)
# print(targets)
writer.add_images("dataloader-3", imgs, step)
step += 1
writer.close()
tensorboard
简单理解就是将数据可视化的工具
操作流程
-
实例化一个SummaryWriter对象writer
指定输出文件的路径
-
向writer中添加数据
-
关闭writer
-
启动tensorboard,去web端查看数据
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
# 创建编写器,保存日志
# log_dir保存路径 "./logs"当前目录下logs目录
writer = SummaryWriter(log_dir="./logs")
image_path = "data/train/bees_image/16838648_415acd9e3f.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
writer.add_image("test", img_array, 2, dataformats='HWC')
# 关闭
writer.close()
标签:img,self,label,学习,pytorch,transforms,简单,dataset,dir
From: https://www.cnblogs.com/dctwan/p/17034855.html