首页 > 编程问答 >使用 RNN 生成 PyTorch 路径 - 与输入、输出、隐藏和批量大小混淆

使用 RNN 生成 PyTorch 路径 - 与输入、输出、隐藏和批量大小混淆

时间:2024-07-24 05:38:01浏览次数:22  
标签:python machine-learning pytorch lstm recurrent-neural-network

我遵循了关于使用 RNN 生成句子的教程,并且尝试修改它以生成位置序列,但是我在定义正确的模型参数(例如 input_size、output_size、hidden_​​dim、batch_size)时遇到了麻烦。

背景:

我有 596 个 x,y 位置序列,每个序列看起来像 [[x1,y1],[x2,y2],...,[xn,yn]]。每个序列代表车辆的 2D 路径。我想训练一个模型,给定一个起点(或部分序列),可以生成这些序列之一。

-我已填充/截断序列,以便它们的长度均为 50,这意味着每个序列是形状为 [50,2]

的数组 - 然后我将此数据分为 input_seq 和 target_seq:

input_seq: torch.Size([596, 49, 2]) 的张量。包含所有 596 个序列,每个序列没有最后一个位置。

target_seq:torch.Size([596, 49, 2]) 的张量。包含所有 596 个序列,每个序列没有第一个位置。

模型类:

class Model(nn.Module):
def __init__(self, input_size, output_size, hidden_dim, n_layers):
    super(Model, self).__init__()
    # Defining some parameters
    self.hidden_dim = hidden_dim
    self.n_layers = n_layers
    #Defining the layers
    # RNN Layer
    self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)
    # Fully connected layer
    self.fc = nn.Linear(hidden_dim, output_size)

def forward(self, x):
    batch_size = x.size(0)      
    # Initializing hidden state for first input using method defined below
    hidden = self.init_hidden(batch_size)
    # Passing in the input and hidden state into the model and obtaining outputs
    out, hidden = self.rnn(x, hidden)
    # Reshaping the outputs such that it can be fit into the fully connected layer
    out = out.contiguous().view(-1, self.hidden_dim)
    out = self.fc(out)        
    return out, hidden

def init_hidden(self, batch_size):
    # This method generates the first hidden state of zeros which we'll use in the forward pass
    # We'll send the tensor holding the hidden state to the device we specified earlier as well
    hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
    return hidden

我使用以下参数实例化模型:

input_size 为 2([x,y] 位置)

output_size 2([x,y] 位置)

hidden_​​dim 2([x,y] 位置)(或者这应该是 50,如完整序列的长度?)

model = Model(input_size=2, output_size=2, hidden_dim=2, n_layers=1)
n_epochs = 100
lr=0.01
# Define Loss, Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Training Run
for epoch in range(1, n_epochs + 1):
    optimizer.zero_grad() # Clears existing gradients from previous epoch
    output, hidden = model(input_seq)
    loss = criterion(output, target_seq.view(-1).long())
    loss.backward() # Does backpropagation and calculates gradients
    optimizer.step() # Updates the weights accordingly
    if epoch%10 == 0:
        print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ')
        print("Loss: {:.4f}".format(loss.item()))

当我运行训练时循环,它因以下错误而失败:

ValueError                                Traceback (most recent call last)
<ipython-input-9-ad1575e0914b> in <module>
      3     optimizer.zero_grad() # Clears existing gradients from previous epoch
      4     output, hidden = model(input_seq)
----> 5     loss = criterion(output, target_seq.view(-1).long())
      6     loss.backward() # Does backpropagation and calculates gradients
      7     optimizer.step() # Updates the weights accordingly
...

ValueError: Expected input batch_size (29204) to match target batch_size (58408).

我尝试修改 input_size、output_size、hidden_​​dim 和 batch_size 并重塑张量,但我尝试的越多,我就越困惑。有人能指出我做错了什么吗?

此外,由于批次大小在 Model.forward(self,x) 中定义为 x.size(0),这意味着我只有一个批次大小为 596,对吗?拥有多个小批量的正确方法是什么?


让我们来分解一下你的 PyTorch RNN 模型中的维度问题,并澄清如何使用批处理。

