首页 > 其他分享 >多分类问题

多分类问题

时间:2023-01-26 12:22:59浏览次数:42  
标签:loss head torch nn self 分类 问题 train

目录

Softmax

二分类问题

给定一系列特征,输出为0或1,表示是否满足某个条件。具体做法是输出一个概率,表示给定特征满足这个条件的概率,或者不满足这个条件的概率。

多分类问题

给定一系列特征,预测是多个类别中的哪一类,比如手写数组识别、物体识别等。

如果在多分类问题中仍采用二分类问题的解决方法,即输出可能属于每个类别的概率,会出现的问题有

  1. 输出的概率可能为负数
  2. 所有类别概率之和不为1,即不是一个分布

提出Softmax Classifier解决上述问题,最后一个线性层输出的结果是z,包括预测属于k个类别的概率,公式如下

  1. 通过计算指数保证了最终输出结果必为正数
  2. 通过归一化保证了最终输出所有类别概率之和为1

image-20230126105933975

举例如下

image-20230126110646483

多分类损失函数

二分类损失函数cross-entropy(交叉熵)

本质还是损失函数,描述预测结果和真实结果之间的差异程度

image-20230126110903486

y:真实值,y_head:预测值

  • y = 1

    • y_head = 1

      预测值和真实值之间吻合,loss会很小

    • y_head = 0

      预测值和真实值之间差异较大,loss会很大,注意看-log(y_head)

  • y = 0

    • y_head = 1

      预测值和真实值之间差异较大,loss会很大,注意看-log(1-y_head)

    • y_head = 0

      预测值和真实值之间吻合,loss会很小

多分类损失函数

没太明白,看弹幕有什么独热编码,记住公式吧?

image-20230126112051588

pytorch提供的交叉熵损失函数直接包括计算log(y_head)、softmax和损失函数计算

image-20230126112205453

注意:最后一层不做激活,直接使用交叉熵损失函数,传入softmax

实现手写数字识别

import torch
from torch import nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# --------------------------------------- 数据准备 ----------------------------------------
batch_size = 64
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)


# --------------------------------------- 定义网络模型 ----------------------------------------
class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.linear1 = nn.Linear(784, 512)
       self.linear2 = nn.Linear(512, 256)
       self.linear3 = nn.Linear(256, 128)
       self.linear4 = nn.Linear(128, 64)
       self.linear5 = nn.Linear(64, 10)

   def forward(self, x):
       x = x.view(-1, 784)
       x = F.relu(self.linear1(x))
       x = F.relu(self.linear2(x))
       x = F.relu(self.linear3(x))
       x = F.relu(self.linear4(x))
       x = self.linear5(x)         # 注意最后一层不加激活函数
       return x


# -------------------------- 实例化网络模型 定义损失函数和优化器 --------------------------------

device = torch.device("cuda")   # 定义gpu设备
model = Net()
model = model.to(device)

criterion = nn.CrossEntropyLoss()   # 交叉熵损失函数
criterion = criterion.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# --------------------------------------- 定义训练过程 ----------------------------------------

def train(epoch):
   running_loss = 0.0
   for batch_idx, data in enumerate(train_loader, 0):
       inputs, targets = data
       inputs = inputs.to(device)
       targets = targets.to(device)
       optimizer.zero_grad()
       outputs = model(inputs)  # forward
       loss = criterion(outputs, targets)  # get loss
       loss.backward()  # backward
       optimizer.step()  # update

       running_loss += loss.item()
       if batch_idx % 300 == 299:  # 每300次输出
           print('[%d, %5d] loss: %3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
           running_loss = 0.0


# --------------------------------------- 定义测试过程 ----------------------------------------
def test():
   correct = 0
   total = 0
   with torch.no_grad():
       for data in test_loader:
           images, labels = data
           images = images.to(device)
           labels = labels.to(device)
           outputs = model(images)
           _, predicted = torch.max(outputs.data, dim=1)
           total += labels.size(0)
           correct += (predicted == labels).sum().item()
   print('Accuracy on test set: %d %%' % (100 * correct / total))


if __name__ == '__main__':
   for epoch in range(10):
       train(epoch)
       test()

标签:loss,head,torch,nn,self,分类,问题,train
From: https://www.cnblogs.com/dctwan/p/17067688.html

相关文章

  • 解决移动端页面向上移动问题
    在PC上网页是固定的,但是到了移动端滑动到底时,页面总要向上移动一点,查了各种方法都没有效果,后来在CSS里试了试这个方法:    *{            margin......
  • Serverless架构下用Python轻松实现图像分类和预测
    Serverless架构下用Python轻松实现图像分类和预测图像分类是人工智能领域的一个热门话题。通俗解释就是,图像分类是一种根据各自在图像信息中所反映的不同特征,把不同类别的......
  • ajax跨域访问的问题解决
    在web项目中经常用到在ajax中进行跨域访问,比如在a域中访问b域中的服务,却实现不了。原因是:浏览器为了保证服务器数据的安全,对于这种请求,所给予的权限是较低的,通常只允许调用......
  • KMP字符串匹配问题
    KMP算法本文参考资料:https://www.zhihu.com/question/21923021KMP算法是一种字符串匹配算法,可以在\(O(n+m)\)的时间复杂度内实现两个字符串的匹配。字符串匹配问题首......
  • 经典问题 1 —— DAG 上区间限制拓扑序
    问题描述给定一个DAG,求一个拓扑序,使得节点\(i\)的拓扑序\(\in[l_i,r_i]\)。题解首先进行一个预处理:对于所有\(u\),令\(\forall(v,u)\inE,l_u\leftarrow\max(l......
  • IP地址分类
    这学期学习计算机网络,写这篇文章,权当复习之用,这里说的IP地址为IPV4,不是IPV6。1)IP地址的组成一个IP地址由32位二进制数组成,如10000000000010110000001100011111为了方便......
  • Linux运维之解决服务器挖矿木马问题
    目录1挖矿木马1.1定义1.2挖矿特征1.3解决挖矿木马1.3.1阻断异常网络通信(非必需)1.3.2清除定时任务1.3.3清除启动项1.3.4清除SSH公钥1.3.5清除木马进程1.4其他常见......
  • 问题:RuntimeError: Model class LuffyAPI.apps.user.models.UserInfo doesn't declare
    问题截图  报错原因提示app未注册,但实际上已经注册的#1.#settings配置文件移动后要将这个settings添加到环境变量中sys.path.insert(0,BASE_DIR)#将所有app......
  • 安装OpenCV时提示缺少boostdesc_bgm.i文件的问题解决方案
    安装OpenCV时,会遇到下面的错误/home/zhang/slam/opencv-3.4.5/opencv_contrib/modules/xfeatures2d/src/boostdesc.cpp:653:20:fatalerror:boostdesc_bgm.i:没有那个文......
  • WMI and ACPI 问题
    微软的资料如下:​​http://msdn.microsoft.com/en-us/library/windows/hardware/Dn614028(v=vs.85).aspx​​Code:​​http://code.msdn.microsoft.com/windowshardware/WMI-......