首页 > 其他分享 >inception.py

inception.py

时间:2023-06-12 17:04:20浏览次数:26  
标签:卷积 self py 35 Channels inception size out

import torch
import torch.nn as nn
import torchvision.utils
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
'''
假如输入为(35, 35, 192)的数据:

第一个branch:
经过branch1x1为带有64个1*1的卷积核,所以生成第一张特征图(35, 35, 64);
第二个branch:
首先经过branch5x5_1为带有48个1*1的卷积核,所以第二张特征图(35, 35, 48),
然后经过branch5x5_2为带有64个5*5大小且填充为2的卷积核,特征图大小依旧不变,因此第二张特征图最终为(35, 35, 64);
第三个branch:
首先经过branch3x3dbl_1为带有64个1*1的卷积核,所以第三张特征图(35, 35, 64),
然后经过branch3x3dbl_2为带有96个3*3大小且填充为1的卷积核,特征图大小依旧不变,因此进一步生成第三张特征图(35, 35, 96),
最后经过branch3x3dbl_3为带有96个3*3大小且填充为1的卷积核,特征图大小和通道数不变,因此第三张特征图最终为(35, 35, 96);
第四个branch:
首先经过avg_pool2d,其中池化核3*3,步长为1,填充为1,所以第四张特征图大小不变,通道数不变,第四张特征图为(35, 35, 192),
然后经过branch_pool为带有pool_features个的1*1卷积,因此第四张特征图最终为(35, 35, pool_features);
最后将四张特征图进行拼接,最终得到(35,35,64+64+96+pool_features)的特征图。'''
# 3 2 2 1
class InceptionA(torch.nn.Module):
def __init__(self, in_Channels, out_Channels ): # inChannels表输入通道数
super(InceptionA, self).__init__()
# 2.1 第一层池化 + 1*1卷积
self.branch1_1x1 = nn.Conv2d(in_Channels, # 输入通道
out_Channels//8*3, # 输出通道
kernel_size=1) # 卷积核大小1*1
# 2.2 第二层1*1卷积
self.branch2_1x1 = nn.Conv2d(in_Channels,out_Channels//8*2, kernel_size=1)

# 2.3 第三层
self.branch3_1_1x1 = nn.Conv2d(in_Channels, out_Channels//8*2, kernel_size=1)
self.branch3_2_5x5 = nn.Conv2d(out_Channels//8*2, out_Channels//8*2, kernel_size=5, padding=2)
# padding=2,因为要保持输出的宽高保持一致

# 2.4 第四层
self.branch4_1_1x1 = nn.Conv2d(in_Channels, out_Channels//8*1, kernel_size=1)
self.branch4_2_3x3 = nn.Conv2d(out_Channels//8*1, out_Channels//8*1, kernel_size=3, padding=1)
self.branch4_3_3x3 = nn.Conv2d(out_Channels//8*1, out_Channels//8*1, kernel_size=3, padding=1)

def forward(self, X_input):
# 第一层
branch1_pool = F.avg_pool2d(X_input, # 输入
kernel_size=3, # 池化层的核大小3*3
stride=1, # 每次移动一步
padding=1)
branch1 = self.branch1_1x1(branch1_pool)
# 第二层
branch2 = self.branch2_1x1(X_input)
# 第三层
branch3_1 = self.branch3_1_1x1(X_input)
branch3 = self.branch3_2_5x5(branch3_1)
# 第四层
branch4_1 = self.branch4_1_1x1(X_input)
branch4_2 = self.branch4_2_3x3(branch4_1)
branch4 = self.branch4_3_3x3(branch4_2)
# 输出
output = [branch2, branch3, branch4, branch1]
# (batch_size, channel, w, h) dim=1: 即安装通道进行拼接。
# eg: (1, 2, 3, 4) 和 (1, 4, 3, 4)按照dim=1拼接,则拼接后的shape为(1, 2+4, 3, 4)
return torch.cat(output, dim=1)


# # 3. 整合模型
# class Net(torch.nn.Module):
# def __init__(self):
# super(Net, self).__init__()
# self.conv_1 = nn.Conv2d(in_Channels=1, out_Channels=10, kernel_size=5)
# self.conv_2 = nn.Conv2d(in_Channels=88, out_Channels=20,
# kernel_size=5)
# self.inceptionA_1 = InceptionA(in_Channels=10)
# # self.inceptionA_2 = InceptionA(in_Channels=20)
#
# self.maxPool = nn.MaxPool2d(kernel_size=2)
# # self.fullConnect = nn.Linear(in_features=1408,
# # out_features=10)
#
# def forward(self, X_input):
# batchSize = X_input.size(0)
# # 第一层: 卷积
# x = self.conv_1(X_input) # 卷积
# x = self.maxPool(x) # 池化
# x = F.relu(x) # 激活
# # 第二层: InceptionA
# out = self.inceptionA_1(x)
# # # 第三层: 再卷积
# # x = self.conv_2(x)
# # x = self.maxPool(x)
# # x = F.relu(x)
# # # 第四层: 再InceptionA
# # x = self.inceptionA_2(x)
# # # 第五层,全连接层
# # x = x.view(batchSize, -1)
# # # 表示将(batch_size, channels, w, h)按照batch_size进行拉伸成shape=(batchSize, chanenls*w*h)
# # # eg: 原x.shape=(64, 2, 3, 4),调用 y =x.view(x.size(0), -1)后,y.shape = (64, 2*3*4)=(64, 24)
# # y_pred = self.fullConnect(x)
#
# return out ##################################################################################################

这段代码定义了一个名为InceptionA的类,它继承自torch.nn.Module。这个类实现了一个Inception模块,用于卷积神经网络中。

类中定义了两个方法:__init__forward__init__方法用于初始化对象,它接受两个参数in_Channelsout_Channels,分别表示输入通道数和输出通道数。方法首先定义了四个分支,每个分支都包含若干个卷积层和激活层。第一个分支包含一个池化层和一个1x1卷积层;第二个分支包含一个1x1卷积层;第三个分支包含一个1x1卷积层和一个5x5卷积层;第四个分支包含一个1x1卷积层和两个3x3卷积层。

forward方法用于前向传播,它接受一个参数X_input,表示输入数据。方法首先对输入数据进行池化操作,并使用第一个分支中的1x1卷积层对池化后的数据进行卷积操作。然后,使用第二、三、四个分支对输入数据进行卷积操作。最后,将四个分支的输出在通道维度上拼接在一起,并返回结果。

此外,代码中还有一些被注释掉的代码,它们定义了一个名为Net的类,用于整合模型。但是这些代码并未被使用。

这段代码定义了一个名为InceptionA的类,它实现了一个Inception模块。这个类与前面的代码有关系,因为它被用于构建卷积神经网络模型。在前面的代码中,有一段定义了一个名为VGG16_cml的类,它实现了一个卷积神经网络模型。在这个类的__init__方法中,使用了InceptionA类替换了一些卷积层。因此,这段代码定义的InceptionA类是用于构建卷积神经网络模型的一个组件。

 

标签:卷积,self,py,35,Channels,inception,size,out
From: https://www.cnblogs.com/wzbzk/p/17475477.html

相关文章

  • cbam.py
    importtorchimportmathimporttorch.nnasnnimporttorch.nn.functionalasFclassBasicConv(nn.Module):def__init__(self,in_planes,out_planes,kernel_size,stride=1,padding=0,dilation=1,groups=1,relu=True,bn=True,bias=False):super(Basic......
  • 0基础学python
    Python学习路线 精品Python学习书籍 技能对照表 ......
  • python 序列化模块
    一、jsonJson模块提供了四个功能:dumps、dump、loads、load1、前景什么叫序列化——将原本的字典、列表等内容转换成一个字符串的过程就叫做序列化。序列化的目的以某种存储形式使自定义对象持久化;将对象从一个地方传递到另一个地方。使程序更具维护性2、loads和dumps......
  • pytest + yaml 框架 -37.mark 标记对用例运行时长断言
    前言pytest执行用例的时候,我们希望对用例的运行时间断言,当用例执行时长大于预期标记此用例失败。@pytest.mark.runtime(1)运行时长单位是秒此插件已打包上传到pypihttps://pypi.org/project/pytest-runtime-yoyo/1.0.0/环境准备pipinstallpytest-yaml-yoyo此功能在v1.......
  • 装pytorch环境
    第一步:先装cuda,装完就可以在cmd显示,cudnn。第二步:在anaconda里安装,加环境,create-namepython=3.10等。第三步,进去环境里,安装的pytorch要对应cudnn版本,还有python版本对应。pytorch安装的时候看仔细,是GPU,不要cpu版本的。结束......
  • scrcpy——Android投屏神器(使用教程)
    scrcpy简介简单地来说,scrcpy就是通过adb调试的方式来将手机屏幕投到电脑上,并可以通过电脑控制您的Android设备。它可以通过USB连接,也可以通过Wifi连接(类似于隔空投屏),而且不需要任何root权限,不需要在手机里安装任何程序。scrcpy同时适用于GNU/Linux,Windows和macOS。它的一些特......
  • 手机在线玩Python的15种方法!
    /手机写代码 /android安卓 QPython.apk链接:https://pan.baidu.com/s/1S2mFHsqa3Zuyxiua6nGsbg 提取码:b1g2  Pydroid.apk链接:https://pan.baidu.com/s/10Bnyl6AdUI2mBRZEuLMB6g 提取码:678f Python教程.apk链接:https://pan.baidu.com/s/1iRJC4mAUTCGBounShuXxdg?pw......
  • python使用HTTP隧道代理代码示例模板
    以下是使用HTTP隧道代理的Python代码示例模板:```pythonimportrequests#设置代理服务器地址和端口号proxy_host="your_proxy_host"proxy_port="your_proxy_port"#设置代理服务器的用户名和密码(如果需要)proxy_username="your_proxy_username"proxy_password="your_proxy_p......
  • python的shell用法
    python的shell用法python[-bBdEhiIOqsSuvVWx?][-ccommand|-mmodule-name|script|-][args]Python-mpython-mmodule名args检索对应的模块名去执行,对于一个普通的模块,可能下面两种写法实际上是等效的:python-mtestpythontest.py两种写法都是将对应的py文......
  • 实验6 turtle绘图与python库应用编程体验
    实验任务1fromturtleimport*defmove(x,y):'''画笔移动到坐标(x,y)处'''penup()goto(x,y)pendown()defdraw(n,size=100):'''绘制边长为size的正n边形'''foriinrange(n):fd(siz......