torch.nn 是 pytorch 的一个神经网络库(nn 是 neural network 的简称)。
Containers
torch.nn 构建神经网络的模型容器(Containers,骨架)有以下六个:
- Module
- Sequential
- ModuleList
- ModuleDict
- ParameterList
- ParameterDict
本博文将介绍神经网络的基本骨架——nn.module的使用。
Module
所有神经网络模块的基类。自定义的模型也应该继承该类。
自定义模型继承该类要重写 __init__()
和 forward()
:
- 在
__init__()
里构建子模块,将子模块作为当前模块类的常规属性。一般将网络中具有可学习参数的层放在__init__
中。 forward()
前向传播函数,定义每次调用时执行的计算,应该被所有子类重写。
# 官方案例
import torch.nn as nn
import torch.nn.functional as F
# 自定义模型
class Model(nn.Module):
def __init__(self):
super().__init__() # 在对子类进行赋值之前,必须对父类进行__init__调用。
# 构建子模块
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
# 前向传播函数
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
# 模型调用
x = torch.randn(3, 1, 10, 20)
model = Model()
y = model(x)
为什么 forward()
方法能在model(x)
时自动调用?
在 python 中当一个类定义了 __call__
方法,则这个类实例就成为了可调用对象。而nn.Module
中的 __call__
方法中调用了 forward()
方法,因此继承了 nn.Module
的子类对象就可以通过 model(x)
来调用 forward()
函数。
只要在 nn.Module
的子类中定义了 forward
函数,backward
函数就会被自动实现(利用Autograd)。