创建模型
创建一个具有三级嵌套的模型,结构如图:
import torch
import torch.nn as nn
# 定义子子模块
class SubSubModule(nn.Module):
def __init__(self):
super(SubSubModule, self).__init__()
self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)
def forward(self, x):
return self.conv(x)
# 定义子模块
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.sub_sub_module = SubSubModule() # 实例化子子模块
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.sub_sub_module(x) # 使用子子模块
x = torch.relu(x)
x = self.pool(x)
return x
# 定义主模块
class MainModule(nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.sub_module = SubModule() # 实例化子模块
self.fc = nn.Linear(3 * 16 * 16, 10) # 假设输入图像大小为 32x32
def forward(self, x):
x = self.sub_module(x) # 使用子模块
x = x.view(x.size(0), -1) # 展平特征图
x = self.fc(x)
return x
# 实例化主模块
model = MainModule()
# 打印模型结构
print(model)
使用print直接打印
直接使用print函数打印,会以整个模型为单位打印
# 实例化主模块
model = MainModule()
# 打印模型结构
print(model)
使用named_children()函数打印模型的子模块
named_children()只会打印children,也就是子模块,至于孙子,曾孙子...一律不打印,即 子子模块及以下的都都不会打印
#打印模型的子模块
for name, module in model.named_children():
print(name, module)
使用named_modules函数打印模型的子模块
named_modules从命名就可以看出,会遍历模型中的所有模块(与named_children()恰恰相反),从主模块到子模块到子子模块到子子...子模块,每一个模块都会打印出来
#打印模型的所有模块
for name, module in model.named_modules():
print(name, module)
使用named_parameters()函数打印模型的可学习参数
#打印模型的可学习参数
for name, param in model.named_parameters():
print(name, param.size())
标签:__,named,打印,nn,self,torch,模块
From: https://www.cnblogs.com/seekwhale13/p/18278306