首页 > 其他分享 >加入自定义块对fashion_mnist数据集进行softmax分类

加入自定义块对fashion_mnist数据集进行softmax分类

时间:2023-07-24 14:11:39浏览次数:35  
标签:__ fashion nn 自定义 self torch Module init softmax

在之前,我们实现了使用torch自带的层对fashion_mnist数据集进行分类。这次,我们加入一个自己实现的block,实现一个四层的多层感知机进行softmax分类,作为对“自定义块”的代码实现的一个练习。

我们设计的多层感知机是这样的:输入维度为784,在展平层过后,第一层为全连接层,输入输出维度分别为784,256;第二层为全连接层,输入输出维度分别为256,128;第三层为全连接层,输入输出维度分别为128,64;第四层为全连接层(输出层),输入输出维度分别为64,10.代码如下:

import torch
from d2l import torch as d2l
from torch import nn
from torch.nn import functional as F

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10

#输入层784; 隐藏层一784,256;隐藏层二256,128; 隐藏层三128,64; 输出层64,10
#我们用自定义Module实现隐藏层二、隐藏层三。
class practice_Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(256,128)
        self.lin2 = nn.Linear(128,64)
        nn.init.normal_(self.lin1.weight,std=0.01)
        nn.init.normal_(self.lin2.weight,std=0.01)
    def forward(self,X):
        X = self.lin1(X)
        X = F.relu(X)
        X = self.lin2(X)
        X = F.relu(X)
        return X
    
manual_block = practice_Module()
net = nn.Sequential(nn.Flatten(),
                   nn.Linear(784,256),
                   nn.ReLU(),
                   nn.Dropout(0.2),
                    manual_block,
                    nn.Dropout(0.3),
                    nn.Linear(64,10)
                   )

def init_weight(m):
    if m == nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
    return 
net.apply(init_weight)

loss = torch.nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs = 20
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

 首先在我们自定义的模块中,初始化函数__init__中定义我们需要的两个层lin1和lin2.上面的代码在抽象的类practice_Module中的初始化函数__init__中进行了参数初始化,也就是说默认情况下用这个类创建的所有对象都会进行这样的默认初始化。

当然,也可以按我们的需要对具体的模块对象进行参数初始化。

然后在forward函数中定义这个模块进行的操作,即先让数据经过线性层lin1,激活,再经过线性层lin2,激活,然后输出。return X语句,return的X值作为输出,就会作为nn.Sequential中的下一层输入。

 注意:这里面的前向传播函数名必须是forward,而不能是其他的,改成其他的就会报错:

Module [practice_Module] is missing the required "forward" function

这也是为什么practice_Module类的实例在nn.Sequential中可以自动计算的原因,是因为系统会自动找到该实例中的方法forward并执行。

下面的语句初始化了一个practice_Module类的实例。可以这样理解:practice_Module是一个抽象的网络结构,而manual_block这个实例才是一个具体的我们需要的模型。

 可以用如下代码对模块实例进行初始化:

 在nn.Sequential中加入我们自定义的模块是非常简单的:

 init_weight()函数对torch中定义好了的层进行参数初始化:

 

标签:__,fashion,nn,自定义,self,torch,Module,init,softmax
From: https://www.cnblogs.com/pkuqcy/p/17576884.html

相关文章

  • Python list里面定义自定义类型
    PythonList中定义自定义类型在Python中,List(列表)是一种非常常见且强大的数据结构。它允许我们以有序的方式存储和访问多个元素。在List中,我们可以存储各种类型的数据,包括整数、浮点数、字符串等。但是,Python的灵活性还允许我们在List中存储自定义的数据类型,从而提供更高的灵活性和......
  • uboot添加自定义命令 CMD
    原文:https://blog.csdn.net/weixin_41252596/article/details/128317180有些用户玩uboot比较花,除了引导系统还要做一堆驱动,有些驱动除了按流程执行还要留出命令行接口用于调试。比如我现在的设备外接了个fpga,fpga和cpu的接口已经做好了,但是为了调试要新增个命令,在命令行下手动与f......
  • 自定义View wrap_content不起作用
    设置wrap_content后,自定义View依然是match_parent的效果ref:Android自定义View:为什么你设置的wrap_content不起作用?-简书(jianshu.com)问题描述:最近实现了一个QQ消息气泡的功能,但是在测试的时候发现尽管自定义View设置了wrap_content的宽高,但是依然占据了所有的父容器空......
  • Elasticsearch自定义分词器
    分词发生时期分词器的处理过程发生在IndexTime和SearchTime两个时期IndexTime:文档写入并创建倒排索引时期,其分词逻辑取决于映射参数analyzer。SearchTime:搜索发生时期,其分词仅对搜索词产生作用。分词器的组成切词器(Tokenizer):用于定义切词(分词)逻辑。词项过滤器(TokenF......
  • 【易语言】自定义数据类型排序
    .版本2.子程序自定义类型数组排序.参数排序组,特殊成员,参考数组.局部变量交换,逻辑型.局部变量未比数据,整数型.局部变量交换变量,特殊成员.局部变量N,整数型交换=真未比数据=取数组成员数(排序组).判断循环首(交换=真)交换=假.变量循......
  • Unity3D 自定义类的数组初始化
    实现功能:1.自定义类,用于保存数据等2.初始化数组代码:publicclasstree_elem{//位置publicintx,y;//大小【相对于原始大小的比例】最后随机分配publicfloatsize;//树的类型,最后随机分配publictree_kindkind;publictree_ele......
  • spring boot 自定义组件
    SpringBoot自定义组件SpringBoot是一个用于快速构建独立的、生产级别的Spring应用程序的框架。它提供了许多开箱即用的组件,可以简化开发流程并提高开发效率。但是,在某些情况下,我们可能需要自定义一些组件来满足特定的需求。本文将介绍如何在SpringBoot中自定义组件,并提......
  • element ui 分页组件自定义每页条数page-size
       参考代码:<divstyle="display:flex;"><el-pagination:total="total":pager-count="5":page-size="searchForm.pageSize":current-page=&q......
  • 自定义异常类
    1'''21.语法说明3自定义异常类是指在编程中,根据实际需要创建的用于表示特定错误或异常情况的类。4通过自定义异常类,我们可以更好地组织和处理代码中可能出现的异常情况。5classCustomException(Exception):6def__init__(self,message):7......
  • WPF .net6 自定义启动入口 、 自定义Main函数、自定义 STAThread 方法
    前言:  为了解决程序开启自启动问题参考资料  CustomEntryPointsinWPFon.NETCore链接https://blog.magnusmontin.net/2020/01/31/custom-entry-point-wpf-net-core/  CreatingacustomMainmethodinaWPFapplication链接https://www.meziantou.net/creat......