class ConvBN(nn.Module):
def __init__(self,c1,c2,k=1,s=1,p=None,g=1,d=1):
super(ConvBN, self).__init__()
if p is None:
p=k//2 if isinstance(k,int) else [x//2 for x in k]
self.conv=nn.Conv2d(c1,c2,k,s,p,groups=g,dilation=d,bias=False)
self.bn=nn.BatchNorm2d(c2)
self.act=nn.SiLU()
def forward(self,x):
return self.bn(self.conv(x))
# return self.act(self.bn(self.conv(x)))
class StarBlock(nn.Module):
def __init__(self, dim, mlp_ratio=3, drop_path=0.):
super().__init__()
self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, g=dim)
self.f1 = ConvBN(dim, mlp_ratio * dim, 1)
self.f2 = ConvBN(dim, mlp_ratio * dim, 1)
self.g = ConvBN(mlp_ratio * dim, dim, 1)
self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, g=dim)
self.act = nn.ReLU6()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x1, x2 = self.f1(x), self.f2(x)
x = self.act(x1) * x2
x = self.dwconv2(self.g(x))
x = input + self.drop_path(x)
return x
if __name__ == '__main__':
x = torch.randn(3, 32, 64, 64).cuda() # 输入 B C H W
model = StarBlock(32).cuda()
print(x.shape)
print(model(x).shape)
标签:__,dim,ConvBN,nn,StarConv,path,self
From: https://www.cnblogs.com/plumIce/p/18545677