首页 > 其他分享 >2024强化学习的结构化剪枝原理及实践

2024强化学习的结构化剪枝原理及实践

时间:2024-11-16 22:14:12浏览次数:3  
标签:剪枝 结构化 git -- 模型 2024 conda bash

[2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration

目录

一、论文说明

论文标题:使用强化学习进行结构化剪枝用于卷积神经网路压缩和加速

机构:伊利诺伊大学厄巴纳-香槟分校

论文链接:https://arxiv.org/pdf/2411.06463

代码链接:https://github.com/Beryex/RLPruner-CNN

论文简介: 卷积神经网络(ConvolutionalNeural Networks, CNNs)近年来表现出卓越的性能。压缩这些模型不仅减少了存储需求,使其在边缘设备上的部署变得可行,还加速了推理,从而降低了延迟和计算成本。结构化剪枝,它在层级上去除过滤器,直接修改了模型架构。这种方法实现了更紧凑的架构,同时保持目标准确性,确保压缩模型具有较好的兼容性和硬件效率。所提方法基于一个关键观察:

  • 1、神经网络中不同层的过滤器对模型性能的重要性各不相同。

  • 2、当修剪的过滤器数量固定时,不同层之间的最佳修剪分配是不均匀的,以最小化性能损失

  • 3、对修剪敏感的层应该占据更小的修剪分配比例。

为了利用这一洞察,文中提出了RL-Pruner,它使用强化学习来学习最佳修剪分配。RL-Pruner可以自动提取输入模型中过滤器之间的依赖关系并执行修剪,无需特定于模型的修剪实现。在GoogleNet、ResNet和MobileNet 等模型上进行了实验,将所提方法与其他结构化剪枝方法进行了比较,以验证其有效性。

在这里插入图片描述

二、原理

RL-Pruner 首先在模型中的层之间构建依赖图,然后分几个步骤进行剪枝。在每个步骤中:1) 基于基础分布生成一个新的剪枝稀疏分布作为动作 ,这作为策略;2)根据相应的稀疏度,使用泰勒准则(Taylorcriterion)对每一层进行剪枝;3) 评估压缩后的模型以获得奖励,并将动作和奖励存储在回放(replay)经验池中。每个步骤后,基础分布根据经验池更新,如果计算资源足够,则对压缩模型应用后训练阶段,使用知识蒸馏(knowledge distillation),其中原始模型作为教师。具体框图如图2所示。

三、实验与分析

1、环境配置

实验平台及软件

  • Windows 10
  • git bash
  • conda环境

这里主要介绍如何在windows系统上让git bash链接conda环境。

在Windows配置git bash链接conda环境

由于工程代码中需要使用bash命令运行代码,因此需要保证git bash能调用conda环境运行对应的脚本文件。

C:\Users\username\.bashrc文件内设置conda.sh位置(文中示例为:D:\\Anaconda3\\etc\\profile.d\\conda.sh),并激活配置。在git bash界面输入具体命令如下:

echo "D:\\Anaconda3\\etc\\profile.d\\conda.sh" >> ~/.bashrc  
source ~/.bashrc

然后关闭git bash界面,再重新打开一个git bash界面,最后输入命令激活conda环境conda activate 虚拟环境名字。如果命令提示中出现如下图所示的字样,即为配置成功,否则根据提示的要求进行配置,比如输入conda init,重新打开一个新的git bash界面。

在这里插入图片描述

2、项目代码运行

克隆项目文件,具体命令如下:

git clone https://github.com/Beryex/RLPruner-CNN.git --depth 1
cd RLPruner-CNN

安装python第三方包,具体命令如下(如果之前有conda环境,可以不用进行下面这一步,等报错了再根据提示安装对应的包即可):

conda create -n RLPruner python=3.10 -y
conda activate RLPruner
pip install -r requirements.txt

官方代码提供了一步到位的运行脚本,从预训练模型、模型压缩到模型验证,仅需在命令行中输入如下代码:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

