首页 > 其他分享 >使用不同函数打印torch.nn模型——print(model),named_children(),named_modules():

使用不同函数打印torch.nn模型——print(model),named_children(),named_modules():

时间:2024-07-01 16:43:32浏览次数:19  
标签:__ named 打印 nn self torch 模块

创建模型

创建一个具有三级嵌套的模型,结构如图:
image

import torch
import torch.nn as nn

# 定义子子模块
class SubSubModule(nn.Module):
    def __init__(self):
        super(SubSubModule, self).__init__()
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x)

# 定义子模块
class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.sub_sub_module = SubSubModule()  # 实例化子子模块
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.sub_sub_module(x)  # 使用子子模块
        x = torch.relu(x)
        x = self.pool(x)
        return x

# 定义主模块
class MainModule(nn.Module):
    def __init__(self):
        super(MainModule, self).__init__()
        self.sub_module = SubModule()  # 实例化子模块
        self.fc = nn.Linear(3 * 16 * 16, 10)  # 假设输入图像大小为 32x32

    def forward(self, x):
        x = self.sub_module(x)  # 使用子模块
        x = x.view(x.size(0), -1)  # 展平特征图
        x = self.fc(x)
        return x

# 实例化主模块
model = MainModule()

# 打印模型结构
print(model)

使用print直接打印

直接使用print函数打印,会以整个模型为单位打印

# 实例化主模块
model = MainModule()

# 打印模型结构
print(model)

image

使用named_children()函数打印模型的子模块

named_children()只会打印children,也就是子模块,至于孙子,曾孙子...一律不打印,即 子子模块及以下的都都不会打印

#打印模型的子模块
for name, module in model.named_children():
    print(name, module)

image
image

使用named_modules函数打印模型的子模块

named_modules从命名就可以看出,会遍历模型中的所有模块(与named_children()恰恰相反),从主模块到子模块到子子模块到子子...子模块,每一个模块都会打印出来

#打印模型的所有模块
for name, module in model.named_modules():
    print(name, module)

image

使用named_parameters()函数打印模型的可学习参数

#打印模型的可学习参数
for name, param in model.named_parameters():
    print(name, param.size())

image

标签:__,named,打印,nn,self,torch,模块
From: https://www.cnblogs.com/seekwhale13/p/18278306

相关文章

  • WhaleStudio 2.6正式发布,WhaleTunnel同步性能与连接器数量再创新高!
    在这个数据驱动的大模型时代,数据集成的作用和意义愈发重要。数据不仅仅是信息的载体,更是推动企业决策和创新的关键因素。作为全球最流行的批流一体数据集成工具,WhaleTunnel随着WhaleStudio2.6版本正式发布,带来了多项功能增强和新特性,性能大幅提升,连接器和功能方面也有大量更新......
  • MKLDNN
    mkldnn的文件目录结构如下:doc/:文档说明,基本在http://intel.github.io/mkl-dnn/index.html中已经展示 advanced/:关于int8量化和版本更新的说明 build/:关于build、build_options和link design/:关于memorylayout(format)的图片 performance_considerations/:关于性能调试、......
  • CentOS 7报错Erro:NetworkManager is not running怎么处理?
    CentOS7系统报错Error:NetworkManagerisnotrunning,意思是NetworkManager未在运行,NetworkManager是Linux系统上管理网络设置的守护进程,负责自动处理和配置网络连接,未运行可能会导致网络连接问题。遇到报错Error:NetworkManagerisnotrunning我们该如何处理呢?今天飞飞和你分......
  • AtCoder Beginner Contest 360
    A-AHealthyBreakfast(abc360A)题目大意给定一个字符串包含RMS,问R是否在S的左边。解题思路比较R和S的下标,谁小即谁在左边。神奇的代码#include<bits/stdc++.h>usingnamespacestd;usingLL=longlong;intmain(void){ios::sync_with_stdio(false);......
  • 独家原创 | Matlab实现CNN-Transformer多变量回归预测
    独家原创|Matlab实现CNN-Transformer多变量回归预测目录独家原创|Matlab实现CNN-Transformer多变量回归预测效果一览基本介绍程序设计参考资料效果一览基本介绍1.Matlab实现CNN-Transformer多变量回归预测;2.运行环境为Matlab2023b及以上;3.data为数......
  • webrtc 的datachannel在golang中的使用
    因为在发送端需要接收一些接收端的统计信息,而且具有不可丢失的需求,所以采取利用datachannel进行传输。datachannel是基于sctp协议的传输通道,sctp可提供按需可靠到达的服务,在datachannel中可以设置是否按序,是否可靠,最大重传次数,数据最大保存时间(当数据超过保存时间仍未发出时将被丢......
  • 【CNN】用MNIST测试各种CNN网络模型性能
    使用MNIST测试各类CNN网络性能,在此记录,以便按需选择网络。除了第一个CNN为自己搭的以外,其余模型使用Pytorch官方模型,这些模型提出时是在ImageNet上进行测试,在此补充在MNIST上的测试。另外时间有限,每种模型只跑一次得出测试数据,实验结果仅供参考各种参数:训练集60000、测......
  • mysql默认存储引擎--innodb存储引擎(详解)
    官方解释:    InnoDB,是MySQL的数据库引擎之一,现为MySQL的默认存储引擎,为MySQLAB发布binary的标准之一。InnoDB由InnobaseOy公司所开发,2006年五月时由甲骨文公司并购。与传统的ISAM与MyISAM相比,InnoDB的最大特色就是支持了ACID兼容的事务(Transaction)功能,类似于Postgre......
  • 10分钟安装好torch的GPU版本(Windows)
    pytorch-gpu1.确定cuda版本2.确定Python版本3开始下载-cu118-cp383.1下载cuda3.2下载torchvision4.下载好了5.开始安装6.开始验证1.确定cuda版本nvcc-V版本为11.8,一会下载的版本为cu1182.确定Python版本确定python版本为为3.8,一会下载为cp38......
  • Fastapi 项目第二天首次访问时数据库连接报错问题Can't connect to MySQL server
    问题描述Fastapi项目使用sqlalchemy连接的mysql数据库,每次第二天首次访问数据库相关操作,都会报错:sqlalchemy.exc.OperationalError:(pymysql.err.OperationalError)(2003,"Can'tconnecttoMySQLserveron'x.x.x.x'([Errno111]Connectionrefused)")问题分析从出......