PyTorch二分类时BCELoss,CrossEntropyLoss,Sigmoid等的选择和使用
这里就总结一下使用PyTorch做二分类时的几种情况:
总体上来讲,有三种实现形式:
- 最后分类层降至一维,使用sigmoid输出一个0-1之间的分数,使用torch.nn.BCELoss作为loss function
self.outputs = nn.Linear(NETWORK_WIDTH, 1)
def forward(self, x):
# other layers omitted
x = self.outputs(x)
return torch.sigmoid(x)
criterion = nn.BCELoss()
net_out = net(data)
loss = criterion(net_out, target)
- 最后分类层降至一维,但是不显式使用sigmoid,而使用torch.nn.BCEWithLogitsLoss作为loss function
self.outputs = nn.Linear(NETWORK_WIDTH, 1)
def forward(self, x):
# other layers omitted
x = self.outputs(x)
return x
###############################################################
criterion = nn.BCEWithLogitsLoss()
net_out = net(data)
loss = criterion(net_out, target)
- 最后分类层nn.Linear输出维度为2维,这时候使用的 loss function 是 torch.nn.CrossEntropyLoss,其已经包含了softmax作为激活函数
self.dense = nn.Linear(hidden_dim,2)
################################################
criterion = nn.CrossEntropyLoss()
net_out = net(data)
loss = criterion(net_out, target)
所以总结一下,在PyTorch中进行二分类,有三种主要的全连接层,激活函数和loss function组合的方法,分别是:
- torch.nn.Linear + torch.sigmoid + torch.nn.BCELoss
- torch.nn.Linear + BCEWithLogitsLoss(集成了Sigmoid)
- torch.nn.Linear(输出维度为2)+ torch.nn.CrossEntropyLoss(集成了Softmax)。