首页 > 其他分享 >SCConv:SRU CRU

SCConv:SRU CRU

时间:2024-11-11 20:43:24浏览次数:1  
标签:dim group CRU self SCConv num SRU channel size

paper
`
import torch
import torch.nn.functional as F
import torch.nn as nn
class GroupBatchnorm2d(nn.Module):
def init(self, c_num: int,
group_num: int = 16,
eps: float = 1e-10
):
super(GroupBatchnorm2d, self).init()
assert c_num >= group_num
self.group_num = group_num
self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
self.eps = eps

def forward(self, x):
    N, C, H, W = x.size()
    x = x.view(N, self.group_num, -1)
    mean = x.mean(dim=2, keepdim=True)
    std = x.std(dim=2, keepdim=True)
    x = (x - mean) / (std + self.eps)
    x = x.view(N, C, H, W)
    return x * self.weight + self.bias

class SRU(nn.Module):
def init(self,
oup_channels: int,
group_num: int = 16,
gate_treshold: float = 0.5,
torch_gn: bool = False
):
super().init()

    self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
        c_num=oup_channels, group_num=group_num)
    self.gate_treshold = gate_treshold
    self.sigomid = nn.Sigmoid()

def forward(self, x):
    gn_x = self.gn(x)  # 一个样本的分成若干个group 每个group内部 做归一化
    w_gamma = self.gn.weight / torch.sum(self.gn.weight) # 每层的权重除所有层的权重和
    w_gamma = w_gamma.view(1, -1, 1, 1) #  修改形状
    reweigts = self.sigomid(gn_x * w_gamma) # 将归一化的输入和每层的权重相乘 得到新的权重
    # Gate
    info_mask = reweigts >= self.gate_treshold # 超过阈值
    noninfo_mask = reweigts < self.gate_treshold # 没超过阈值
    x_1 = info_mask * gn_x # 超过阈值的信息
    x_2 = noninfo_mask * gn_x  # 没超过阈值的信息
    x = self.reconstruct(x_1, x_2)  # 重构两部分信息
    return x
def reconstruct(self, x_1, x_2):
    x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
    x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
    return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)

"""
对输入张量 x 进行分组归一化,得到 gn_x。

计算归一化层的权重 w_gamma,并将其形状调整为 (1, -1, 1, 1)。

使用 sigmoid 激活函数对 gn_x 和 w_gamma 的乘积进行激活,得到 reweigts。

根据 reweigts 和 gate_treshold 生成两个掩码:info_mask 和 noninfo_mask。

使用掩码将 gn_x 分成两部分:x_1 和 x_2。

调用 reconstruct 方法对 x_1 和 x_2 进行重建,并返回结果。"""

class CRU(nn.Module):
'''
alpha: 0<alpha<1
'''

def __init__(self,
             op_channel: int,
             alpha: float = 1 / 2,
             squeeze_radio: int = 2,
             group_size: int = 2,
             group_kernel_size: int = 3,
             ):
    super().__init__()
    self.up_channel = up_channel = int(alpha * op_channel)
    self.low_channel = low_channel = op_channel - up_channel
    self.squeeze1 = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)
    self.squeeze2 = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)
    # up
    self.GWC = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1,
                         padding=group_kernel_size // 2, groups=group_size)
    self.PWC1 = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)
    # low
    self.PWC2 = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1,
                          bias=False)
    self.advavg = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
    # Split
    up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)  # 特征按dim分两份 up_channel  low_channel
    up, low = self.squeeze1(up), self.squeeze2(low)  # up_channel low_channel 的dim 进一步压缩
    # Transform
    Y1 = self.GWC(up) + self.PWC1(up)  # Y1 dim=32  直接还原为初始维度也就是输入的x的dim
    Y2 = torch.cat([self.PWC2(low), low], dim=1)  # 也还原为初始dim
    # Fuse
    out = torch.cat([Y1, Y2], dim=1)  # 拼接完成 得到两倍的x的dim
    out = F.softmax(self.advavg(out), dim=1) * out  # 还是在算每个通道的权重
    out1, out2 = torch.split(out, out.size(1) // 2, dim=1)  # 将两倍的通道的特征图分成两份
    return out1 + out2  # 相加 还原成初始dim

class ScConv(nn.Module):
def init(self,
op_channel: int,
group_num: int = 4,
gate_treshold: float = 0.5,
alpha: float = 1 / 2,
squeeze_radio: int = 2,
group_size: int = 2,
group_kernel_size: int = 3,
):
super().init()
self.SRU = SRU(op_channel,
group_num=group_num,
gate_treshold=gate_treshold)
self.CRU = CRU(op_channel,
alpha=alpha,
squeeze_radio=squeeze_radio,
group_size=group_size,
group_kernel_size=group_kernel_size)

def forward(self, x):
    x = self.SRU(x)
    x = self.CRU(x)
    return x

if name == 'main':
x = torch.randn(3, 32, 64, 64).cuda() # 输入 B C H W
model = ScConv(32).cuda()
print(model(x).shape)

`



