首页 > 其他分享 >RuntimeError: “nll_loss_forward_reduce_cuda_kernel_2d_index“ not implemented for ‘Int‘

RuntimeError: “nll_loss_forward_reduce_cuda_kernel_2d_index“ not implemented for ‘Int‘

时间:2023-10-23 23:24:13浏览次数:33  
标签:loss target kernel Int labels cross entropy input

RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int'

Traceback (most recent call last):
  File "E:/MyWorkspace/EEG/Pytorch/Train.py", line 79, in <module>
    opti='Adam')
  File "E:\MyWorkspace\EEG\Pytorch\Utils.py", line 133, in TrainTest_Model
    validation_loss, validation_acc = Test_Model(net, testloader, criterion,True)
  File "E:\MyWorkspace\EEG\EEGLearn-Pytorch\Utils.py", line 82, in Test_Model
    loss = criterion(outputs, labels.cuda()) # GPU
  File "D:\coson\anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\coson\anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\loss.py", line 1166, in forward
    label_smoothing=self.label_smoothing)
  File "D:\coson\anaconda3\envs\pytorch\lib\site-packages\torch\nn\functional.py", line 3014, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int'

Process finished with exit code 1

 

关键错误在criterion(outputs, labels.cuda()) ,在本工程中criterion运行时给的值CrossEntropyLoss类实例,即: criterion = nn.CrossEntropyLoss()

因此该错误是在loss计算的时候发生的,原因就是类型不匹配,那个参数类的类型不匹配呢?(其实就是labels类型不匹配,

手动给的参数中,我这labels是int32即Int,所以很好判定,但如果不知道,怎么办,接着往下看)
看到 torch\nn\modules\loss.py 的1166行,即 label_smoothing=self.label_smoothing,该1166行是函数调用部分代码,完整代码如下:

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction,
                               label_smoothing=self.label_smoothing)

打开 cross_entropy 定义头,看到如下:

def cross_entropy(
    input: Tensor,
    target: Tensor,
    weight: Optional[Tensor] = None,
    size_average: Optional[bool] = None,
    ignore_index: int = -100,
    reduce: Optional[bool] = None,
    reduction: str = "mean",
    label_smoothing: float = 0.0,
) -> Tensor:

我们传入的参数对应 cross_entropy 的 input 和 target,定义中没有指出input和target数值类型,到底是哪个参数不匹配,接着在 cross_entropy 函数往下看,发现其调用了C函数,芭比Q了,

 

torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

看不到这个C函数的底层实现,如何办,看看官方给函数说明和例子吧,峰回路转,发现在 cross_entropy 函数说明中有以下例子:

 

  # Example of target with class indices
  input = torch.randn(3, 5, requires_grad=True)
  target = torch.randint(5, (3,), dtype=torch.int64)
  loss = F.cross_entropy(input, target)
  loss.backward()
``
官方给的target用的int64,即long类型
所以可以断定`criterion(outputs, labels.cuda())`中的labels参数类型造成。
由上,我们可以对labels参数类型做转为:
```python
labels.long().cuda()

所以:

criterion(outputs, labels.long().cuda())

修改后,代码正常运行。

 

标签:loss,target,kernel,Int,labels,cross,entropy,input
From: https://www.cnblogs.com/ltkekeli1229/p/17783721.html

相关文章

  • 宝塔:续签SSL证书报错string indices must be integers
    网站SSL证书过期,续签的时候,报错stringindicesmustbeintegers。  处理方法:1.点击左侧首页,选择“修复”; 2.修复之后,重新点击网站,设置>>>SSL>>>续签证书,等待流程通过,点击保存即可。 ......
  • const int* 、int* const、const int* const
    2.4.4修饰指针或引用2.4.4.1指向只读变量的指针constinta;//const关键字修饰的是指针所指向的变量,而不是指针本身//不能通过指针修改所指向的变量 指针本身可以被修改 constint*p*p=30;//非法,无法通过指针修改只读变量的值inta=10;constintb......
  • c++ int数组存储long long元素
    高往低存,可能造成数据截断。如longlong64位,int32位,高32位被丢弃。可以将int数组每两个元素分别存储低32位和高32位inta[4];//隐式转换*a=2;//目标格式是int,2默认值默认值默认值*(longlong*)a=2;//目标格式是......
  • [转]setTimeout 和 setInterval 的定时时间深入研究
    原文地址:setTimeout和setInterval的定时时间深入研究-知乎setInterval() -间隔指定的毫秒数不停地执行指定的代码(一直执行)。setTimeout() -在指定的毫秒数后执行指定代码(只执行一次)。使用setInterVal:functiondoStuff(){//此处为需要执行一段时间T......
  • CF1479B1 Painting the Array I
    如果两种方案末尾两数有一数相同,那么答案较大的方案不劣于答案较小的方案。答案较大的方案只需\textbf{模仿}答案较小的方案即可,在状态变成相同之前答案最多只会少\(1\)。所以只需要考虑末尾两数\(a,b\)与新进来的数\(c\)各不相同时该替换哪个。假设\(a\)下次出现的位置......
  • Jlink V8 Interface Description
     JTAGInterfaceConnection(20pin) J-LinkandJ-TracehaveaJTAGconnectorcompatibletoARM'sMulti-ICE.TheJTAGconnectorisa20wayInsulationDisplacementConnector(IDC)keyedboxheader(2.54mmmale)thatmateswithIDCsocketsmou......
  • unity 使用interface 判断 null错误的问题
     在使用Interface,并且由Monobehaviour继承Interface情况下,判断interface的实际UnityEngine.Object是否null,出现错误,没有成功的判断出已经Destroy https://gamedev.stackexchange.com/questions/128971/unity-c-interface-object-never-equals-null解决方案:https://discuss......
  • double转int
    intdoubleToInt(doubledVal,intnMagnification){ doubledCorrect=1.0/(nMagnification*100); intn1=(int)(dVal*nMagnification+dCorrect); returnn1;}测试代码:doubled1=1.0; intk0=100; doubledStep=1.0/k0; doubledCorrect=1.0/(k0*100......
  • 在C#中如何将int转换为枚举?
    内容来自DOChttps://q.houxu6.top/?s=在C#中如何将int转换为枚举?在C#中如何将一个int类型转换为enum类型?从整数类型:YourEnumfoo=(YourEnum)yourInt;从字符串类型:YourEnumfoo=(YourEnum)Enum.Parse(typeof(YourEnum),yourString);//对于带有[Flags]属性的枚......
  • 为什么Java的 +=, -=, *=, /= 复合赋值运算符不需要将long转换为int?
    内容来自DOC[https://q.houxu6.top/?s=为什么Java的+=,-=,*=,/=复合赋值运算符不需要将long转换为int?](https://q.houxu6.top/?s=为什么Java的+=,-=,*=,/=复合赋值运算符不需要将long转换为int?)直到今天,我认为例如:i+=j;只是一个i=i+j;的快捷方式。但......