nn.Sequential
和 nn.ModuleList()
是 PyTorch
中用于管理神经网络模型中的子模块的两种不同的方式。
nn.Sequential
是一个用于构建顺序模型的容器类。它允许按照给定的顺序添加一系列的子模块,并将它们串联在一起形成一个顺序的网络结构。nn.Sequential
可以简化模型的定义和前向传播的编写,特别适用于那些没有复杂控制流程的简单网络结构。通过向 nn.Sequential
中添加子模块,这些子模块会自动按照添加的顺序连接在一起,并形成一个整体的模型。在调用 nn.Sequential
的 forward
方法时,输入数据将按照添加的顺序经过每个子模块,从而实现整个模型的前向传播。
示例使用 nn.Sequential
构建一个简单的模型:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
)
input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)
在这个示例中,我们通过 nn.Sequential
定义了一个顺序模型。顺序模型包含三个子模块:一个线性层、一个 ReLU
激活函数和另一个线性层。当我们调用模型的 forward
方法时,输入数据 input_tensor
将按照添加的顺序依次经过每个子模块,并生成输出数据 output_tensor
。
相比之下,nn.ModuleList()
是一个类似于 Python
列表的容器,用于存储和管理任意数量的子模块。与 nn.Sequential
不同的是,nn.ModuleList()
并不自动连接子模块,而是将其存储为列表的形式。因此,在使用 nn.ModuleList()
定义模型时,我们需要自己定义子模块之间的连接关系。这使得 nn.ModuleList()
更加灵活,适用于那些具有复杂控制流程或需要自定义连接方式的网络结构。
示例使用 nn.ModuleList()
构建一个简单的模型:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.module_list = nn.ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
])
def forward(self, x):
for module in self.module_list:
x = module(x)
return x
model = MyModel()
input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)
在这个示例中,我们定义了一个自定义的模型类 MyModel
,其中使用了 nn.ModuleList()
来存储三个子模块:一个线性层、一个ReLU
激活函数和另一个线性层。在模型的 forward
方法中,我们通过迭代 module_list
中的子模块,依次将输入数据 x
传递给它们,并获取最终的输出。
因此,nn.Sequential
和 nn.ModuleList()
的区别在于自动连接子模块的能力。nn.Sequential
自动按照添加的顺序连接子模块,适用于简单的顺序模型。而 nn.ModuleList()
则需要手动定义子模块之间的连接方式,适用于具有复杂控制流程或自定义连接的模型。
此外,nn.Sequential
还提供了更简洁的语法来定义模型,因为它可以直接通过传入子模块的列表来创建模型。而 nn.ModuleList()
则需要显式地在模型类中定义和初始化子模块。
nn.Sequential
和 nn.ModuleList()
都是 nn.Module
的子类,因此它们都可以作为模型的属性进行注册和管理。