报错信息
在执行nlp自定义模型的训练函数的时候,报如下错误:
RuntimeError: expected scalar type Float but found Long
错误原因
错误信息指出了问题所在:模型期望的数据类型是 float,但实际上传递给模型的数据类型是 long。
这个错误通常是由于张量数据类型不匹配引起的。在 PyTorch 中,张量数据类型非常重要,因为它们指定了张量中存储的数值的精度和类型。如果您在模型的前向传递中使用了错误的数据类型,就会出现这个错误。
例如:
import torch
import torch.nn as nn
v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)
运行结果:
因为input也就是我们的v是torch.long
类型的而weight是torch.float
类型。所以在做矩阵乘法的时候这两种类型的不一致导致了报错。
解决方案
把v的dtype显示地设置成torch.float代码就成功运行了
import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)
运行结果:
tensor([-0.6189, -0.9843, -0.7568, 0.9157, 0.5192, -0.6109, -0.5627, -0.7755,
-0.9522, 0.7771], grad_fn=<AddBackward0>)
标签:float,nn,Python,RuntimeError,torch,数据类型,报错,import
From: https://www.cnblogs.com/zhangxuegold/p/17534668.html