首页 > 其他分享 >cbam.py

cbam.py

时间:2023-06-12 17:00:29浏览次数:31  
标签:__ nn cbam self py att pool size

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None

def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x

class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
elif pool_type=='lp':
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( lp_pool )
elif pool_type=='lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp( lse_pool )

if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw

scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale

def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs

class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out) # broadcasting
return x * scale

class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out ##############################################################

这段代码定义了几个类,它们实现了一个名为CBAM(Convolutional Block Attention Module)的注意力模块。这个模块用于卷积神经网络中,可以增强模型的表达能力。

代码中定义了五个类:BasicConvFlattenChannelGateSpatialGateCBAM

BasicConv类实现了一个基本的卷积层,包括卷积、批量归一化和激活操作。

Flatten类实现了一个展平层,用于将多维张量展平为一维张量。

ChannelGate类实现了一个通道注意力门,用于计算每个通道的注意力权重。它接受一个参数gate_channels,表示输入通道数。类中定义了一个多层感知器(MLP),用于计算每个通道的注意力权重。在前向传播过程中,首先对输入数据进行全局池化操作,然后使用MLP计算每个通道的注意力权重。最后,使用sigmoid函数将注意力权重转换为0到1之间的数值,并返回结果。

SpatialGate类实现了一个空间注意力门,用于计算每个空间位置的注意力权重。它包括一个通道池化层和一个卷积层。在前向传播过程中,首先对输入数据进行通道池化操作,然后使用卷积层计算每个空间位置的注意力权重。最后,使用sigmoid函数将注意力权重转换为0到1之间的数值,并返回结果。

CBAM类整合了通道注意力门和空间注意力门,实现了一个完整的CBAM模块。它接受三个参数:gate_channels表示输入通道数;reduction_ratio表示通道注意力门中MLP的压缩比例;pool_types表示通道注意力门中使用的池化类型。在前向传播过程中,首先使用通道注意力门计算每个通道的注意力权重,并对输入数据进行加权。然后,使用空间注意力门计算每个空间位置的注意力权重,并对加权后的数据进行进一步加权。最后,返回加权后的结果。

这些类与前面的代码有关系,因为它们被用于构建卷积神经网络模型。在前面的代码中,有一段定义了一个名为VGG16_cml的类,它实现了一个卷积神经网络模型。在这个类的__init__方法中,在特定位置添加了自定义的CBAM类。

 

标签:__,nn,cbam,self,py,att,pool,size
From: https://www.cnblogs.com/wzbzk/p/17475509.html

相关文章

  • 0基础学python
    Python学习路线 精品Python学习书籍 技能对照表 ......
  • python 序列化模块
    一、jsonJson模块提供了四个功能:dumps、dump、loads、load1、前景什么叫序列化——将原本的字典、列表等内容转换成一个字符串的过程就叫做序列化。序列化的目的以某种存储形式使自定义对象持久化;将对象从一个地方传递到另一个地方。使程序更具维护性2、loads和dumps......
  • pytest + yaml 框架 -37.mark 标记对用例运行时长断言
    前言pytest执行用例的时候,我们希望对用例的运行时间断言,当用例执行时长大于预期标记此用例失败。@pytest.mark.runtime(1)运行时长单位是秒此插件已打包上传到pypihttps://pypi.org/project/pytest-runtime-yoyo/1.0.0/环境准备pipinstallpytest-yaml-yoyo此功能在v1.......
  • 装pytorch环境
    第一步:先装cuda,装完就可以在cmd显示,cudnn。第二步:在anaconda里安装,加环境,create-namepython=3.10等。第三步,进去环境里,安装的pytorch要对应cudnn版本,还有python版本对应。pytorch安装的时候看仔细,是GPU,不要cpu版本的。结束......
  • scrcpy——Android投屏神器(使用教程)
    scrcpy简介简单地来说,scrcpy就是通过adb调试的方式来将手机屏幕投到电脑上,并可以通过电脑控制您的Android设备。它可以通过USB连接,也可以通过Wifi连接(类似于隔空投屏),而且不需要任何root权限,不需要在手机里安装任何程序。scrcpy同时适用于GNU/Linux,Windows和macOS。它的一些特......
  • 手机在线玩Python的15种方法!
    /手机写代码 /android安卓 QPython.apk链接:https://pan.baidu.com/s/1S2mFHsqa3Zuyxiua6nGsbg 提取码:b1g2  Pydroid.apk链接:https://pan.baidu.com/s/10Bnyl6AdUI2mBRZEuLMB6g 提取码:678f Python教程.apk链接:https://pan.baidu.com/s/1iRJC4mAUTCGBounShuXxdg?pw......
  • python使用HTTP隧道代理代码示例模板
    以下是使用HTTP隧道代理的Python代码示例模板:```pythonimportrequests#设置代理服务器地址和端口号proxy_host="your_proxy_host"proxy_port="your_proxy_port"#设置代理服务器的用户名和密码(如果需要)proxy_username="your_proxy_username"proxy_password="your_proxy_p......
  • python的shell用法
    python的shell用法python[-bBdEhiIOqsSuvVWx?][-ccommand|-mmodule-name|script|-][args]Python-mpython-mmodule名args检索对应的模块名去执行,对于一个普通的模块,可能下面两种写法实际上是等效的:python-mtestpythontest.py两种写法都是将对应的py文......
  • 实验6 turtle绘图与python库应用编程体验
    实验任务1fromturtleimport*defmove(x,y):'''画笔移动到坐标(x,y)处'''penup()goto(x,y)pendown()defdraw(n,size=100):'''绘制边长为size的正n边形'''foriinrange(n):fd(siz......
  • 为什么很多人自学Python都放弃了?
     有些人学Python并不是因为对编程有浓厚的兴趣,或者没有经验尝试认为入门容易。因此,当他们开始学习Python时,可能会遇到一些困难或感到挫败,导致他们放弃。总的来说,python并不是你想想的那样简单。以下是可能导致放弃的原因:缺乏计算机科学和编程基础:学习Python需要具备一定的计......