首页 > 其他分享 >深度学习pytorch——nn.Module(持续更新)

深度学习pytorch——nn.Module(持续更新)

时间:2024-03-30 12:59:36浏览次数:29  
标签:__ features nn self Module pytorch net Linear

作为一个初学者,发现构建一个简单的线性模型都能看到nn.Module的身影,初学者疑惑了,nn.Module到底是干什么的,如此形影不离,了解之后,很牛。

1、nn.Module是所有层的父类,比如Linear、BatchNorm2d、Conv2d、ReLU、Sigmoid、ConvTranposed、Dropout等等这些都是它的儿子(子类),你可以直接拿来使用。

2、nn.Module还支持一个nn.Module嵌套另一个nn.Module。

3、并且可以自动完成forward,你只需要nn.Sequential()这个容器就可以了,代码示例如下:

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.net = nn.Sequential(BasicNet(),
                                 nn.ReLU(),
                                 nn.Linear(3, 2))

    def forward(self, x):
        return self.net(x)

4、深度学习中参数可谓是量产,如果手动进行参数管理将会是一个庞大的工程,导师问你在干什么就不是在训练模型了,而是处理那些无处安放的参数,而使用nn.Module就提供的parameters就可以秒出结果,代码示例如下:

    for name, t in net.named_parameters():
        print('parameters:', name, t.shape)
    # parameters: net.0.net.weight torch.Size([3, 4])
    # parameters: net.0.net.bias torch.Size([3])
    # parameters: net.2.weight torch.Size([2, 3])
    # parameters: net.2.bias torch.Size([2])

然后将参数直接传入优化器进行优化,代码示例如下:

optimizer=optim.SGD(net.parameters(),lr=1e-3)

5、有很多的孩子,并且你还可以很简单的知道他孩子长什么样,我先介绍一下他的孩子们:

我们可以通过 net.named_children()了解他的亲孩子(children),也就是直系亲属,net.named_modules()了解他所有的孩子(modules),直系亲属外亲都算,代码示例如下:

    for name, m in net.named_children():
        print('children:', name, m)
    # children: net Sequential(
    #     (0): BasicNet(
    #     (net): Linear(in_features=4, out_features=3, bias=True)
    # )
    # (1): ReLU()
    # (2): Linear(in_features=3, out_features=2, bias=True)
    # )

    for name, m in net.named_modules():
        print('modules:', name, m)

# modules:  Net(
#   (net): Sequential(
#     (0): BasicNet(
#       (net): Linear(in_features=4, out_features=3, bias=True)
#     )
#     (1): ReLU()
#     (2): Linear(in_features=3, out_features=2, bias=True)
#   )
# )
# modules: net Sequential(
#   (0): BasicNet(
#     (net): Linear(in_features=4, out_features=3, bias=True)
#   )
#   (1): ReLU()
#   (2): Linear(in_features=3, out_features=2, bias=True)
# )
# modules: net.0 BasicNet(
#   (net): Linear(in_features=4, out_features=3, bias=True)
# )
# modules: net.0.net Linear(in_features=4, out_features=3, bias=True)
# modules: net.1 ReLU()
# modules: net.2 Linear(in_features=3, out_features=2, bias=True)

6、可以非常方便的将网络运行在不同的设备,这里的设备是指cuda、gpu之类的,代码示例如下:

device = torch.device('cuda')
net = Net()
net.to(device)

7、方便对模型进行保存和加载,一个模型一般需要训练好久,但是我们并没有如此连续的时间,我们可以将现在训练好的模型进行保存,下次加载继续训练,代码示例如下:

# 加载
net.load_state_dict(torch.load('ckpt.mdl'))
# 保存
torch.save(net.state_dict(),'ckpt.mdl')

8、方便进行训练和测试状态的切换,代码示例如下:

# 训练
net.train()
# 测试
net.eval()

9、可以实现自己构建的模型,代码示例如下:

# 以下是一个展平的实现
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)  #[b,打平],保留b,其他的全部打平



class TestNet(nn.Module):

    def __init__(self):
        super(TestNet, self).__init__()

        self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
                                 nn.MaxPool2d(2, 2),
                                 Flatten(),
                                 nn.Linear(1*14*14, 10))

    def forward(self, x):
        return self.net(x)
