首页 > 其他分享 >PyTorch下,使用list放置模块,导致计算设备不一的报错

PyTorch下,使用list放置模块,导致计算设备不一的报错

时间:2024-02-04 19:25:44浏览次数:26  
标签:return list module PyTorch 报错 模块 model

报错

在复现 Transformer 代码的训练阶段时,发生报错:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

解决方案

通过next(linear.parameters()).device确定 model 已经在 cuda:0 上了,同时输入 model.forward()的张量也位于 cuda:0。输入的张量没什么好推敲的,于是考虑到模型具有多层结构,遂输出每层结构的设备信息,model.encoder -> model.encoder.sublayer[0] ··· ···

测试发现,model.encoder.sublayer[0] 之后的模块的设备信息均位于 cpu,原因是构造这部分模块时,由于需要多个相同的模块,使用了 list 来存放模块:

# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
    return [copy.deepcopy(module) for _ in range(n)]

显然 list 不支持 GPU,需要用 PyTorch 提供的代替:

def clones(module, n: int):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

ModuleList 把子模块存入列表,能像 Python 里普通的列表被索引,最重要的是能使内部的模块被正确注册,并对所有的 Module 方法可见。[Source]

成功解决!

相关环境

python                    3.11.7               he1021f5_0
pytorch                   2.1.2           py3.11_cuda12.1_cudnn8_0    

标签:return,list,module,PyTorch,报错,模块,model
From: https://www.cnblogs.com/zh-jp/p/18006850

相关文章

  • PyTorch 2.2 中文官方教程(十六)
    介绍torch.compile原文:pytorch.org/tutorials/intermediate/torch_compile_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0注意点击这里下载完整的示例代码作者:WilliamWentorch.compile是加速PyTorch代码的最新方法!torch.compile通过将PyTorch代码JIT编译成优化的......
  • PyTorch 2.2 中文官方教程(十八)
    开始使用完全分片数据并行(FSDP)原文:pytorch.org/tutorials/intermediate/FSDP_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:HamidShojanazeri,YanliZhao,ShenLi注意在github上查看并编辑本教程。在大规模训练AI模型是一项具有挑战性的任务,需要大量的计算能力和资源......
  • PyTorch 2.2 中文官方教程(十九)
    使用RPC进行分布式管道并行原文:pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:ShenLi注意在github中查看并编辑本教程。先决条件:PyTorch分布式概述单机模型并行最佳实践开始使用分布式RPC框......
  • PyTorch 2.2 中文官方教程(二十)
    移动设备在iOS上进行图像分割DeepLabV3原文:pytorch.org/tutorials/beginner/deeplabv3_on_ios.html译者:飞龙协议:CCBY-NC-SA4.0作者:JeffTang审阅者:JeremiahChung介绍语义图像分割是一种计算机视觉任务,使用语义标签标记输入图像的特定区域。PyTorch语义图像分割De......
  • PyTorch 2.2 中文官方教程(十一)
    使用PyTorchC++前端原文:pytorch.org/tutorials/advanced/cpp_frontend.html译者:飞龙协议:CCBY-NC-SA4.0PyTorchC++前端是PyTorch机器学习框架的纯C++接口。虽然PyTorch的主要接口自然是Python,但这个PythonAPI坐落在一个庞大的C++代码库之上,提供了基础数据......
  • PyTorch 2.2 中文官方教程(十二)
    自定义C++和CUDA扩展原文:pytorch.org/tutorials/advanced/cpp_extension.html译者:飞龙协议:CCBY-NC-SA4.0作者:PeterGoldsboroughPyTorch提供了大量与神经网络、任意张量代数、数据处理和其他目的相关的操作。然而,您可能仍然需要更定制化的操作。例如,您可能想使用在论......
  • PyTorch 2.2 中文官方教程(十三)
    在C++中注册一个分发的运算符原文:pytorch.org/tutorials/advanced/dispatcher.html译者:飞龙协议:CCBY-NC-SA4.0分发器是PyTorch的一个内部组件,负责确定在调用诸如torch::add这样的函数时实际运行哪些代码。这可能并不简单,因为PyTorch操作需要处理许多“层叠”在彼此之......
  • windows查看端口占用,通过端口找进程号(查找进程号),通过进程号定位应用名(查找应用)(netstat
     文章目录通过端口号查看进程号`netstat`通过进程号定位应用程序`tasklist` 通过端口号查看进程号netstat在Windows系统中,可以使用netstat命令来查看端口的占用情况。以下是具体的步骤:打开命令提示符(CMD):按Win+R组合键打开运行对话框,输入cmd并按Enter键。......
  • PyTorch 2.2 中文官方教程(十四)
    参数化教程原文:译者:飞龙协议:CCBY-NC-SA4.0作者:MarioLezcano注意点击这里下载完整示例代码在本教程中,您将学习如何实现并使用此模式来对模型进行约束。这样做就像编写自己的nn.Module一样容易。对深度学习模型进行正则化是一项令人惊讶的挑战。传统技术,如惩罚方法,通......
  • PyTorch 2.2 中文官方教程(十)
    使用整体追踪分析的追踪差异原文:pytorch.org/tutorials/beginner/hta_trace_diff_tutorial.html译者:飞龙协议:CCBY-NC-SA4.0作者:AnupamBhatnagar有时,用户需要识别由代码更改导致的PyTorch操作符和CUDA内核的变化。为了支持这一需求,HTA提供了一个追踪比较功能。该......