标签:dim,group,CRU,self,SCConv,num,SRU,channel,size
From: https://www.cnblogs.com/plumIce/p/18540533

相关文章

  • 团队项目Scrum冲刺-day1
    一、各个成员在Alpha阶段认领的任务成员Alpha陈国金凌枫陈卓恒谭立业廖俊龙曾平凡曾俊涛薛秋昊二、明日各个成员的任务安排成员任务陈国金凌枫陈卓恒谭立业廖俊龙曾平凡曾俊涛薛秋昊三......
  • ffmpeg Audio Filters acrusher
    Reduceaudiobitresolution.Thisfilterisbitcrusherwithenhancedfunctionality.Abitcrusherisusedtoaudiblyreducenumberofbitsanaudiosignalissampledwith.Thisdoesn’tchangethebitdepthatall,itjustproducestheeffect.Materialre......
  • 第 2 篇 Scrum 冲刺博客
    作业要求这个作业属于哪个课程计科34班这个作业的要求在哪里团队作业4——项目冲刺这个作业的目标1.站立式会议2.发布项目燃尽图3.每人的代码/文档签入记录4.适当的项目程序/模块的最新(运行)截图5.每日每人总结会议照片昨日已完成的工作/今天计划完成的工作......
  • 电脑中丢失 vcruntime140.dll 的五种解决方法
    vcruntime140.dll是MicrosoftVisualC++2015RedistributablePackage的一部分,它是一个动态链接库(DLL)文件,主要负责为使用了C++编译器编写的应用程序提供运行时支持。简而言之,vcruntime140.dll包含了程序运行所需的基础函数和数据结构,如内存管理、输入输出操作等。因此,对于很......
  • 第 1 篇 Scrum 冲刺博客
    作业要求这个作业属于哪个课程计科34班这个作业的要求在哪里团队作业4——项目冲刺这个作业的目标1、认领任务2、规划明天任务3、项目预期任务量4、敏捷开发感想5、团队期望各个成员在Alpha阶段认领的任务成员任务梁俊轩功能测试与管理雷......
  • E. Disrupting Communications
    注意可能出现dpx+1在模意义下为0的情况,此时需要额外维护0的个数而不能求逆元记f[x]表示x子树内包含x的连通子图的个数,g[x]表示全树包含x的连通子图的个数,由于子树的限制,所有fx互斥【子树互斥模型】求出f[x]后换根DP求出g[x]。答案即为u-LCA(u,v)上f的和+g[LCA(u,v)]+v-LCA(u,v......
  • 第1篇Scrum冲刺博客
    软件工程班级链接作业要求作业要求作业目标项目冲刺团队成员姓名学号王睿娴3222003968张颢严3222004426梁恬(组长)3222004467潘思言3222004423一、各个成员在Alpha阶段认领的任务二、明日各个成员的任务安排三、整个项目预期......
  • 为什么找不到vcruntime140_1.dll,无法继续执行代码的原因及五种有效解决方法
    vcruntime140_1.dll是微软VisualC++RedistributableforVisualStudio的一个动态链接库(DLL)文件。它是运行由VisualStudio2015及更高版本编译的C++应用程序所必需的。该DLL文件包含了支持C++标准库和Microsoft特定扩展功能的运行时函数,对于Windows应用程序......
  • win10找不到vcruntime140_1.dll,无法继续执行代码的解决方法
    vcruntime140_1.dll是微软VisualC++RedistributableforVisualStudio的一个动态链接库(DLL)文件。它是运行由VisualStudio2015及更高版本编译的C++应用程序所必需的。该DLL文件包含了支持C++标准库和Microsoft特定扩展功能的运行时函数,对于Windows应用程序......
  • 初学elasticsearch——除了CRUD之外我还需要关注es的哪些问题
    1.倒排索引是如何工作的倒排索引中主要有词条和文档两个概念:词条是分词后产生的词语,每条数据都有对应的文档(被序列化好的json串)倒排索引就是把词条、文档ID记录下来,每当出现一个重复的词条都会追加在文档ID如下图,词条是不会重复的 在查询的时候,我们会先对搜索内容进行分词,根......