首页 > 其他分享 >pytorch自定义或自组织数据集

pytorch自定义或自组织数据集

时间:2023-04-25 12:55:41浏览次数:35  
标签:__ img 自定义 组织 self labels label pytorch images

 

import os
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
import torchvision
from PIL import Image


class DatasetSelfDefine(torchvision.datasets.vision.VisionDataset):
    def __init__(
            self,
            root: str,
            name: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
        super(DatasetSelfDefine, self).__init__(root, transforms, transform, target_transform)
        images_dir = Path(root) / 'images' / name
        labels_dir = Path(root) / 'labels' / name
        self.images = [n for n in images_dir.iterdir()]
        self.labels = []
        for image in self.images:
            base, _ = os.path.splitext(os.path.basename(image))
            label = labels_dir / f'{base}.txt'
            self.labels.append(label if label.exists() else None)

    #  获取数据集大小
    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        img = Image.open(self.images[idx]).convert('RGB')# PIL Image, 大小为 (H, W)

        label_file = self.labels[idx]
        if label_file is not None:  # found
            with open(label_file, 'r') as f:
                labels = [x.split() for x in f.read().strip().splitlines()]
                labels = np.array(labels, dtype=np.float32)
        else:  # missing
            labels = np.zeros((0, 5), dtype=np.float32)

        boxes = []
        classes = []
        for label in labels:
            x, y, w, h = label[1:]
            boxes.append([
                    (x - w / 2) * img.width,
                    (y - h / 2) * img.height,
                    (x + w / 2) * img.width,
                    (y + h / 2) * img.height])
            classes.append(label[0])

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)# 真实标注框 [x1, y1, x2, y2], x 范围 [0,W], y 范围 [0,H]
        target["labels"] = torch.as_tensor(classes, dtype=torch.int64)# 上述标注框的类别标识

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    #  访问第 i 个数据
    def __len__(self) -> int:
        return len(self.images)


if __name__ == '__main__':

    batch_size = 64

    dataset = DatasetSelfDefine('./data/coco128', 'train2017', transform=torchvision.transforms.ToTensor())
    print(f'dataset: {len(dataset)}')
    print(f'dataset[0]: {dataset[0]}')

    dataset = DatasetSelfDefine('./data/coco128', 'train2017',
                     transform=torchvision.transforms.Compose([
                         torchvision.transforms.ToTensor()
                     ]))

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True,
                            collate_fn=lambda batch: tuple(zip(*batch)))

    for batch_i, (images, targets) in enumerate(dataloader):
        print(f'batch {batch_i}, images {len(images)}, targets {len(targets)}')
        print(f'  images[0]: shape={images[0].shape}')
        print(f'  targets[0]: {targets[0]}')

  

 

标签:__,img,自定义,组织,self,labels,label,pytorch,images
From: https://www.cnblogs.com/jeshy/p/17352281.html

相关文章

  • Servlet添加自定义的过滤器没有效果?
    在学习HttpServlet的时候有个自定义过滤器的定义类,我们想让特定url走过滤器。publicclassMyFilterimplementsFilter{privateFilterConfigconfig;publicvoidinit(FilterConfigconfig)throwsServletException{this.config=config;}publi......
  • 【HarmonyOS】自定义组件之JavaUI实现通用标题栏组件
    【关键字】标题栏、常用内置组件整合、JavaUI、自定义组件 【1、写在前面】平时我们在开发一个应用时,我们都知道一个完整的项目中会有很多个页面,而这些页面中会有许多通用的部分,比如通用标题栏、通用Dialog、通用下拉菜单等等,在Android开发中我们可以通过LayoutInflater.from......
  • 使用PyTorch和Flower 进行联邦学习
    本文将介绍如何使用Flower构建现有机器学习工作的联邦学习版本。我们将使用PyTorch在CIFAR-10数据集上训练卷积神经网络,然后将展示如何修改训练代码以联邦的方式运行训练。完整文章:https://avoid.overfit.cn/post/8d05a12c208c4f499573c9966d0fe415......
  • 自定义Python版本ESL库访问FreeSWITCH
    环境:CentOS7.6_x64Python版本:3.9.12FreeSWITCH版本:1.10.9一、背景描述ESL库是FreeSWITCH对外提供的接口,使用起来很方便,但该库是基于C语言实现的,Python使用该库的话需要使用源码进行编译。如果使用系统自带的Python版本进行编译,过程会比较流畅,就不描述了。这里记录下使用自定义......
  • Pytorch可视化热力图
     可视化热力图可以有两种方式:1)特征图可视化,将各通道特征的最大值作为热力图像素值,进行可视化——可以参考博客,一种比较灵活的特征图保存方式2)根据梯度值结合特征图计算热力图,热力图的显示的重点是梯度高的地方,也是网络关注的重点 基于梯度进行热力图可视化有一些工作,如grad......
  • 其它权限校验方法 自定义权限校验方法
    我们前面都是使用@PreAuthorize注解,然后在在其中使用的是hasAuthority方法进行校验。SpringSecurity还为我们提供了其它方法例如:hasAnyAuthority,hasRole,hasAnyRole等。​这里我们先不急着去介绍这些方法,我们先去理解hasAuthority的原理,然后再去学习其他方法你就更容易理解,而不是......
  • skywalking自定义插件开发
    skywalking是使用字节码操作技术和AOP概念拦截Java类方法的方式来追踪链路的,由于skywalking已经打包了字节码操作技术和链路追踪的上下文传播,因此只需定义拦截点即可。这里以skywalking-8.7.0版本为例。关于插件拦截的原理,可以看我的另一篇文章:skywalking插件工作原理剖析1.......
  • 第八章 重新组织数据
    8.1自封装字段例如,取值逻辑封装进对象8.2以对象取代数据值如果一个字段不能表达清楚业务含义,还需要添加多个相关联的字段。考虑添加一个新的对象。8.3以对象取代数组,或多个参数 ......
  • excel的练习1--自定义单元格格式(win11)
    题目1.4自定义单元格格式。在EXCEL单元格区域设置自定义数字格式,实现如下效果:在该区域的任意单元格输入数字1,则显示为√,其他数据原样显示。则:1.格式定义形式为【1】.【提示:格式-单元格格式-自定义】2.格式定义之后得到的结果B4:F10复制粘贴为文本到第【2】空。答案步骤选......
  • 自定义方法
          ......