首页 > 编程问答 >LSTNet RuntimeError:输入和参数张量不在同一设备上

LSTNet RuntimeError:输入和参数张量不在同一设备上

时间:2024-07-27 06:36:44浏览次数:12  
标签:python pytorch lstm

我克隆了一个 github repo 它运行一个pytorch深度学习模块,我定制了这部分以将模块和数据发送到GPU。

train_dataset = MarketDataset(train_data, history_len=history_len)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = LSTNet()
model.to("cuda")

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

train_loss_list = []
for epoch in tqdm(range(epochs)):
    
    epoch_loss_train = 0
    for i, batch in tqdm(enumerate(train_data_loader, start=1), 
                         leave=False, desc="Train", total=len(train_data_loader)):
            
        X, Y = batch
        X = X.to("cuda")
        Y = Y.to("cuda")
        optimizer.zero_grad()
        Y_pred = model(X)
        loss = criterion(Y_pred, Y)
        loss.backward()
        optimizer.step()

        with open(root_dir / 'Log/Running-Loss.txt', 'a+') as file:
            file.write(f'{loss.item()}\n')
        epoch_loss_train += loss.item()
        
    epoch_loss_train = epoch_loss_train / len(train_data_loader)
    train_loss_list.append(epoch_loss_train)
    
    with open(root_dir / 'Log/Epoch-Loss.txt', 'a+') as file:
        file.write(f'{epoch_loss_train}\n')

但是我得到了这个错误

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu

如果我检查了,两者该模块以及 X 和 Y 已经在 cuda 上。为什么错误仍然提示输入和参数不在同一设备?


此错误消息表示的模型的一部分(可能是某些子模块)的参数没有移动到 GPU,而输入数据 (X 和 Y) 在 GPU 上。

正在正确地将模型本身移动到 GPU ( model.to("cuda") ),但需要确保所有模型 参数 也被移动到 GPU。

问题可能出在定义 LSTNet 模型的方式上。需要检查模型定义并在任何子模块或层中显式地将它们移动到 GPU。

以下是一些可以尝试的解决方案:

  1. 在模型定义中将子模块移动到 GPU:
  2. 检查 LSTNet 类的 __init__ 方法。
  3. 对于每个定义的层或子模块,在实例化后立即添加 .to("cuda")

例如,如果的模型定义如下:

```python class LSTNet(nn.Module): def init (self, ...): super(LSTNet, self). init () self.conv1 = nn.Conv1d(...) self.lstm = nn.LSTM(...) self.linear = nn.Linear(...)

```

需要将其更改为:

python class LSTNet(nn.Module): def __init__(self, ...): super(LSTNet, self).__init__() self.conv1 = nn.Conv1d(...).to("cuda") self.lstm = nn.LSTM(...).to("cuda") self.linear = nn.Linear(...).to("cuda")

  1. 使用 apply 方法递归地将参数发送到 GPU:

  2. 定义一个将模块及其参数发送到 GPU 的函数:

    python def to_cuda(module): module = module.to("cuda") return module

  3. 在创建模型实例后应用该函数:

    python model = LSTNet() model.apply(to_cuda)

通过以上任一方法,可以确保模型的所有参数和输入数据都在同一设备(GPU)上,从而解决 RuntimeError

标签:python,pytorch,lstm
From: 77020548

相关文章

  • Python win32serviceutil QueryServiceStatus:返回值是什么意思?
    我正在学习使用pywin32,并尝试在64位Python3.6.4上使用win32serviceutil模块以下代码:importwin32serviceutilasserviceserviceStatus=service.QueryServiceStatus("WinDefend")print(serviceStatus)返回以下元组:(16,4,197,0,0,0,0)我对wind......
  • Python request-html 未下载 Chromium
    importrequestsfrombs4importBeautifulSoupfromrequests_htmlimportHTMLSessionurl="https://dmarket.com/ingame-items/item-list/csgo-skins?title=recoil%20case"sesion=HTMLSession()response=sesion.get(url)response.html.render()soup=B......
  • VS Code 不改变 python 环境
    我正在使用VS-Code和anaconda环境作为python解释器。我通过ctrl+shift+`选择准确的anaconda基础环境,它也反映在vscode的下侧面板中。但是,当我检查python版本时,它显示我系统的默认python环境3.7.9如果您看到下面的截图,anaconda环境是3.......
  • 使用 Python 打开保存为 Parquet 文件中元数据的 R data.table
    使用R,我创建了一个Parquet文件,其中包含一个data.table作为主要数据,另一个data.table作为元数据。library(data.table)library(arrow)dt=data.table(x=c(1,2,3),y=c("a","b","c"))dt2=data.table(a=22222,b=45555)attr(dt,&......
  • Python 需要 Windows 长路径
    我尝试运行此安装:pip3installmsgraph-sdk它给了我这个错误:它说我需要使用此链接启用Windows长路径:https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation?tabs=registry#enable-long-paths-in-windows-10-versi......
  • Python griddata() 和 Matlab griddata():某些网格点的结果不同
    在将一些(相当大的物理)Matlab代码转换为Python时,我偶然发现了这种情况。当对相同的二维离散数据进行插值时,Python/Scipy的griddata()函数给出的结果与Matlab的对应函数不同。griddata()Matlab示例代码:Python示例代码:%Samplepoints(x,y):7x5=3......
  • Ebay Python SDK 仅在特定项目类别上返回错误
    我在一个项目中使用ebaySDK一段时间了。最近我尝试导入一些商品,例如手表、手机壳等...并且我使用了eBay自己通过eBay返回的英国商店页面上的类别ID他们的“get_category_suggestions”API端点,但eBay似乎有选择地决定拒绝某些项目并引发服务器错误!为了测试,我做了......
  • 使用特定的Python版本(MacOS)制作virtualenv
    我安装了brew,python3(默认和最新版本)和pip3,pyenv。TensorFlow现在不支持python3.7,所以我听说我应该制作一个独立运行3.6或更低版本的virtualenv。我安装了python3.6.7bypyenvinstall3.6.7但无法制作virtualenv-p3.6.7(mydir)因为3.6.7不在P......
  • 使用Python去除图像中的线条
    我正在尝试使用Python和cv2、numpy、skimage等从黑白图像中删除“阴影线”(如果图像中存在“阴影线”)。本质上,我的图像可以有1或2条曲线,如下例所示。但每条线都有一条1-5像素外的阴影线,需要删除。我怎样才能在Python中做到这一点?原始......
  • Python 和 OpenCV:如何裁剪半成形边界框
    我有一个为无网格表创建网格线的脚本:脚本之前:脚本之后:是否有一种简单的方法,使用OpenCV来裁剪“脚本之后”图像,使其仅包含四边边界框?示例输出:编辑:我目前正在研究一种解决方案,该解决方案可以找到垂直/水平方向的第一条/最后一条......