首页 > 其他分享 >sMLP

sMLP

时间:2024-11-12 12:30:38浏览次数:1  
标签:__ sMLP nn self torch channels permute

paper

import torch.nn as nn
import torch
class sMLPBlock(nn.Module):
    '''
    稀疏MLP 不是一个样本的所有特征通过全连接层 而是部分通过全连接层
    '''
    def __init__(self, W, H, channels):
        super().__init__()
        assert W == H
        self.channels = channels
        self.activation = nn.GELU()
        self.BN = nn.BatchNorm2d(channels)
        self.proj_h = nn.Conv2d(H, H, (1, 1))
        self.proh_w = nn.Conv2d(W, W, (1, 1))
        self.fuse = nn.Conv2d(channels*3, channels, (1,1), (1,1), bias=False)
        #  也可以这样写
        # self.proj_h2=nn.Linear(H,H)
        # self.proh_w2=nn.Linear(W,W)
        # self.fuse2=nn.Linear(channels*3,channels)

    def forward(self, x):
        x = self.activation(self.BN(x))
        x_w = self.proj_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        x_h = self.proh_w(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
        x = self.fuse(torch.cat([x, x_h, x_w], dim=1))
        print(x)
        # 写法2
        # x_h2=self.proj_h2(x.permute(0,1,3,2)).permute(0,1,3,2)
        # x_w2=self.proh_w2(x)
        # x_fuse=torch.cat([x_h2,x_w2,x],dim=1)
        # x_2=self.fuse2(x_fuse.permute(0,2,3,1)).permute(0,3,1,2)
        # print(x_2)
        return x

if __name__ == '__main__':
    x = torch.randn(1, 3, 2, 2).cuda() # 输入 B C H W
    model = sMLPBlock(2,2,3).cuda()
    res=model(x)

标签:__,sMLP,nn,self,torch,channels,permute
From: https://www.cnblogs.com/plumIce/p/18541580

相关文章

  • sMLP:稀疏全mlp进行高效语言建模
    这是一篇2022由纽约州立大学布法罗分校和MetaAI发布的论文,它主要的观点如下:具有专家混合(MoEs)的稀疏激活mlp在保持计算常数的同时显着提高了模型容量和表达能力。此外gMLP表明,所有mlp都可以在语言建模方面与transformer相匹配,但在下游任务方面仍然落后。所以论文提出了sMLP,通过......