paper
`
import torch
import torch.nn as nn
def build_act_layer(act_type):
#Build activation layer
if act_type is None:
return nn.Identity()
assert act_type in ['GELU', 'ReLU', 'SiLU']
if act_type == 'SiLU':
return nn.SiLU()
elif act_type == 'ReLU':
return nn.ReLU()
else:
return nn.GELU()
class ElementScale(nn.Module):
#A learnable element-wise scaler.
def __init__(self, embed_dims, init_value=0., requires_grad=True):
super(ElementScale, self).__init__()
self.scale = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)),
requires_grad=requires_grad
)
def forward(self, x):
return x * self.scale
class ChannelAggregationFFN(nn.Module):
"""An implementation of FFN with Channel Aggregation.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
kernel_size (int): The depth-wise conv kernel size as the
depth-wise convolution. Defaults to 3.
act_type (str): The type of activation. Defaults to 'GELU'.
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
"""
def __init__(self,
embed_dims,
kernel_size=3,
act_type='GELU',
ffn_drop=0.):
super(ChannelAggregationFFN, self).__init__()
self.embed_dims = embed_dims
self.feedforward_channels = int(embed_dims * 4)
self.fc1 = nn.Conv2d(
in_channels=embed_dims,
out_channels=self.feedforward_channels,
kernel_size=1)
self.dwconv = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=self.feedforward_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True,
groups=self.feedforward_channels)
self.act = build_act_layer(act_type)
self.fc2 = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
self.decompose = nn.Conv2d(
in_channels=self.feedforward_channels, # C -> 1
out_channels=1, kernel_size=1,
)
self.sigma = ElementScale(
self.feedforward_channels, init_value=1e-5, requires_grad=True)
self.decompose_act = build_act_layer(act_type)
def feat_decompose(self, x):
# x_d: [B, C, H, W] -> [B, 1, H, W]
t=self.decompose(x) # 将多通道用一个通道来表示
t=self.decompose_act(t) # 对单通道应用GELU激活函数 增加非线性 帮助模型学习更加复杂的模式
t=x - t #原始特征图减去t 去除或削弱了x与temp相似的特征 如果t是全局或主要特征 那么x-t可以理解成局部或差异信息
t=self.sigma(t) # 包含一个可学习的参数 用来调整每个通道的权重
x = x + t # 将原特征图和缩放之后的特定信息相加
return x
def forward(self, x):
# proj 1
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
# proj 2
x = self.feat_decompose(x)
x = self.fc2(x)
x = self.drop(x)
return x
if name == 'main':
input = torch.randn(1, 64, 32, 32).cuda()# 输入 B C H W
block = ChannelAggregationFFN(embed_dims=64).cuda()
output = block(input)
print(input.size())
print(output.size())
`