本文重点
前面我们学习了pytorch中两种模式的损失函数,一种是nn,另外一种是functional,本文将讲解pytorch中已经封装好的损失函数。其实nn的方式就是类,而functional的方式就是方法。nn中使用的也是functional。
损失函数中的参数
无论是nn还是functional,大多数的损失函数都有size_average和reduce两个布尔类型的参数,因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。
如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss
如果 reduce = True,那么 loss 返回的是标量 ,此时:
如果 size_average = True,返回 loss.mean();
如果 size_average =False,返回 loss.sum();
般损失函数默认:求所有损失的均值