import torch.nn as nn
import torch
import torch.nn.functional as F
class ConvolutionalAttention(nn.Module):
def __init__(self,
in_channels,
out_channels,
inter_channels,
num_heads=8):
super(ConvolutionalAttention,self).__init__()
assert out_channels % num_heads == 0, \
"out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads)
self.in_channels = in_channels
self.out_channels = out_channels
self.inter_channels = inter_channels
self.num_heads = num_heads
self.norm = nn.BatchNorm2d(in_channels)
self.kv =nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1))
self.kv3 =nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7))
def _act_dn(self, x):
x_shape = x.shape # n,c_inter,h,w
h, w = x_shape[2], x_shape[3]
x = x.reshape(
[x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1]) #n,c_inter,h,w -> n,heads,c_inner//heads,hw
x = F.softmax(x, dim=3)
x = x / (torch.sum(x, dim =2, keepdim=True) + 1e-06)
x = x.reshape([x_shape[0], self.inter_channels, h, w])
return x
def forward(self, x):
# 分两条路径 第一条 先纵向卷积 _act_dn(每个头负责的若干个通道之间计算重要性)再横向卷积 第二条路径相反
x = self.norm(x)
x1 = F.conv2d(
x,
self.kv,
bias=None,
stride=1,
padding=(3,0))
x1 = self._act_dn(x1)
x1 = F.conv2d(
x1, self.kv.transpose(1, 0), bias=None, stride=1,
padding=(3,0))
x3 = F.conv2d(
x,
self.kv3,
bias=None,
stride=1,
padding=(0,3))
x3 = self._act_dn(x3)
x3 = F.conv2d(
x3, self.kv3.transpose(1, 0), bias=None, stride=1,padding=(0,3))
x=x1+x3
return x
if __name__ == '__main__':
block = ConvolutionalAttention(in_channels=32,out_channels=32,inter_channels=64).cuda()
input = torch.rand(1, 32, 64, 64).cuda()
output = block(input)
print(input.size())
print(output.size())
标签:__,heads,带状,卷积,self,channels,num,inter,SCTNet
From: https://www.cnblogs.com/plumIce/p/18569251