论文地址:
https://arxiv.org/abs/2207.14284
源码地址:
https://github.com/raoyongming/HorNetGnConvGnConvHorNet
摘要简介:
最近,视觉Transformer在各种任务中取得了巨大成功,这主要得益于基于点积自注意力的新型空间建模机制。在本文中,我们发现视觉Transformer的关键要素,即输入自适应、长距离和高阶空间交互,也可以高效地在基于卷积的框架中实现。
我们提出了一种递归门控卷积(gnConv),它利用门控卷积和递归设计实现高阶空间交互。这种新操作高度灵活且可定制,能与各种卷积变体兼容,并将自注意力中的二阶交互扩展到任意阶,而不会引入大量额外计算。gnConv可以作为一个即插即用的模块,用于改进各种视觉Transformer和基于卷积的模型。
基于gnConv操作,我们构建了一个新的通用视觉骨干网络家族,名为HorNet。在ImageNet分类、COCO目标检测和ADE20K语义分割等多个实验中,HorNet展现出比Swin Transformer和ConvNeXt更优越的性能,而整体架构和训练配置相似。此外,HorNet还表现出对更多训练数据和更大模型尺寸的良好可扩展性。
除了在视觉编码器中的有效性,我们还展示了gnConv可以应用于特定任务的解码器,并在减少计算的同时持续提高密集预测性能。我们的结果表明,gnConv可以成为视觉建模的新基本模块,有效结合视觉Transformer和卷积神经网络(CNN)的优点。相关代码可以在此网址找到。
简单来说,我们发现了一种新的卷积方法,可以像视觉Transformer那样捕捉图像中的高阶空间交互。基于这种新方法,我们创建了一个新的图像处理模型HorNet,它在多个视觉任务中都表现得比现有方法更好。这个新方法不仅适用于图像的主要处理部分,还可以用在特定的任务中,帮助我们更准确地预测图像中的细节,同时减少了计算量。
结构图:
Pytorch版源码:
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_dwconv(dim, kernel, bias):
return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2, bias=bias, groups=dim)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GnConv(nn.Module):
def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
super().__init__()
self.order = order
self.dims = [dim // 2 ** i for i in range(order)]
self.dims.reverse()
self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
if gflayer is None:
self.dwconv = get_dwconv(sum(self.dims), 7, True)
else:
self.dwconv = gflayer(sum(self.dims), h=h, w=w)
self.proj_out = nn.Conv2d(dim, dim, 1)
self.pws = nn.ModuleList(
[nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)]
)
self.scale = s
print('[gconv]', order, '阶与维度=', self.dims, '尺度=%.4f' % self.scale)
def forward(self, x, mask=None, dummy=False):
B, C, H, W = x.shape
fused_x = self.proj_in(x)
pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
dw_abc = self.dwconv(abc) * self.scale
dw_list = torch.split(dw_abc, self.dims, dim=1)
x = pwa * dw_list[0]
for i in range(self.order - 1):
x = self.pws[i](x) * dw_list[i + 1]
x = self.proj_out(x)
return x
class GnBlock(nn.Module):
def __init__(self, dim, shortcut=False, layer_scale_init_value=1e-6):
super().__init__()
self.shortcut = shortcut
self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
self.gnconv = GnConv(dim, order=5)
self.norm2 = LayerNorm(dim, eps=1e-6, data_format='channels_last')
self.pwconv1 = nn.Linear(dim, 2 * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(2 * dim, dim)
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
B, C, H, W = x.shape
if self.gamma1 is not None:
gamma1 = self.gamma1.view(1, C, 1, 1)
else:
gamma1 = 1
x = (x + gamma1 * self.gnconv(self.norm1(x))) if self.shortcut else gamma1 * self.gnconv(self.norm1(x))
input = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm2(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
if self.gamma2 is not None:
gamma2 = self.gamma2.view(1, C, 1, 1)
x = gamma2 * x
else:
gamma2 = 1
x = (input + x) if self.shortcut else x
return x
# 使用一个示例输入进行代码测试
if __name__ == "__main__":
# 创建一个示例输入张量
input_tensor = torch.randn(1, 32, 128, 128) # 批量大小,通道数,高度,宽度
model = GnBlock(32) # 使用与输入通道数相同的维度初始化 GnBlock
output = model(input_tensor) # 前向传播
# 打印输出的尺寸
print("输出尺寸:", output.size())
标签:__,dim,None,nn,self,init,源码,即插即用,门控
From: https://blog.csdn.net/weixin_45694817/article/details/137153863