首页 > 其他分享 >softmax回归的简洁实现

softmax回归的简洁实现

时间:2023-04-28 15:13:09浏览次数:39  
标签:简洁 函数 nn 回归 init softmax 模型 ###

softmax回归的简洁实现

通过深度学习框架的高级API能够使实现softmax回归模型更方便地实现

继续使用Fashion-MNIST数据集,并保持批量大小为256。



import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

softmax回归的输出层是一个全连接层。

为了实现我们的模型, 我们只需在Sequential中添加一个带有10个输出的全连接层。 同样,在这里Sequential并不是必要的, 但它是实现深度模型的基础。 我们仍然以均值0和标准差0.01随机初始化权重。

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
###nn.Flatten():展平层 
###nn.Linear():全连接层 有784(28x28)个输入特征,输出10个类别

##定义网络模型
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

###初始化全连接层权重
def init_weights(m):
    ###若这个模块是全连接层,
    if type(m) == nn.Linear:
        ###将全连接层的权重元素随机初始化为均值为0,方差为0.01的正态分布
        nn.init.normal_(m.weight, std=0.01)
###网络层调用init_weights函数
net.apply(init_weights);

softmax实现

计算了模型的输出,然后将此输出送入交叉熵损失

###定义交叉熵损失函数

loss = nn.CrossEntropyLoss(reduction='none')

###nn.CrossEntropyLoss() 是 PyTorch 中的一个损失函数,用于多分类问题。它结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数,可以用于解决分类问题。

##在 nn.CrossEntropyLoss() 中,输入张量 y_pred 表示模型的预测值,即真实标签。
##reduction 参数指定了损失函数的处理方式,可以设置为"none"(默认)则表示不进行任何处理,即不考虑损失函数的值

###用于多分类问题的重要损失函数,可以用于评估模型的预测结果与真实标签之间的差异,并用于优化模型的参数。

优化算法

###使用学习率为0.1的小批量随机梯度下降作为优化算法。
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

###训练次数
num_epochs = 10
###开始训练
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

标签:简洁,函数,nn,回归,init,softmax,模型,###
From: https://www.cnblogs.com/idazhi/p/17362249.html

相关文章

  • 一个WPF开发的、界面简洁漂亮的音频播放器
    今天推荐一个界面简洁、美观的、支持国际化开源音频播放器。项目简介这是一个基于C#+WPF开发的,界面外观简洁大方,操作体验良好的音频播放器。支持各种音频格式,包括:MP4、WMA、OGG、FLAC、M4A、AAC、WAV、APE和OPUS;支持标记、实时显示歌词等功能;支持换肤、中英文等主流语言。......
  • 多元线性回归
    1绪论2预备知识2.1多元线性回归分析法基本思想2.2多元线性回归分析法的理论模型2.3多元线性回归分析的计算步骤2.3.1参数估计2.3.2假设检验2.4Python语言操作步骤3多元线性回归模型的建立与分析3.1数据收集与分析3.......
  • cnblogs 简洁模式 All In One
    cnblogs简洁模式AllInOnecnblogs简洁模式,显示文章阅读量总量和博客排名操作步骤打开一个非自己的账号的cnblogs博客首页⚠️注意博客皮肤一定要选择一个与自己的博客皮肤一致的才可以例如:CodingLifehttps://i.cnblogs.com/settingshttps://www.cnblogs.com/ano......
  • 线性回归
    线性回归线性模型利用特征的线性函数进行预测,这里的线性指的是参数是线性的。一、普通最小二乘法线性回归(OLS)是最简单&最经典的线性方法,模型寻找截距和系数,使得模型对训练集的预测值与真实值之间的均方误差(MSE)最小,但是线性回归没有办法控制模型的复杂度(模型有大量的非0参数)。......
  • 多元时间序列滚动预测:ARIMA、回归、ARIMAX模型分析|附代码数据
    原文链接:http://tecdat.cn/?p=22849最近我们被客户要求撰写关于多元时间序列滚动预测的研究报告,包括一些图形和统计输出。当需要为数据选择最合适的预测模型或方法时,预测者通常将可用的样本分成两部分:内样本(又称"训练集")和保留样本(或外样本,或"测试集")。然后,在样本中估计模型,并......
  • 数据分享|逻辑回归、随机森林、SVM支持向量机预测心脏病风险数据和模型诊断可视化|附
    原文链接:http://tecdat.cn/?p=24973最近我们被客户要求撰写关于心脏病的研究报告,包括一些图形和统计输出。世界卫生组织估计全世界每年有1200万人死于心脏病。在美国和其他发达国家,一半的死亡是由于心血管疾病简介心血管疾病的早期预后可以帮助决定改变高危患者的生活方式,从......
  • 机器学习之——回归(regression)、梯度下降(gradient descent)
      本文由LeftNotEasy所有,发布于http://leftnoteasy.cnblogs.com。如果转载,请注明出处,在未经作者同意下将本文用于商业用途,将追究其法律责任。前言:  上次写过一篇关于贝叶斯概率论的数学,最近时间比较紧,coding的任务比较重,不过还是抽空看了一些机器学习的书和视频,其中很推荐两......
  • 我国能源消耗的影响模型分析—基于多元线性回归与岭回归模型
    我国能源消耗的影响模型分析—基于多元线性回归与岭回归模型⭕AdamCY888文章目录我国能源消耗的影响模型分析—基于多元线性回归与岭回归模型一、引言二、回归模型简介(一)多元线性回归模型原理(二)建模步骤三、实证分析(一)构建指标及获取数据(二)......
  • 贝叶斯分位数回归、lasso和自适应lasso贝叶斯分位数回归分析免疫球蛋白、前列腺癌数据
    原文链接:http://tecdat.cn/?p=22702最近我们被客户要求撰写关于贝叶斯分位数回归的研究报告,包括一些图形和统计输出。贝叶斯回归分位数在最近的文献中受到广泛关注,本文实现了贝叶斯系数估计和回归分位数(RQ)中的变量选择,带有lasso和自适应lasso惩罚的贝叶斯摘要还包括总结结果、......
  • 2-2线性回归实现
    线性回归实现%matplotlibinlineimportrandomfrommxnetimportautograd,np,npxfromd2limportmxnetasd2l生成数据集根据带有噪声的线性模型构造一个人造数据集。任务是使用这个有限样本的数据集来恢复这个模型的参数。##使用线性模型参数w=[2,-3.4]T,b=4.2和噪......