为了更好地了解每一步的设置,下面内容将分为预训练模型、模型压缩、模型验证三个步骤进行介绍。

1、训练预训练权重

训练模型得到对应的预训练权重,这里以resnet32googlenet为例,在git bash输入具体命令(默认使用cuda)如下:

./scripts/train.sh googlenet cifar100
./scripts/train.sh resnet32 cifar100

或者使用参考指定配置命令:

python -m train --model ${MODEL} --dataset ${DATASET} --device cuda \
                --output_dir ${PRETRAINED_MODEL_DIR} \
                --log_dir ${LOG}

其中,
${MODEL}代表backbone的类型([“vgg11”, “vgg13”, “vgg16”, “vgg19”, “resnet18”, “resnet34”, “resnet50”, “resnet101”, “resnet152”, “resnet8”, “resnet14”, “resnet20”, “resnet32”, “resnet44”, “resnet56”, “resnet110”, “densenet121”, “densenet161”, “densenet169”, “densenet201”, “mobilenetv3_small”, “mobilenetv3_large”, “googlenet”]);
${DATASET}代表数据集名称,如cifar10或者cifar100。
${PRETRAINED_MODEL_DIR}代表输出权重文件路径,默认在pretrained_model文件夹下;
${LOG}代表输出日志路径,默认在log文件夹下。

在CIFAR100数据集上训练resnet32的结果(最佳准确率:0.706)如下图所示。
在这里插入图片描述
在CIFAR100数据集上训练googlenet的结果(最佳准确率:0.774)如下图所示。
在这里插入图片描述

2、模型压缩

模型结构化剪枝这里以0.2的稀疏度,taylor剪枝策略和Q_FLOP_coef=0,Q_Para_coef=0的参数进行测试。在git bash输入具体命令(默认使用cuda)如下:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

同理,也可以使用参考指定配置命令:

python -m compress --model ${MODEL} --dataset ${DATASET} --device cuda \
                   --sparsity ${SPARSITY} --prune_strategy ${prune_strategy} --ppo \
                   --Q_FLOP_coef ${Q_FLOP_coef} --Q_Para_coef ${Q_Para_coef} \
                   --pretrained_pth ${PRETRAINED_MODEL_PTH} \
                   --compressed_dir ${COMPRESSED_MODEL_DIR} \
                   --checkpoint_dir ${CKPT_DIR} \
                   --log_dir ${LOG} --save_model

测试结果如下图所示:
在这里插入图片描述

3、模型验证

在数据集上验证模型的识别性能,在git bash输入具体命令(默认使用cuda)如下:

./scripts/evaluate.sh googlenet cifar100

同理,也可以使用参考指定配置命令:

python -m evaluate --model ${MODEL} --dataset ${DATASET} --device cuda \
                   --pretrained_pth ${PRETRAINED_MODEL_PTH} \
                   --compressed_pth ${COMPRESSED_MODEL_PTH} \
                   --log_dir ${LOG}

测试结果如下图所示:
在这里插入图片描述

四、总结

本文提出了RL-Pruner,一种结构化剪枝方法,能够学习各层之间的最优稀疏性分布,并支持无模型特定修改的一般剪枝。希望所提方法能够认识到每一层对模型(model)性能的重要性不同,这将影响未来在神经网络压缩领域的工作,包括无结构剪枝和量化。

标签:剪枝,结构化,git,--,模型,2024,conda,bash
From: https://blog.csdn.net/qq_40734883/article/details/143762832

