首页 > 其他分享 >pytorch简单学习

pytorch简单学习

时间:2023-01-08 16:46:23浏览次数:33  
标签:img self label 学习 pytorch transforms 简单 dataset dir

目录

这四个东西之间的关系(简单理解)

image-20230108163318476

Dataset

提供一种方式获取数据及其label

主要提供两个功能

  • 如何获取每一个数据及其label
  • 告诉我们总共有多少数据

获取数据集有两种方式:从电脑文件中读入从torchvision官方获取提供的数据集

从电脑文件中读入数据集

  • 创建一个类继承Dataset
  • 确定数据文件所在的位置(路径/目录)和数据的标签
  • 通过os和PIL Image相关操作进行处理

具体代码及细节如下

from torch.utils.data import Dataset
from PIL import Image
import os


# 定义MyData类继承Dataset类
class MyData(Dataset):
    """
        构造函数 __init__(self, arg1, arg2...)
        用该类创建一个对象时先自动调用构造函数,创建对象时可以指定初始化参数
        构造函数中定义的self.root_dir self.label_dir等这些是实例属性

        此处构造函数的具体作用是通过传入参数,获取某个标签下的所有图片信息
    """

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # dataset/train 注意是相对路径
        self.label_dir = label_dir  # ants	
        self.path = os.path.join(self.root_dir, self.label_dir)  # dataset/train/ants
        self.img_list = os.listdir(self.path)  # ants下所有文件名列表

    """
        __getitem__(self, idx)
        通过下标idx获取一张图片
    """

    def __getitem__(self, idx):
        img_name = self.img_list[idx]  # 获取文件名
        # 拼接获得文件的相对地址
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 通过文件地址打开文件
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label  # 返回文件及其标签

    def __len__(self):
        return len(self.img_list)


# 定义参数
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)  # 实例化ants_dataset对象
bees_dataset = MyData(root_dir, bees_label_dir)  # 实例化bees_dataset对象
train_dataset = ants_dataset + bees_dataset  # 将两个数据集拼接

从torchvision官方获取提供的数据集

在此过程中可结合transforms(对数据进行处理和变换)和tensorboard(展示数据)

具体代码实现及细节如下

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 指定读取数据集时将图片转为tensor
dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 读取训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transforms, download=True)

# 读取测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transforms, download=True)

# print(test_set[0])
# print(test_set.classes)
# img, label = test_set[0]
# print(label)
# print(test_set.classes[label])
# img.show()

# 使用tensorboard展示测试数据集前10个图片
writer = SummaryWriter(log_dir="./logs")
for i in range(10):
    img, label = test_set[i]
    writer.add_image("tensor image", img, i)
writer.close()

transforms

什么是transforms

transforms是一个工具包,里面定义了ToTensor这样一个类,可以将PIL image转化为tensor

首先要用ToTensor实例化一个对象tool,再把image作为参数调用tool,返回值就是这个图片的tensor

image-20230108112449392

具体实现代码如下

from torchvision import transforms
from PIL import Image

img_path = "dataset/train/ants/5650366_e22b7e1065.jpg"  # 原始图片相对路径
img = Image.open(img_path)
tensor_trans = transforms.ToTensor()    # 实例化一个ToTensor对象
tensor_img = tensor_trans(img)      # 传入image参数,获取该图片tensor返回值
# print(tensor_img)

什么是tensor

tensor是封装了原始图片信息以训练神经网络所需要的其他的一些新的数据类型

image-20230108140040895

DataLoader

Dataset是获取数据集,而Dataloader用来管理如何从数据集中取数据进行训练,例如:每次取几个数据、取完一轮之后要不要进行打乱、刚刚取过的一批数据要不要放到底部等

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

trans = torchvision.transforms.ToTensor()

# 准备测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=trans, download=True)

# 64个图片为一批,打包为一个整体(batch_size)
# 一趟epoch后打乱顺序(shuffle=True)
# 最后一批不足64个舍去(drop_last=True)
dataloader = DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

writer = SummaryWriter(log_dir="./logs")

step = 0
for data in dataloader:
    imgs, targets = data	# data为一个batch,返回64个图片的tensor和标签
    # print(imgs.shape)
    # print(targets)
    writer.add_images("dataloader-3", imgs, step)
    step += 1
writer.close()

tensorboard

简单理解就是将数据可视化的工具

操作流程

  • 实例化一个SummaryWriter对象writer

    指定输出文件的路径

  • 向writer中添加数据

  • 关闭writer

  • 启动tensorboard,去web端查看数据

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
# 创建编写器,保存日志
# log_dir保存路径 "./logs"当前目录下logs目录
writer = SummaryWriter(log_dir="./logs")

image_path = "data/train/bees_image/16838648_415acd9e3f.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
writer.add_image("test", img_array, 2, dataformats='HWC')

# 关闭
writer.close()

标签:img,self,label,学习,pytorch,transforms,简单,dataset,dir
From: https://www.cnblogs.com/dctwan/p/17034855.html

相关文章