首页 > 其他分享 >3.6--softmax回归的从零开始实现

3.6--softmax回归的从零开始实现

时间:2024-07-12 14:55:54浏览次数:12  
标签:acc -- sum iter 3.6 train softmax test

softmax回归从零实现

前言

本节介绍softmax和交叉熵损失函数的从零开始实现。


一、导入相关的库

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

print(torch.__version__)
print(torchvision.__version__)

2.3.1+cpu
0.14.1+cu117

二、数据和模型参数

1.读取数据

使用上一节提供的方法

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

2.初始化模型参数

定义权重和偏差

num_inputs =784
num_outputs = 10

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)

为模型参数附上梯度,用于反向传播修改

W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

sum函数求和实例

# sum函数求和实例
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.size())
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))

三、实现softmax运算

# X.exp()是对tensor中每个变量操作
def softmax(X):
	X_exp = X.exp()
	partition = X_exp.sum(dim=1, keepdim=True)
	# 这里应用了广播机制
	return X_exp/partition

四、定义模型

# 返回256×10的张量
def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

五、定义损失函数

gather()参数dim=1表示,矩阵中数的值是列值,矩阵中数所处位置的行值是行值。dim=0时相反。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

tensor([[0.1000],
[0.5000]])

# 交叉熵损失函数
# 先选出y_hat对应样本标签处的值,然后求对数就是交叉熵损失函数的值
def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

六、计算分类准确率

本函数已保存在d2lzh_pytorch包中方便以后使用

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        # net(X)使样本经过模型和softmax,argmax(dim=1)输出样本最大值的索引,然后判断是否与真实标签相等
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() # argmax(dim=1)表示第二个维度中最大值的下标
        n += y.shape[0]
    return acc_sum / n

测试

print(evaluate_accuracy(test_iter, net))

0.0477

七、训练模型

num_epochs, lr = 5, 0.1

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            # 正向传播
            y_hat = net(X)
            # sum()的原因是后面要根据标量反向传播
            l = loss(y_hat, y).sum()
            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            # 反向传播原因可以看这篇文章https://blog.csdn.net/weixin_45021364/article/details/105194187
            l.backward()
            # 优化器修改模型参数,修改params.data梯队不会对其计算
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                optimizer.step()  # “softmax回归的简洁实现”一节将用到
            
            # train_l_sum为交叉熵损失函数的和
            train_l_sum += l.item()
            # 预测准确的数量
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

结果显示:
epoch 1, loss 0.7857, train acc 0.747, test acc 0.790
epoch 2, loss 0.5713, train acc 0.813, test acc 0.812
epoch 3, loss 0.5252, train acc 0.826, test acc 0.820
epoch 4, loss 0.5016, train acc 0.831, test acc 0.824
epoch 5, loss 0.4858, train acc 0.837, test acc 0.825

八、预测

X, y = next( iter(test_iter))
# 真实样本标签
true_labels = d2l.get_fashion_mnist_labels(y.numpy())
# 测试样本标签
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
# 显示前9张图
d2l.show_fashion_mnist(X[0:9], titles[0:9])

在这里插入图片描述


总结

从这节中学到了
1.softmax对模型预测结果的处理,softmax使预测概率为正数,并且概率在范围[0,1]。
2.交叉熵损失函数为softmax处理结果中,把样本真实标签对应位置的预测概率求对数再取反就是对应样本的损失函数。
3.准确率是样本预测正确总数除样本总数。
4.模型每次迭代训练过程为:正向传播、计算损失函数、反向传播(梯度清零)、优化器修改模型参数

标签:acc,--,sum,iter,3.6,train,softmax,test
From: https://blog.csdn.net/qq_53243414/article/details/140373502

相关文章

  • 初识c语言-1
     1.主函数intmain(){return0;} 注:c语言规定main是函数的入口,且只能有一个。  2.数据类型  是用来创建变量的,创建变量的本质是用来向内存申请空间的。char字符数据类型1byetshort短整型2byetin整型4byetlong长整型4byetlonglong更长的整型  8byet f......
  • C++ 多态
    1.多态的概念多态的概念:通俗来说,就是多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同的状态。比如买票这个行为,当普通人买票时,是全价买票;学生买票时,是半价买票;军人买票时是优先买票。2.多态的定义及实现2.1 虚函数虚函数:即被virtu......
  • 模型加载20G以上的超大语料,无法加载,怎么办呢?
    背景:在做机器翻译的时候,我们的单边语料大约20G大小的纯文本语料,在DataLoader加载的时候不可能一次性加载进来,所以就有了这个超大语料的加载问题。解决方案:data_dealing.py:importosimportsysroot_dir=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))......
  • MySQL与Redis优化
    MySQL优化策略:查询优化:使用EXPLAIN分析查询语句,优化JOIN操作,减少子查询和复杂的WHERE条件。索引优化:合理创建索引以加快查询速度,同时避免过度索引导致写性能下降。数据类型优化:使用合适的数据类型,避免冗余和浪费,例如使用TIMESTAMP代替DATETIME。表结构优化:如垂直分割和水平......
  • 获取数据库表格字段描述
    USE[database1]GO/******Object:StoredProcedure[dbo].[Sp_ObjItems]ScriptDate:2024/7/1213:17:42******/SETANSI_NULLSONGOSETQUOTED_IDENTIFIERONGOcreatePROCEDURE[dbo].[CheckFormDescription]--Addtheparametersforthestoredprocedur......
  • MES 与 PLC 的几种交互方式
       在MES开发领域,想要从PLC获取数据就必须要和PLC有信号交互。高效准确的获取PLC数据一直是优秀MES系统开发的目标之一。初涉相关系统开发的工程师往往不能很好的理解PLC和MES之间编程逻辑的本质差别,在设计交互逻辑是难免顾此失彼。因此本文结合本人这些年来和......
  • 【linux】nmon资源监控与定时任务
    原文:https://www.runoob.com/linux/linux-comm-crontab.htmlcrontab定时任务:【nmon监控稳定性场景】122、126406,14,22***root/home/nomouser/nmon-f-s20-c1620-m/home/nomouser123406,14,22***root/root/nmon-f-s20-c1620-m/root0*/8***......
  • 接触
    接触类型Bonded(绑定):默认接触形式,不允许界面或单元相对滑动或分离,即使加载或移除载荷后也能保持接触。法线方向不可分开,切向也不行NoSeparation(不分离):不允许分离,即使加载或移除载荷后,界面不允许接触面分离,但允许相对滑动。相当于相对面间有小范围滑动,即法向不可分离,......
  • Java性能优化-switch性能优化-用String还是int做比较
    场景Java中使用JMH(JavaMicrobenchmarkHarness微基准测试框架)进行性能测试和优化:https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/131723751参考以上性能测试工具的使用。下面针对Java中对switch-case比较时使用String还是int性能做对比。注:博客:https://bl......
  • WPF中引用不到相对路径图片?
    在wpf中使用相对路径运行项目时却不显示图片怎么解决?新建img文件夹添加所需要的图片选中图片右键属性设置属性重新生成即可运行效果转载请标明出处!......