相关文章

  • 2024-2025-1 20241413 《计算机基础与程序设计》第八周学习总结
    这个作业属于哪个课程https://edu.cnblogs.com/campus/besti/2024-2025-1-CFAP这个作业要求在哪里https://www.cnblogs.com/rocedu/p/9577842.html#WEEK08作业目标功能设计与面向对象设计面向对象设计过程面向对象语言三要素汇编、编译、解释、执行--------......
  • 2024/11/15日 日志 关于 会话跟踪技术--- Cookie & Session
    会话跟踪技术--·会话:用户打开浏览器,访问web服务器的资源,会话建立,直到有一方断开连接,会话结束。--在一次会话中可以包含多次请求和响应--·会话跟踪:一种维护浏览器状态的方法,服务器需要识别多次请求是否来自于同一浏览器,--以便在同一次会话的多次请求间共享数据--·......
  • 20222325 2024-2025-1 《网络与系统攻防技术》实验六实验报告
    1.实验内容本实践目标是掌握metasploit的用法。指导书参考Rapid7官网的指导教程。https://docs.rapid7.com/metasploit/metasploitable-2-exploitability-guide/下载官方靶机Metasploitable2,完成下面实验内容。(1)前期渗透①主机发现(可用Aux中的arp_sweep,search一下就可以use)......
  • 吐槽202401113关于监管特殊停牌炒作股票事件--兼此事件复盘
    写在前面:强者从不抱怨环境,弱者只会怨天尤人我仅以此发泄下情绪,我知道不会有人在这里看股市,所以才写在这里。自媒体的宣泄是为了博取不同人的共情,吸引流量。而我仅仅为了宣泄下不满,后面看到的人不要学我。 本次的监管特停事件是一次黑天鹅事件。 复盘2024-11-09日,周五下午......
  • 20222306 2024-2025-1《网络与系统攻防技术》实验六实验报告
    1.实验内容1.1内容回顾总结这周都重点在于Metasploit工具的使用,我深入了解了对其功能和使用流程。Metasploit是一个功能强大的渗透测试框架,广泛应用于网络安全领域。它为安全专家、渗透测试人员和红队提供了一个全面的工具集,支持漏洞利用、攻击模拟和安全评估。Metasploit提......
  • 20222311 2024-2025-1 《网络与系统攻防技术》实验六实验报告
    1.实验内容1.1本周学习内容回顾使用了Metasploit框架,其是一个功能强大的渗透测试框架。在使用的过程当中,Metasploit提供了种类繁多的攻击模块,涵盖了远程代码执行、服务拒绝、提权等多种攻击方式,支持对多种操作系统和应用程序进行测试。除了漏洞利用,它还具备强大的后渗透功能,如......
  • 20222427 2024-2025-1 《网络与系统攻防技术》实验五实验报告
    1.实验内容1.1本周内容总结使用了Metasploit框架,其是一个功能强大的渗透测试框架。在使用的过程当中,Metasploit提供了种类繁多的攻击模块,涵盖了远程代码执行、服务拒绝、提权等多种攻击方式,支持对多种操作系统和应用程序进行测试。除了漏洞利用,它还具备强大的后渗透功能,如......
  • [20241114]建立完善mod_addr.sh脚本.txt
    [20241114]建立完善ext_kglob.sh脚本.txt--//以前考虑使用管道问题,我考虑复杂了,看了gdb文档,实际上gdb-ex参数支持在命令行加入执行命令。--//选择使用mmon后台进程,改写如下:$catext_kglob.sh#/bin/bash#extraceobjectstringfromobjecthandleaddress#arg1=addressarg2=o......
  • ISCTF2024-Crypto(不全)
    Crypto一开始有时间写了一点,原本不打算发的,但详细写了前面几题的wp,还是发一下。我和小蓝鲨的秘密fromPILimportImagefromCrypto.Util.numberimportbytes_to_long,long_to_bytesimportnumpyasnpn=29869349657224745144762606999e=65537original_image_p......
  • 20222426 2024-2025-1 《网络与系统攻防技术》实验五实验报告
    202224262024-2025-1《网络与系统攻防技术》实验五实验报告1.实验内容1信息搜集定义:通过各种方式获取目标主机或网络的信息,属于攻击前的准备阶段。目的:收集目标主机的DNS信息、IP地址、子域名、旁站和C段、CMS类型、敏感目录、端口信息、操作系统版本、网站架构、漏洞信息、......