首页 > 其他分享 >深度学习Pytorch中组卷积的参数存储方式与剪枝的问题

深度学习Pytorch中组卷积的参数存储方式与剪枝的问题

时间:2023-04-15 23:14:47浏览次数:52  
标签:剪枝 nn conv 中组 卷积 self Pytorch 64 256

写这个主要是因为去年做项目的时候 需要对网络进行剪枝 普通卷积倒没问题 涉及到组卷积通道的裁剪就对应不上 当时没时间钻研 现在再看pytorch 钻研了一下 仔细研究了一下卷积的weight.data的存储

1.搭建网络

这里先随便搭建一下网络 放几个深度可分离卷积和普通卷积

import torch.nn as nn


def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class DWConv(nn.Module):
    def __init__(self, in_plane, out_plane):
        super(DWConv, self).__init__()
        self.depth_conv = nn.Conv2d(in_channels=in_plane,
                                    out_channels=in_plane,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    groups=in_plane)
        self.point_conv = nn.Conv2d(in_channels=in_plane,
                                    out_channels=out_plane,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    groups=1)
 
    def forward(self, x):
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x



class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))

class TestModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = Conv(3, 64, 3, 2, 1)
        self.dwconv = DWConv(64, 256)
        self.conv1 = DWConv(256, 4)
        self.conv2 = Conv(4, 512, 8, 2, 1)
    def forward(self, x):
        x = self.conv(x)
        x = self.dwconv(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

2.测试代码,查看网络层

model.modules()迭代遍历模型的所有子层,所有子层即指nn.Module子类,nn.xxx构成的卷积,池化,ReLU, Linear, BN, 等都是nn.Module子类,也就是model.modules()会迭代的遍历它们所有对象。

model.named_modules() 就是有名字的model.modules()。

import torch.nn as nn
import testM
model = testM.TestModule()
for modelName, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        print(modelName)
        print(layer)
        print(layer.weight.data.shape)
        print(1)

这里只输出了有关于卷积层的内容

conv.conv
Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
torch.Size([64, 3, 3, 3])
1
dwconv.depth_conv
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
torch.Size([64, 1, 3, 3])
1
dwconv.point_conv
Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
torch.Size([256, 64, 1, 1])
1
conv1.depth_conv
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
torch.Size([256, 1, 3, 3])
1
conv1.point_conv
Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
torch.Size([4, 256, 1, 1])
1
conv2.conv
Conv2d(4, 512, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1), bias=False)
torch.Size([512, 4, 8, 8])
1

3.普通卷积

拿conv.conv来说, 参数的size是[64, 3, 3, 3] 第二个是卷积前的通道数 ,第一个是卷积后的通道数。正如图中所示,中间部分就是卷积核,共有64个3×3×3的卷积核,每一个进行卷积操作,最终生成64个通道的特征图。

conv2.conv 就是有512个4×8×8的卷积核进行卷积操作,最终生成512个通道的特征图。

剪枝裁剪通道的时候 直接根据上层卷积修改第二个卷积前的通道数再留下本层有效的卷积通道数即可。

4.深度可分离卷积


拿conv1来说,先说depth_conv,参数的size是[256, 1, 3, 3],这里第一个参数就是卷积后的通道数,如深度可分离卷积图1所展示,由原来的256个通道各自与对用的3×3大小的卷积核进行卷积操作,最终得到256通道的特征图,这里每个通道都只和自己对应的一个3×3大小的卷积核进行一次卷积操作,可以当作普通卷积中filter为1的情况。所以第二个参数就是1,这里进行剪枝操作的时候要根据上层卷积裁剪后的通道数修改第一个参数和group即可(当时就是卡在这里了)。

然后是point_conv,这里就和普通的卷积原理一样了,修改通道也是根据普通卷积那里修改即可。

标签:剪枝,nn,conv,中组,卷积,self,Pytorch,64,256
From: https://www.cnblogs.com/lisuhang/p/17322194.html

