class Basic(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, padding=0, bias=False):
super(Basic, self).__init__()
self.out_channels = out_planes
groups = 1
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
class ChannelPool(nn.Module):
def __init__(self):
super(ChannelPool, self).__init__()
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class SAB(nn.Module):
def __init__(self):
super(SAB, self).__init__()
kernel_size = 5
self.compress = ChannelPool()
self.spatial = Basic(2, 1, kernel_size, padding=(kernel_size - 1) // 2, bias=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out)
return x * scale
class RAB(nn.Module):
def __init__(self, in_channels=64, out_channels=64, bias=True):
super(RAB, self).__init__()
kernel_size = 3
stride = 1
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
self.res = nn.Sequential(*layers)
self.sab = SAB()
def forward(self, x,path):
# path1 和 path2 的计算结果完全相同 path2 没有保留 不需要保存的中间值
if path==1:
x1 = x + self.res(x)
x2 = x1 + self.res(x1)
x3 = x2 + self.res(x2)
x3_1 = x1 + x3
x4 = x3_1 + self.res(x3_1)
x4_1 = x + x4
# sab:在通道维度上求平均值和最大值 将得到的两个通道叠加在一起 然后再通过卷积变成1通道,再通过sigmoid得到权重,然后让每个通道都和这个1通道的权重相乘
x5 = self.sab(x4_1)
x5_1 = x + x5
return x5_1
else:
x1 = x + self.res(x)
x2 = x1 + self.res(x1)
x2 = x2 + self.res(x2) + x1
x2 = x2 + self.res(x2) + x
x2 = self.sab(x2) + x
return x2
if __name__ == '__main__':
block1 = RAB(in_channels=3,out_channels=3,bias=True)
input = torch.rand(1, 3, 9, 9)
output1 = block1(input,1)
output2 = block1(input,2)
print("")
标签:__,kernel,RAB,self,模块,x2,DRANet,out,size
From: https://www.cnblogs.com/plumIce/p/18568524