首页 > 编程问答 >Pytorch Dataloader 添加批量维度

Pytorch Dataloader 添加批量维度

时间:2024-07-25 04:34:21浏览次数:16  
标签:python pytorch pytorch-dataloader

我认为这个问题已经被问过几次了,但我还没有在这里找到一个好的答案。

所以我有一个由 2 个 numpy 数组组成的 Pytorch 数据集。

以下是维度。

特征 = [10000, 450, 28] numpy 数组。 dim_0 = 样本数,dim_1 = 时间序列,dim_2 = 特征。基本上我有一个 450 帧长的数据,其中每帧包含 28 个特征,并且我有 10000 个样本。

label = [10000,450] numpy 数组。 dim_0 = 样本数,dim_1 = 每帧标签。

任务是我需要对每帧进行分类。

我使用以下函数创建了一个 Pytorch 自定义数据集和数据加载器。

label_length = label.size
label = torch.from_numpy(label)
features = torch.from_numpy(features)

train_dataset = Dataset(label, features, label_length)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

As预期,train_dataloader.dataset.data 返回大小为 [10000,450,28] 的张量 伟大的!现在只需从 10000 个样品中取出批次并进行循环! 所以我运行如下代码 - 假设优化器/损失函数都已设置。

train_loss = 0
EPOCHS = 3
for epoch_idx in range(EPOCHS):
    for i, data in enumerate(train_dataloader):
        inputs, labels = data
        print(inputs.size())
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

但是我收到此错误:

ValueError: LSTM: Expected input to be 2D or 3D, got 4D 相反

当我检查了输入的维度,它给出了[64 x 10000 x 450 x 28]

为什么dataloader添加了这个批次的维度? (我理解每个文档应该这样做,但我认为它应该从 10000 个样本中取出 64 个样本并创建批次并循环每个批次?

我认为我在某个地方犯了错误,但无法指出我是什么做错了...

编辑:这是我的简单数据集类

class Dataset(torch.utils.data.Dataset):
    def __init__(self, label, data, length):
        self.labels = label
        self.data = data
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # need to create tensor
        #data = torch.from_numpy(self.data)
        #labels = torch.from_numpy(self.labels).type(torch.LongTensor)
        data = self.data
        labels = self.labels
        return data, labels

错误的原因是的 __getitem__ 方法返回整个数据集和标签,而不是单个样本。因此,当 DataLoader 尝试创建大小为 64 的批次时,它会在数据集的开头添加一个额外的批次维度。

需要修改 __getitem__ 方法以返回单个样本及其对应的标签:

class Dataset(torch.utils.data.Dataset):
    def __init__(self, label, data):
        self.labels = label
        self.data = data

    def __len__(self):
        return len(self.data)  # 返回样本数量

    def __getitem__(self, idx):
        data = self.data[idx]
        labels = self.labels[idx]
        return data, labels

通过此更改, __getitem__ 将返回索引为 idx 的单个样本及其标签。然后, DataLoader 将这些单独的样本组合成大小为 64 的批次,从而产生形状为 [64, 450, 28] 的输入张量和形状为 [64, 450] 的标签张量。

此外,请注意 __len__ 方法也已更新为返回数据集中的样本数量,而不是标签长度。

以下是更新后的代码:

import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, label, data):
        self.labels = label
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        labels = self.labels[idx]
        return data, labels

# 示例数据
features = torch.randn(10000, 450, 28)
label = torch.randint(0, 2, (10000, 450))  # 假设的标签是二进制的

train_dataset = Dataset(label, features)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

for epoch_idx in range(3):
    for i, data in enumerate(train_dataloader):
        inputs, labels = data
        print(inputs.size())  # 应打印 torch.Size([64, 450, 28])
        # ...模型的其余训练代码

通过这些更改,的 DataLoader 应该按预期工作,并为的 LSTM 模型提供正确形状的输入张量。

