首页 > 其他分享 >3.7 softmax回归的简洁实现

3.7 softmax回归的简洁实现

时间:2023-05-30 17:33:19浏览次数:53  
标签:简洁 nn torch iter 3.7 init softmax net d2l

1. 导入包,加载Mnist数据集

 2.

代码:

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# PyTorch不会隐式地调整输入的形状。因此,
# 我们加入一个Flatten()层展平。Flatten()就是把任何维度的Tensor变成一个2D的tensor
# 第0维度保留,剩下的维度全部展成一个向量
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

#init_weights这个函数会对每个layer做一次,m就是我们当前的layer
def init_weights(m):
    #如果m是一个nn.Linear
    if type(m) == nn.Linear:
        #那么我们就把它的weight初始化成(均值默认为0)方差为1
        nn.init.normal_(m.weight, std=0.01)
    return

#把init_weights函数应用到我们的net上来
#即按照每一层去跑一下这个函数
net.apply(init_weights);

#定义损失函数为交叉熵损失
loss = nn.CrossEntropyLoss(reduction='none')

#定义训练器为梯度下降
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

 

标签:简洁,nn,torch,iter,3.7,init,softmax,net,d2l
From: https://www.cnblogs.com/pkuqcy/p/17443398.html

相关文章

  • pytorch1.4.0 CUDA11.0 python3.7安装记录
    参考过程CUDA安装教程CUDA教程2找到自己电脑显卡的cuda版本CUDA是什么版本是11.0.140安装CUDA11.1下载链接,但是我们不用这个我们用的是11.0最新版的下载地址下载选项设置(害,整整2个多G啊)。可以在下载按钮的地方右键,复制链接,然后在迅雷下面下载。虽然慢但是稳定。不过用Chrome复......
  • 3.3 线性回归的简洁实现
    importnumpyasnpimporttorchfromtorch.utilsimportdatafromd2limporttorchasd2lfromtorchimportnn#nn是神经网络(NeuralNetworks)的缩写3.3.1生成数据集true_w=torch.tensor([2,-3.4])#与上一节类似生成数据集true_b=4.2features,labels=......
  • 3.4 softmax回归
    3.4.1分类问题整节理论知识,详见书本。3.4.2网络架构整节理论知识,详见书本。3.4.3全连接层的参数开销整节理论知识,详见书本。3.4.4softmax运算整节理论知识,详见书本。3.4.5小批量样本的向量化整节理论知识,详见书本。3.4.6损失函数整节理论知识,详见书本。3.4.7......
  • Softmax
    Softmax将输出的离散值转换成概率值,且所有情况的概率之和为1。求导pytorch实现......
  • 3.6 Softmax回归的从零开始实现
    我们首先导入相关的包,并读入训练和测试所用的数据集图片的DataLoader: 这里面d2l.load_data_fashion_mnist(batch_size)读入训练和测试所用的图像数据集的DataLoader。 1.初始化模型参数Softmax回归模型参数包括W、b。假设输入特征数量为num_inputs,输出的数量(类别的数量)为n......
  • fastposter v2.15.0 从繁琐到简单,简洁好用的海报生成器
    fastposterv2.15.0从繁琐到简单,简洁好用的海报生成器从繁琐到简单,简洁好用的海报生成器我很高兴向大家推荐一款令人兴奋的工具——Fastposter海报生成器。作为一名开发者,我们深知在项目中创建专业级海报的重要性,但常常面临时间和设计技能的限制。现在,Fastposter海报生成器为我们......
  • fastposter v2.15.0 从繁琐到简单,简洁好用的海报生成器
    fastposterv2.15.0从繁琐到简单,简洁好用的海报生成器从繁琐到简单,简洁好用的海报生成器我很高兴向大家推荐一款令人兴奋的工具——Fastposter海报生成器。作为一名开发者,我们深知在项目中创建专业级海报的重要性,但常常面临时间和设计技能的限制。现在,Fastposter海报生成器为我......
  • 3.7 高次方数的尾数
    #include<stdio.h>intmain(){inti,x,y,last=1;/★变量last保存求得的×的y次方的部分积的后三位*/printf("Inputxandy:An");scanf("%dSd",&x,&y);for(i-1;i<=y;i++)/*×自乘的次数y*/last=last*x%1000;/*将last乘x后对1000取模,即求积的后三位*/printf(&q......
  • hexo-快速、简洁且高效的博客框架
    title:hexo快速、简洁且高效的博客框架abbrlink:38713date:2022-03-0220:30:40tags:Hexo博客框架官方地址:安装代码:npminstallhexo-cli-ghexoinitblogcdblognpminstallhexoserverhexothemeyun食用方法Inyourhexofolder:npminstallhexo-......
  • 002 线性回归的简洁实现
    1.创建数据集数据集的手工创建和上一节一样,人为设置true_w,true_b,以及num_examples(样本的总数量),调用synthetic_data()函数来创建。上一节中我们已经用#@save将这个函数保存在了d2l包中,这里我们直接调用就可以了:2.读取数据集load_array()这个函数接受数据集的features,lab......