理解维度

  • input_size : 这应该是每个时间步的输入向量的大小。在你描述为 [x, y] 位置的场景中, input_size 应该是 2

  • hidden_size : 这控制着 RNN 隐藏状态的维度,并且是一个可以根据你模型的复杂性和你想要捕获的序列中依赖关系的数量进行调整的超参数。 它不直接对应于序列长度或坐标数量。 更大的 hidden_size 值通常可以表示更复杂的模式,但也需要更多的计算资源。

  • output_size : 这应该是每个时间步的期望输出向量的大小。由于你想预测一个 [x, y] 位置, output_size 也应该是 2

  • batch_size : 这表示在一次迭代(或“批处理”)中处理的序列数。虽然你可以在单个批处理中处理所有 596 个序列,但这对于内存来说可能太多了。将你的数据分成更小的批处理(例如,32 或 64)通常更有利于训练,并允许模型更快地更新。

代码中的错误

你代码中的主要错误在于你如何计算损失并处理目标的形状:

  1. CrossEntropyLoss 用于分类: 你使用的是 nn.CrossEntropyLoss ,它是为分类问题设计的,其中输出代表不同类别的概率。对于回归任务(例如预测 [x, y] 坐标),你应该使用 nn.MSELoss (均方误差)或 nn.L1Loss (平均绝对误差)。

  2. 目标形状不匹配: CrossEntropyLoss 希望目标为每个时间步的类索引(整数)。你正试图提供形状为 [596, 49, 2] 的浮点目标,这会导致不匹配。

修改后的代码和解释

以下是修正后的代码,并解释了如何实现批处理:

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.rnn = nn.RNN(input_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        # 如果没有提供隐藏状态,则初始化
        if hidden is None:
            hidden = self.init_hidden(x.size(0)) 
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out) 
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.n_layers, batch_size, self.hidden_size)

# 超参数
input_size = 2
hidden_size = 128  # 选择一个合适的隐藏大小
output_size = 2
n_layers = 1
batch_size = 32 
n_epochs = 100
lr = 0.001

# 模型、损失和优化器
model = Model(input_size, hidden_size, output_size, n_layers)
criterion = nn.MSELoss()  # 针对回归使用 MSELoss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 数据加载器(用于批处理)
# 假设你的数据在名为 'dataset' 的张量中
dataset = torch.randn(596, 49, 2)  # 替换为你的实际数据
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 训练循环
for epoch in range(n_epochs):
    for batch_idx, data in enumerate(dataloader):
        # 清除梯度
        optimizer.zero_grad()

        # 将输入数据和目标数据分开(假设最后一个时间步是目标)
        input_seq = data[:, :-1, :] 
        target_seq = data[:, 1:, :]

        # 前向传播
        outputs, _ = model(input_seq)

        # 计算损失
        loss = criterion(outputs, target_seq)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(f'Epoch: {epoch+1}/{n_epochs}, Batch: {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}')

说明:

  1. DataLoader 代码现在使用 torch.utils.data.DataLoader 来有效地处理你的数据并将其分成批处理。

  2. 批处理循环: 外循环迭代 epoch,内循环迭代 dataloader 提供的批处理。

  3. 输入/目标拆分: 在每个批处理中,输入序列 ( input_seq ) 被视为除最后一个时间步之外的所有时间步,而目标序列 ( target_seq ) 被视为除第一个时间步之外的所有时间步(因为你正尝试预测下一个位置)。

  4. MSELoss: 代码现在使用 nn.MSELoss 来计算预测坐标和目标坐标之间的均方误差。

  5. 形状: 注意,模型的输出 outputs 将具有形状 [batch_size, sequence_length - 1, output_size] ,这与 target_seq 的形状相匹配。

通过这些调整,你的代码应该能够正确处理维度并使用批处理来训练 RNN 模型。记住, hidden_size 是一个超参数,你可能需要根据你的数据的复杂性对其进行实验以找到最佳值。

标签:python,machine-learning,pytorch,lstm,recurrent-neural-network
From: 62305941