标签:python,pytorch,pytorch-dataloader
From: 78328271

相关文章

  • 用于打印脚本输出的 Python 实用程序
    我可以发誓有一个实用程序可以打印一个python脚本,其输出交织在一起。例如,给定一个脚本:a=2b=3print(a+b)print(a*b)该实用程序将输出a=2b=3print(a+b)#>5print(a*b)#>6有人知道该实用程序的名称吗?我最难找到它。谢谢你!描述的实用程序没有标......
  • a method to make some handy tools with python
    Inmyworkingofcomputer,therearealotofsimplejobsthatarefrequentlyrepeated.Itriedtofindawaytomakethesejobbeenprocessedeasily.Method1:Themethodiswritingascripttodothejob,andexecutingthescriptbyutoolsextensionuto......
  • Python网络爬虫详解:实战豆瓣电影信息采集
    文章目录前言一、爬虫是什么?二、常用库及其作用1.Requests2.BeautifulSoup3.lxml4.Scrapy5.Selenium6.PyQuery7.Pandas8.JSON9.Time三、实现步骤步骤一:环境准备步骤二:数据采集步骤三:数据处理步骤四:数据存储总结前言随着互联网的迅猛发展和数据分析需求的不......
  • python学习之内置函数
    Python拥有许多内置函数,这些函数是Python的一部分,不需要额外导入即可直接使用。这些函数提供了对Python解释器功能的直接访问,涵盖了从数学计算到类型检查、从内存管理到异常处理等各个方面。下面是一些常用的Python内置函数及其简要说明:一、Printprint函数大家都不会......
  • Python中以函数为作用域
    点击查看代码#第一题foriteminrange(10):#不报错,没有函数,所有操作在全局作用域里面执行,item最后赋值为:9,此时item在缩进与全局都可以使用passprint(item)#第二题item=10deffunc():foriteminrange(10):#优先在本地查找,找不到在到全局查找p......
  • 掌握IPython宏:%%macro命令的高效使用指南
    掌握IPython宏:%%macro命令的高效使用指南在编程中,宏是一种允许你定义可重用代码片段的强大工具。IPython,这个增强版的Python交互式环境,提供了一个名为%%macro的魔术命令,允许用户创建宏,从而提高代码的可重用性和效率。本文将详细介绍如何在IPython中使用%%macro命令创建宏,并......
  • 7月24号python:库存管理
    7月24号python:库存管理题目:​ 仓库管理员以数组stock形式记录商品库存表。stock[i]表示商品id,可能存在重复。原库存表按商品id升序排列。现因突发情况需要进行商品紧急调拨,管理员将这批商品id提前依次整理至库存表最后。请你找到并返回库存表中编号的最小的元素以便及......
  • IPython的Bash之舞:%%bash命令全解析
    IPython的Bash之舞:%%bash命令全解析IPython的%%bash魔术命令为JupyterNotebook用户提供了一种在单元格中直接执行Bash脚本的能力。这个特性特别适用于需要在Notebook中运行系统命令或Bash特定功能的场景。本文将详细介绍如何在IPython中使用%%bash命令,并提供实际的代码示......
  • Python数据分析与可视化大作业项目说明(含免费代码)
    题目:对全球和中国互联网用户的数据分析与可视化代码下载链接:https://download.csdn.net/download/s44359487yad/89574688一、项目概述1.1.项目背景:互联网是当今时代最重要和最有影响力的技术之一,它已经深刻地改变了人们的生活、工作、学习等方面。互联网用户数据是反映......
  • IPython的跨界魔术:%%javascript命令深度解析
    IPython的跨界魔术:%%javascript命令深度解析IPython,作为Python编程的强大交互式工具,提供了多种魔术命令来扩展其功能。其中,%%javascript魔术命令允许用户在IPythonNotebook中直接执行JavaScript代码,打通了Python和JavaScript两个世界,为数据可视化、Web内容操作等提供了便......