首页 > 其他分享 >StarConv

StarConv

时间:2024-11-14 11:42:34浏览次数:1  
标签:__ dim ConvBN nn StarConv path self

paper

class ConvBN(nn.Module):
    def __init__(self,c1,c2,k=1,s=1,p=None,g=1,d=1):
        super(ConvBN, self).__init__()
        if p is None:
            p=k//2 if isinstance(k,int) else [x//2 for x in k]
        self.conv=nn.Conv2d(c1,c2,k,s,p,groups=g,dilation=d,bias=False)
        self.bn=nn.BatchNorm2d(c2)
        self.act=nn.SiLU()
    def forward(self,x):
        return self.bn(self.conv(x))
        # return self.act(self.bn(self.conv(x)))
class StarBlock(nn.Module):
    def __init__(self, dim, mlp_ratio=3, drop_path=0.):
        super().__init__()
        self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, g=dim)
        self.f1 = ConvBN(dim, mlp_ratio * dim, 1)
        self.f2 = ConvBN(dim, mlp_ratio * dim, 1)
        self.g = ConvBN(mlp_ratio * dim, dim, 1)
        self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, g=dim)
        self.act = nn.ReLU6()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x1, x2 = self.f1(x), self.f2(x)
        x = self.act(x1) * x2
        x = self.dwconv2(self.g(x))
        x = input + self.drop_path(x)
        return x

if __name__ == '__main__':
    x = torch.randn(3, 32, 64, 64).cuda() # 输入 B C H W
    model = StarBlock(32).cuda()
    print(x.shape)
    print(model(x).shape)

标签:__,dim,ConvBN,nn,StarConv,path,self
From: https://www.cnblogs.com/plumIce/p/18545677

相关文章