相关文章

  • 如何在 Python 中查看与 Azure OpenAI 助手关联的所有上传文件?
    我正在使用Python对文档中的问题进行基准测试,并在jupyter笔记本中实例化了我的助手。我想确认助手是否有我上传的文件,但似乎找不到有关此功能将使用什么功能的文档。使用适用于AzureOpenAI的最新版本的PythonAPI。目前,无法使用AzureOpenAI的PythonAPI直接查看......
  • 如何在Python中计算小数?
    我正在创建一个计算器来用python计算企业的利润,但到目前为止我只能使用整数。这是我的代码示例:Gross=int(input("PleaseentertotalGrossRevenuefortheFiscalYear"))NetTaxes=int(Gross)*0.1所以我将会计年度的总收入乘以按“税率”计算,但我只能使用......
  • 如何使用 Python 打开 Google Firestore 上的特定数据库?
    我正在使用Firebase并使用以下代码从Firestore设置/检索文档:importfirebase_adminfromfirebase_adminimportcredentials,firestorecred=credentials.ApplicationDefault()firebase_admin.initialize_app(cred,options={"projectId":"huq-jimbo"})fires......
  • 如何使用 Python 和 Numpy 重现 Matlab 文件读取以解码 .dat 文件?
    我有一个Matlab脚本,可以读取编码的.dat文件,对其进行解码并保存。我试图使用numpy将其转换为Python。我发现对于同一个文件,我得到不同的输出结果(python数字没有意义)。该代码最初作为从串行端口读取的脚本的一部分运行,因此是数据的结构。我首先认为位移是问题所在,因为......
  • 在Python中调整pdf页面大小
    我正在使用python裁剪pdf页面。一切正常,但如何更改页面大小(宽度)?这是我的裁剪代码:input=PdfFileReader(file('my.pdf','rb'))p=input.getPage(1)(w,h)=p.mediaBox.upperRightp.mediaBox.upperRight=(w/4,h)output.addPage(p)当我裁剪页面时,我也需要......
  • 如何使用 python 更改资源管理器窗口中的路径?
    没有人知道如何在不使用python打开新实例的情况下更改资源管理器窗口中的当前路径吗?例如,如果用户使用C:\Users\User打开资源管理器窗口。然后我必须将该路径更改为C:\Windows\System32例如。提前致谢。很遗憾,无法直接使用Python更改现有文件资源管理器窗口的......
  • python 以及将数组传递给函数的问题
    我需要求解一些常微分方程$\frac{dy}{dx}=f(x)=x^2ln(x)$并继续在限制0之间创建数组xpt。<=xpt<=2因为我必须小心xpt=0,所以我将函数定义如下deff(x):ifx<=1.e-6:return0.else:returnnp.square(x)*np.log(x)我的调用程序读取Np......
  • 如果 Python 脚本正在使用文件夹,如何在文件资源管理器中进行更改时防止 Windows 的“
    我有一个简单的脚本,显示在QTreeView中的QListView中选择的目录的内容,我想添加打开文件资源管理器的功能,以让用户编辑目录内的内容。但是,添加新的文件夹和文件可以,但删除或移动文件夹或文件会提示“文件夹正在使用”错误:此操作无法完成,因为该文件已在另一个程......
  • 如何使用 Python API 获取每个模型的活跃用户列表、最后登录信息
    我想通过PythonAPI获取我的dbt项目的所有模型中的活动或非活动用户列表。这可能吗?我尝试列出模型,但无法获取用户信息,如用户名、项目、以及上次活动或上次登录。不幸的是,dbt本身并不跟踪你所寻找的用户活动数据(最后登录、活跃用户等)。dbt的主要功能是转换数据,而不......
  • Python tkinter 窗口不断关闭,我不知道为什么
    我正在尝试制作一个有趣的小程序,其中每小时左右就会有一只毛茸茸的动物走过屏幕。我有一个主窗口,它启动一个循环,每小时左右播放一次动画/声音,但是在口袋妖怪第一次完成行走后,整个程序就会结束。我认为这可能与我设置tkinter窗口的方式有关,但我无法弄清楚。我认为在这里包含......