相关文章

  • NOC 2022 初中组选择和编程题题解
    NOC2022初中组选择题和编程题题解注意:本文有几个问题:部分题目我也不确定答案,而且我水平不行,有些题目我还真不会,大家就把我的答案当个参考吧。目前有一大半的题目因为作者比较懒,暂时没写,空在那儿,可以下载原题自己做做。1初中组选拔赛原题链接,提取码:efy6。1.1选择题......
  • 第二章(4)Pytorch安装和张量创建
    第二章(4)Pytorch安装和张量创建1.Pytorch基础PyTorch是一个基于Python的科学计算库,也是目前深度学习领域中最流行的深度学习框架之一。PyTorch的核心理念是张量计算,即将数据表示为张量,在计算时使用自动微分机制优化模型。在使用PyTorch进行深度学习时,了解张量的基础操作、类型、......
  • [附CIFAR10炼丹记前编] CS231N assignment 2#5 _ pytorch 学习笔记 & 解析
    pytorch环境搭建课程给你的环境当中,可以直接用pytorch,当时其默认是没有给你安装显卡支持的.如果你只用CPU来操作,那其实没什么问题,但我的电脑有N卡,就不能调用. 考虑到我已有pytorch环境(大致方法就是确认pytorch版本和对应的cuda版本安装cuda,再按照官网即可,建议自......
  • 从零开始配置深度学习环境:CUDA+Anaconda+Pytorch+TensorFlow
    本文适用于电脑有GPU(显卡)的同学,没有的话直接安装cpu版是简单的。CUDA是系统调用GPU所必须的,所以教程从安装CUDA开始。CUDA安装CUDA是加速深度学习计算的工具,诞生于NVIDIA公司,是一个显卡的附加驱动。必须使用NVIDIA的显卡才能安装,可以打开任务管理器查看自己的硬件设备。下载CU......
  • Pytorch one-hot编码
    1.引言在我们做分割任务时,通常会给一个mask,但训练时要进行onehot编码。2.codeimporttorchif__name__=='__main__':label=torch.zeros(size=(1,4,4),dtype=torch.int)label[:,2:4]=1print(label.shape)print(label)label_one_hot......
  • 使用Pytorch实现强化学习——DQN算法
    使用Pytorch实现强化学习——DQN算法强化学习的主要构成强化学习主要由两部分组成:智能体(agent)和环境(env)。在强化学习过程中,智能体与环境一直在交互。智能体在环境里面获取某个状态后,它会利用该状态输出一个动作(action)。然后这个动作会在环境之中被执行,环境会根据智能体采取的动......
  • PyTorch深度学习建模与应用--每日最高温度预测
    1.python2.JupyterLabhttp://jupyter.org/安装jupyterlab只需要在命令提示符中输入pipinstalljupyterlab启动则在命令提示符中输入jupyterlabhttps://jupyter.org/try-jupyter/lab/  可以在这里进行尝试。3.PyTorchpytorch的配置可以看这篇https://blog.csdn.net/m0_7257......
  • Anaconda环境下安装gpu版pytorch
    cuda安装首先到下面的网址下载cude,注意,不要下载最新的,目前pytorch支持的最新版本是11.8。https://developer.nvidia.com/cuda-toolkit-archivepytorch安装打开Anaconda自带的命令行,如下图所示。再到下面的网站获取安装命令。https://pytorch.org/get-started/locally/在安......
  • WSL2安装CUDA & pytorch
     WSL2安装pytorchwsl-ubuntu安装1操作系统,win11开启CPU虚拟化   如果是关闭状态,需要进入到BOIS中打开设置。  开启虚拟机平台搜索栏中搜索功能,即可出现“启用或关闭Windows功能”      升级配置wslhttps://wslstorestorage.blob.core.win......
  • 深度学习之PyTorch实战(5)——对CrossEntropyLoss损失函数的理解与学习
     其实这个笔记起源于一个报错,报错内容也很简单,希望传入一个三维的tensor,但是得到了一个四维。RuntimeError:onlybatchesofspatialtargetssupported(3Dtensors)butgottargetsofdimension:4查看代码报错点,是出现在pytorch计算交叉熵损失的代码。其实在......