首页 > 其他分享 >pytorch collate_fn测试用例

pytorch collate_fn测试用例

时间:2023-11-01 14:26:07浏览次数:52  
标签:img torch shape label pytorch 测试用例 collate np Size

collate_fn 函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class MyDataset(Dataset):
    def __init__(self, imgs, labels):
        self.imgs = imgs
        self.labels = labels

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        out_img = img.astype(np.float32)
        out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c  -->>  c,h,w
        out_label = self.labels[idx] #[4, 5] or [2, 5]
        return out_img, out_label

#if batchsize=3
#batch is list, [3]
#batch0 tuple2  (np[3, 300, 150], np[4, 5])
#batch1 tuple2  (np[3, 300, 150], np[2, 5])
#batch2 tuple2  (np[3, 300, 150], np[4, 5])
def my_collate_fn(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(torch.FloatTensor(sample[0]))
        targets.append(torch.FloatTensor(sample[1]))

    imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
    return imgs_out, targets




img_data = []
label_data = []

nums = 34
H=300
W=150
for _ in range(nums):
    random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
    nums_target = np.random.randint(low=0, high=10)
    random_xyxy_label = np.random.random((nums_target, 5))
    img_data.append(random_img)
    label_data.append(random_xyxy_label)

dataset = MyDataset(img_data, label_data)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)

for cnt, (img, label) in enumerate(dataloader):
    print("==>>", cnt, ",  img shape=", img.shape)
    for i in range(len(label)):
        print("label shape=", label[i].shape)

打印如下:

==>> 0 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([8, 5])
label shape= torch.Size([2, 5])
label shape= torch.Size([5, 5])
==>> 1 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([3, 5])
label shape= torch.Size([8, 5])
label shape= torch.Size([5, 5])
==>> 2 ,  img shape= torch.Size([3, 3, 300, 150])
label shape= torch.Size([7, 5])
label shape= torch.Size([1, 5])
label shape= torch.Size([8, 5])

标签:img,torch,shape,label,pytorch,测试用例,collate,np,Size
From: https://www.cnblogs.com/yanghailin/p/17803000.html

相关文章

  • 解码注意力Attention机制:从技术解析到PyTorch实战
    在本文中,我们深入探讨了注意力机制的理论基础和实际应用。从其历史发展和基础定义,到具体的数学模型,再到其在自然语言处理和计算机视觉等多个人工智能子领域的应用实例,本文为您提供了一个全面且深入的视角。通过Python和PyTorch代码示例,我们还展示了如何实现这一先进的机制。关......
  • 解码注意力Attention机制:从技术解析到PyTorch实战
    在本文中,我们深入探讨了注意力机制的理论基础和实际应用。从其历史发展和基础定义,到具体的数学模型,再到其在自然语言处理和计算机视觉等多个人工智能子领域的应用实例,本文为您提供了一个全面且深入的视角。通过Python和PyTorch代码示例,我们还展示了如何实现这一先进的机制。关......
  • 代码 测试用例 测试用例 测试结果 26. 删除有序数组中的重复项
    给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。考虑 nums 的唯一元素的数量为 k ,你需要做以下事情确保你的题解可以被通......
  • pytorch中squeeze()和unsqueeze()函数
     下面使用一个二维矩阵看下dim不同时呈现出的效果:   #创建一个3*4的全1二维tensor   a=torch.ones(3,4)   '''   运行结果tensor([[1.,1.,1.,1.],[1.,1.,1.,1.],[1.,1.,1.,1.]])'''在0维度上插入一个维度,可以看到现在......
  • pytorch 学习记录
    model.train():启用BatchNormalization和Dropout。作用:对BN层,保证BN层能够用到每一批数据的均值和方差,并进行计算更新;对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。model.eval():不启用BatchNormalization和Dropoutwithtorch.no_grad():with语句块内......
  • 刘老师《Pytorch深度学习实践》第三讲:梯度下降
    1.分治法不能用局部点干扰性大2.梯度下降3.随机梯度下降随机梯度下降法(StochasticGradientDescent,SGD):由于批量梯度下降法在更新每一个参数时,都需要所有的训练样本,所以训练过程会随着样本数量的加大而变得异常的缓慢。随机梯度下降法正是为了解决批量梯度下降法这一......
  • pytorch深度学习入门
    参考:1、Pytorch最全入门介绍,Pytorch入门看这一篇就够了2、torch.nn模块torch.nn模块是PyTorch中用于构建神经网络的核心模块,包含了各种不同类型的层(如全连接层、卷积层、池化层)、损失函数、优化器等。下面介绍torch.nn中常用的一些类和函数:nn.module:所有神经网络层的基类,定义了......
  • ResNet详解:网络结构解读与PyTorch实现教程
    本文深入探讨了深度残差网络(ResNet)的核心概念和架构组成。我们从深度学习和梯度消失问题入手,逐一解析了残差块、初始卷积层、残差块组、全局平均池化和全连接层的作用和优点。文章还包含使用PyTorch构建和训练ResNet模型的实战部分,带有详细的代码和解释。关注TechLead,分享AI与......
  • ResNet详解:网络结构解读与PyTorch实现教程
    本文深入探讨了深度残差网络(ResNet)的核心概念和架构组成。我们从深度学习和梯度消失问题入手,逐一解析了残差块、初始卷积层、残差块组、全局平均池化和全连接层的作用和优点。文章还包含使用PyTorch构建和训练ResNet模型的实战部分,带有详细的代码和解释。关注TechLead,分享AI与......
  • pytorch:1.12-gpu-py39-cu113-ubuntu20.04
    docker-compose安装unbuntu20.04version:'3'services:ubuntu2004:image:ubuntu:20.04ports:-'2256:22'-'3356:3306'-'8058:80'volumes:-my-volume:/datacommand:tail......