首页 > 其他分享 >Torch-Pruning工具箱

Torch-Pruning工具箱

时间:2022-12-04 18:55:26浏览次数:64  
标签:剪枝 Torch tp strategy Pruning idxs 工具箱 model pruning

Torch-Pruning
通道剪枝网络实现加速的工作。
image
Torch pruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪枝,自动发现层与层剪枝的依赖关系,可以处理Densenet、ResNet和DeepLab

特性

卷积网络通道剪枝 CNNs (e.g. ResNet, DenseNet, Deeplab) 和 Transformers (即 Bert, @horseee贡献代码)

  • 网络图跟踪以及依赖关系.
  • 支持网络层: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding 和 扩展层.
  • 支持操作: split, concatenation, skip connection, flatten, 等等.
  • 剪枝策略: Random, L1, L2, 等等.

它是怎样工作的

Torch-Pruning 使用 fake inputs输入网络和torch.jit一样收集网络信息.

dependency graph 用来表示计算图和层之间的关系. 由于裁剪一层会影响若干层 , dependecy会自动传播剪枝到其他层并且保存在PruningPlan.

如果模型中有 torch.split或者torch.cat,所有剪枝的indices都会做一些变换的

Conv-Conv:\(n_{i+1}\) oc中减少1个通道,下一个卷积每个通oc通道中ic通道\(n_{i+1}\)少一个
Skip Connection: 需要考虑ic和上一层的oc互相关联,所以这里shortcutadd都需要传递这种关联。

依赖关系 可视化 例子
Conv-Conv image AlexNet
Conv-FC(Global Pooling or Flatten) image ResNet,VGG
Skip Connection image ResNet
Concatenation image DenseNet, ASPP
Split image torch.chunk

一个例子

先来看下torchpruning 的流程图:
image

# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

# 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4, round_to=16) # or manually selected pruning_idxs=[2, 6, 9, ...]
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan)

# 4. execute this plan (prune the model)
pruning_plan.exec()

print(model)

image

pruning_plan = DG.get_pruning_plan( pruning_idxs ):
image

底层剪枝函数

使用一层一层的固定剪枝和上面是等价的

tp.prune_conv( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )

运行结果:

(Conv2d(36, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
 3456)

对设备友好的通道对齐剪枝

可以通过设置round_to参数,下例可以使得通道对16取整(即,16,32,48,64)

strategy = tp.strategy.L1Strategy()
pruning_idxs = strategy(model.conv1.weight, amount=0.2, round_to=16)

image

image

本文暂时没有对torch pruning源码进行分析,先学会使用,后续如果有需要、有时间会再进行源码分析

标签:剪枝,Torch,tp,strategy,Pruning,idxs,工具箱,model,pruning
From: https://www.cnblogs.com/whiteBear/p/16930896.html

相关文章

  • pytorch安装
    pytorch安装1、查看本机的CUDA版本cmd命令行输入nvidia-smi,在第一行最右边可以看到CUDA的版本号![version](C:\Users\nice7\Pictures\SavedPictures\version.png)2、......
  • pytorch 如何从checkpoints中继续训练
    左1:从头开始训练时,lr的变化。左2:从epoch100时开始训练......
  • win10 中 anaconda3 安装 pytorch 教程
    anaconda中自带python,所以不需要提前安装python。1.安装anaconda3下载链接:https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/下载文件:Anaconda3-2021.11-Windo......
  • 小丸工具箱经典版下载,开发必备
    关注微信公众号【工控羊】或者微信号【gksheep】,微信公众号后台输入数字编号【2032】即可获取下载链接。......
  • Pytorch tensor操作 gather、expand、repeat、reshape、view、permute、transpose
    文章目录​​tensor.gather​​​​tensor.expand​​​​tensor.repeat​​​​reshape()和view()​​​​permute()和transpose()​​​​torch.matmul()​​​​torc......
  • torch.nn.CrossEntropyLoss
    文章目录​​交叉熵损失函数`torch.nn.CrossEntropyLoss`​​​​F.cross_entropy​​​​F.nll_loss​​交叉熵损失函数​​torch.nn.CrossEntropyLoss​​weight(Tensor......
  • Pytorch mask:上三角和下三角
    上三角triuPytorch上三角和下三角的调用与numpy是相同的。np.triu(np.ones((5,5)),k=0)#k控制对角线开始的位置Out[25]:array([[1.,1.,1.,1.,1.],[0.,1.,1......
  • 矩池云 | GPU 分布式使用教程之 Pytorch
    GPU分布式使用教程之PytorchPytorch官方推荐使用DistributedDataParallel(DDP)模块来实现单机多卡和多机多卡分布式计算。DDP模块涉及了一些新概念,如网络(WorldSize......
  • torch.autograd.Function 用法及注意事项
    众所周知,作为深度学习框架之一的PyTorch和其他深度学习框架原理几乎完全一致,都有着自动求导机制,当然也可以说成是自动微分机制。有些时候,我们不想要它自带的求导机制,需要......
  • 关于hutool工具箱进行RSA非对称加密的使用笔记
    首先是导入hutool工具包的maven依赖<!--huTool工具箱--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</arti......