# 构建自己的线性层
class MyLinear(nn.Module):

    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()
        # 这里使用nn.Parameter代表了可以回传给nn.model,进行更新,所以不用写requires_grad = True
        # requires_grad = True
        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x

如果你还有什么更好的idea,欢迎分享!!!

标签:__,features,nn,self,Module,pytorch,net,Linear
From: https://blog.csdn.net/2201_76139143/article/details/137163372

相关文章

  • MogDB 安装解压错误:cannot run bzip2: No such file or directory
    MogDB安装解压错误:cannotrunbzip2:Nosuchfileordirectory本文出处:https://www.modb.pro/db/403662问题症状MogDB安装时,涉及两个步骤解压,第一步解压缩tar包:[root@enmotech~]#tar-xvfMogDB-2.1.1-CentOS-x86_64.tarupgrade_sql.tar.gzMogDB-2.1.1-CentOS-64bit......
  • MySQL的InnoDB引擎的事务原理以及MVCC
    目录一、事务原理二、redolog三、undolog四、MVCC    1.基础概念    2.隐藏字段    3.undolog        4.readview        5.原理分析一、事务原理        1).事务        事务是一组操作的集合,它......
  • clean maven工程报错: Cannot find JRE '1.8 (1)'. You can specify JRE to run maven
    在双击Maven的clean时,报错:CannotfindJRE'1.8(1)'.YoucanspecifyJREtorunmavengoalsinSettings原因可能是自己之前下载的是JDK17,并且IDEA认为该JDK为默认JDK,而我的Maven项目设置使用的是JDK8,因此报错。解决方案如下:点击File-settingBuild,Execution,Deploy......
  • atcoder beginner 346 题解
      看到别人的视频讲解 AtCoderBeginnerContest346A至G題讲解bydreamoon C如果用sort写,那么再从小到大遍历也需要写几行#include<cstdio>#include<cstdlib>#include<cstring>#include<cmath>#include<cstdbool>#include<string>#include<......
  • windows下nginx-rtmp-module的编译方法
    ForewordLinux为当前nginx添加rtmp模块非常的方便,sudo./configure--add-module+sudomake就完事儿了,但是windows比较复杂,没有包管理器,所以各个模块的源码要自己找,下面是我在windows11下的nginxwithrtmpmodule的编译记录。编译器工具链大概有msvctoolchain,perl......
  • 一行一行讲解深度学习代码(零)如何利用pytorch搭建一个完整的深度学习项目——深度学习
    本文适合没有基础的pytorch深度学习小白和python基础不太好的同学!!建议有基础的同学不要看~文章目录深度学习项目的大致结构(一)数据集加载1.功能2.工具(1)datasets(2)DataLoader(二)数据预处理1.功能2.工具(1)torchvision.transforms(2)Compose()3.实战(1)定义数据集(2)数据预处理......
  • vuex.esm.js:135 Uncaught Error: [vuex] getters should be function but “getters.
    报错vuex.esm.js:135UncaughtError:[vuex]gettersshouldbefunctionbut"getters.mode"inmodule"userModule"is"dark".atassert(vuex.esm.js:135:1)原因:在使用vuex的moulds时index.js中已创建了一个vue实例newVuex.Store,在模块文件中又再创建了一个,导致报......
  • Channel-Wise Autoregressive Entropy Models For Learned Image Compression
    目录简介创新点模型框架信道条件熵模型实验&结果简介熵约束自动编码器的熵模型同时使用前向适应和后向适应。前向自适应利用边信息,可以被有效加入到深度网络中。后向自适应通常基于每个符号的因果上下文进行预测,这需要串行处理,这妨碍了GPU/TPU的有效利用。创新点本文引......
  • Windows安装CUDA 12.1及cudnn
    下载CUDA打开链接(https://developer.nvidia.com/cuda-toolkit-archive)选择 12.1.1 版本 选择Windows->x86_64->10->exe(local)->Download  下载完成后按提示安装到默认路径 下载cudnn点击进入nVidia下载cudnn(https://developer.download.nvidia.com/co......
  • pytorch的基础函数
    [torch.arange]是PyTorch中的一个函数,用于生成一个一维的张量(tensor),其中包含从起始值(包括)到结束值(不包括)的等差数列。这个函数非常类似于Python的内置range函数,但是生成的是PyTorch张量而不是Python列表。torch.arange(start=0,end,step=1,*,out=None,dtype=No......