为了满足即插即用的功能,本博客重写nonlocal块,并可以根据自己的喜好选择2D卷积或者3D卷积,并可以选择是否使用bn层或pool。
nonlocalblock模块图
3D代码如下:
class Nonlocal_3d(nn.Module):
def __init__(self, dim, dim_inner, pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
norm_module=nn.BatchNorm3d, instantiation="softmax"):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
norm (bool): If true, add the final batch norm of the Non-local block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super().__init__()
self.dim = dim
self.dim_inner = dim_inner
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self.norm = norm
self.pool_size = pool_size
self.use_pool = pool_size is not None
self.instantiation = instantiation
# 构建网络
self._construct_nonlocal(norm_module)
def _construct_nonlocal(self, norm_module):
self.conv_theta = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_phi = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_g = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv3d(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)
if self.norm:
self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
self.bn.transform_final_bn = True
if self.use_pool:
self.pool = nn.MaxPool3d(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])
def forward(self, x):
x_identity = x
N, C, T, H, W = x.size()
theta = self.conv_theta(x)
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner ** -0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError("Unknown norm type {}".format(self.instantiation))
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, TxHxW) => (N, C, T, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
p = self.conv_out(theta_phi_g)
if self.norm:
p = self.bn(p)
return p + x_identity
2D代码如下
class Nonlocal_2d(nn.Module):
def __init__(self, dim, dim_inner, pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
norm_module=nn.BatchNorm2d, instantiation="softmax"):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
norm (bool): If true, add the final batch norm of the Non-local block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super().__init__()
self.dim = dim
self.dim_inner = dim_inner
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self.norm = norm
self.pool_size = pool_size
self.use_pool = pool_size is not None
self.instantiation = instantiation
# 构建网络
self._construct_nonlocal(norm_module)
def _construct_nonlocal(self, norm_module):
self.conv_theta = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_phi = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_g = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)
if self.norm:
self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
self.bn.transform_final_bn = True
if self.use_pool:
self.pool = nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])
def forward(self, x):
x_identity = x
N, C, H, W = x.size()
theta = self.conv_theta(x)
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, HxW) * (N, C, HxW) => (N, HxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner ** -0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError("Unknown norm type {}".format(self.instantiation))
# (N, HxW, HxW) * (N, C, HxW) => (N, C, HxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, HxW) => (N, C, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, H, W)
p = self.conv_out(theta_phi_g)
if self.norm:
p = self.bn(p)
return p + x_identity
通用代码如下:
#!/usr/bin/env python3
import torch
import torch.nn as nn
class Nonlocal(nn.Module):
def __init__(self, dim, dim_inner, typeis='3D', pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
norm_module=nn.BatchNorm3d, instantiation="softmax"):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
typeis (string): supports two different convolution classes:
2D convolution or 3D convolution.
2D pool or 3D pool
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
norm (bool): If true, add the final batch norm of the Non-local block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super().__init__()
self.dim = dim
self.dim_inner = dim_inner
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self.norm = norm
self.pool_size = pool_size
self.use_pool = pool_size is not None
self.instantiation = instantiation
self.type = typeis.lower()
assert self.type in ['2d', '3d'], "dim of Nonlocal must be at ['2D','3D']."
# 构建网络
self._construct_nonlocal(norm_module)
def _construct_nonlocal(self, norm_module):
conv = eval('nn.Conv' + self.type) # 自动适应2d和3d
self.conv_theta = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_phi = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_g = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
self.conv_out = conv(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)
if self.norm:
self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
self.bn.transform_final_bn = True
if self.use_pool:
pool = eval('nn.MaxPool' + self.type)
self.pool = pool(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])
def forward(self, x):
x_identity = x
if self.type=='3d':
N, C, T, H, W = x.size()
elif self.type =="2d":
N, C, H, W = x.size()
theta = self.conv_theta(x)
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner ** -0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError("Unknown norm type {}".format(self.instantiation))
if self.type=='3d':
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, TxHxW) => (N, C, T, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
elif self.type =="2d":
# (N, HxW, HxW) * (N, C, HxW) => (N, C, HxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, HxW) => (N, C, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, H, W)
p = self.conv_out(theta_phi_g)
if self.norm:
p = self.bn(p)
return p + x_identity
构建参数说明:
def __init__(self, dim, dim_inner, typeis='3D', pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
norm_module=nn.BatchNorm3d, instantiation="softmax"):
dim:输入维度
dim_inner: nonlocalblock内部变化的维度
typeis: 决定卷积的类别,如果是3D,则为nn.Conv3d;如果是2D,则为nn.Conv2d;
pool_size: 池化层尺寸,如果为None,则表示不使用pool层;如果使用,则输入pool的大小,根据typeis变化,2D为[x,x],3D为[x,x,x].
norm: 是否使用bn。
norm_module:使用什么正则化,nn对象。默认nn.BatchNorm3d。
instantiation:nonlocal中使用的激活函数类型。