首页 > 其他分享 >pytorch学习(八)Dataset加载分类数据集

pytorch学习(八)Dataset加载分类数据集

时间:2024-07-21 21:00:22浏览次数:9  
标签:__ img self label pytorch path Dataset dir 加载

我们之前用torchvision加载了pytorch的网络数据集,现在我们用Dataset加载自己的数据集,并且使用DataLoader做成训练数据集。

图像是从网上下载的,网址是 点这里,标签是图像文件夹名字。下载完成后作为自己的数据集。

1.加载自己的数据集的思路

    1)要完成继承自Dataset的类的构建

          由于Dataset是一个包含了虚函数的类,因此继承Dataset后,必须实现这些虚函数。

   2)第一个要完成的是__init__的构建,一般的方法是在__init__(self,root_dir, label_dir)中设置数据集的根目录root_dir,和类别数据集label_dir,然后用os.listdir得到label_dir中的图像名字

    3)第二个要完成的就是

__getitem__(self, item):

       item就是所要取数据的索引,这个函数主要是返回一个训练数据(比如一个图像),和一个结果数据,比如(该图像的分类结果是一个ant),因此用到刚os.listdir所列出的文件名字,用os.path.join加入路径,得到图像的绝对路径,用PIL导入图像,并给label赋值,返回图像和;abel即可。

   4)第三个要实现的就是数据集的长度

  __len__(self):

可以直接len(os.listdir所列出的文件名的数组),就可以得到数据集的长度。

2.需要注意的问题

   我在调试的时候发现

for imgs, labels  in train_loader:

一直报错,查找原因,发现是该数据集中的图像存在两个问题,第一个是大小不一,第二个貌似通道个数也不一致。

大小不一

因此使用transform做了处理

transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])

3.代码如下:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter("logs")
transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])


class MyDataLoader(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        img = transform(img)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = "E:/TOOLE/slam_evo/pythonProject/data/hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"

ants_dataset = MyDataLoader(root_dir,ants_label_dir)
bees_dataset = MyDataLoader(root_dir,bees_label_dir)
train_data = ants_dataset + bees_dataset

img0, label0 = train_data[12]
# img0.show()
img1, label1 = train_data[124]
# img1.show()
# 一次处理数据10个
BATCH_SIZE = 10
# 把数据集装载到DataLoader里
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)

A = len(train_loader)
num_iter = 0
for imgs, labels  in train_loader:

    print(imgs.shape)
    print(labels)
    # print(train_data.classes)
    writer.add_images("ant-bees",imgs,num_iter)
    num_iter = num_iter +1

writer.close()


用tensorboard显示,batch_size= 10,因此每次迭代有10张图像

标签为:

标签:__,img,self,label,pytorch,path,Dataset,dir,加载
From: https://blog.csdn.net/hero_heart/article/details/140561019

相关文章

  • 【PyTorch】图像多分类项目
    【PyTorch】图像二分类项目【PyTorch】图像二分类项目-部署【PyTorch】图像多分类项目【PyTorch】图像多分类项目部署多类图像分类的目标是为一组固定类别中的图像分配标签。目录加载和处理数据搭建模型定义损失函数定义优化器训练和迁移学习用随机权重进行训......
  • android audio 相机按键音加载与修改
    相机按键音资源,加载文件路径:frameworks/av/services/camera/libcameraservice/CameraService.cpp按键音,加载函数: voidCameraService::loadSoundLocked(sound_kindkind){   ATRACE_CALL();     LOG1("CameraService::loadSoundLockedref=%d",mSoundRe......
  • bug处理--antdesign中umi升级后无法加载子页面
    bug处理--antdesign中umi升级后无法加载子页面historyconstAdmin:React.FC=(props)=>{ const{children}=props; return( <PageHeaderWrapper> {children} </PageHeaderWrapper> );};now升级到Umi4后,之前的一些组件不能用了,获取不到props,props......
  • 如何在 kivy 中的应用程序文件中保存和加载设置?
    我一直在使用配置对象来存储和检索kivy中应用程序的设置。问题是配置对象是全局的,并且是为系统中的所有kivy应用程序设置的。我当前使用的内容:fromkivy.configimportconfig...AppcodeclassUserAPP(App);defbuild(self)defbuild_config(self,con......
  • 在 PowerShell 中,可以编写脚本来检测本地加载和远程加载的情况。这通常涉及到检查计算
    在PowerShell中,可以编写脚本来检测本地加载和远程加载的情况。这通常涉及到检查计算机上的特定服务或应用程序的状态或配置。以下是一些示例脚本和方法,可以用来实现这些检测:检测本地加载示例:检查本地服务的运行状态powershellCopyCode#检查本地服务状态$serviceName="M......
  • 在 PowerShell 中,"本地加载"和"远程加载"通常指的是运行脚本或命令的位置或方式。以下
    在PowerShell中,"本地加载"和"远程加载"通常指的是运行脚本或命令的位置或方式。以下是关于本地加载和远程加载的一些基本概念和示例:本地加载本地加载指的是在当前计算机上执行PowerShell脚本或命令。这些脚本和命令直接在本地计算机上运行,无需通过网络连接到其他计算机或服......
  • 同时加载 2 个 Tkinter 窗口。一个有动画的
    以下脚本独立运行以运行场景:首先打印结果,然后以动画结束绘图。importnumpyasnpimportmatplotlib.pyplotaspltimporttkinterastkfromtkinterimportttkfrommatplotlib.animationimportFuncAnimationdefrun_model():#Inputparameters(examplev......
  • 易优CMS模板标签load文件加载导入外部的css样式文件
    【基础用法】标签:load描述:资源文件加载,比如:css/js用法:{eyou:loadhref='/static/js/common.js'ver='on'/}属性:file=''资源文件路径href=''远程资源文件URLver=''开启版本号自动刷新浏览器缓存涉及表字段:无【更多示例】-------------------------------示例1------......
  • 加载 iptables 相关模块
    在现代Linux系统中,连接跟踪(ConnectionTracking)功能已经集成到nf_conntrack模块中,不再需要单独加载ip_conntrack模块。相应的nf_conntrack模块负责处理IPv4的连接跟踪功能,而IPv6的连接跟踪功能由nf_conntrack_ipv6模块处理。如果你需要加载与连接跟踪相关的模块,可以使用......
  • android audio不同音频流,(三)各音频流默认音量加载过程
    各音频流默认值,定义文件路径:frameworks/base/media/java/android/media/AudioSystem.java默认音量定义数组: /**@hide*/ publicstaticint[]DEFAULT_STREAM_VOLUME=newint[]{     4, //STREAM_VOICE_CALL     7, //STREAM_SYSTEM ......