首页 > 其他分享 >yolov8的模型剪枝教程

yolov8的模型剪枝教程

时间:2024-08-05 10:28:16浏览次数:14  
标签:剪枝 教程 prune conv self yolov8 model data

        模型剪枝是用在模型的一种优化技术,旨在减少神经网络中不必要的参数,从而降低模型的复杂性和计算负载,进一步提高模型的效率。

        模型剪枝的流程:约束训练(constained training)、剪枝(prune)、回调训练(finetune)

        本篇主要记录自己YOLOv8模型剪枝的全过程,主要参考:YOLOv8剪枝全过程

目 录

一、约束训练(constrained training)

1、参数设置

2、稀疏训练

二、剪枝(prune)

三、回调训练(finetune)

1、代码修改

2、再训练


 

一、约束训练(constrained training)

1、参数设置

         设置./ultralytics/cfg/default.yaml中的amp=False

2、稀疏训练

        主要方式:在BN层添加L1正则化

        具体步骤:在./ultralytics/engine/trainer.py中添加以下内容:

                # Backward
                self.scaler.scale(self.loss).backward()
 
                # ========== added(新增) ==========
                # 1 constrained training
                l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
                for k, m in self.model.named_modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
                        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
                # ========== added(新增) ==========
 
                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
                if ni - last_opt_step >= self.accumulate:
                    self.optimizer_step()
                    last_opt_step = ni

 

        然后启动训练(/yolov8/train.py):

  

from ultralytics import YOLO
 
model = YOLO('yolov8n.yaml')
 
results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=300, save=True)

 

二、剪枝(prune)

        一该部分选用上一步训练得到的模型./runs/detect/train2/weight/last.pt进行剪枝处理。在/yolov8/下新建文件prune.py,具体内容如下:

  

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
 
# Load a model
yolo = YOLO("./runs/detect/train2/weights/last.pt")
model = yolo.model
 
ws = []
bs = []
 
for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
 
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)
 
 
def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta = conv1.bn.bias.data.detach()
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    # print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n
 
    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
 
    if not isinstance(conv2, list):
        conv2 = [conv2]
 
    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
 
 
def prune(m1, m2):
    if isinstance(m1, C2f):  # C2f as a top conv
        m1 = m1.cv2
 
    if not isinstance(m2, list):  # m2 is just one module
        m2 = [m2]
 
    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1
 
    prune_conv(m1, m2)
 
 
for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)
 
seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i + 1])
 
detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])
 
for name, p in yolo.model.named_parameters():
    p.requires_grad = True
 
yolo.val()  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调
 
torch.save(yolo.ckpt, "./runs/detect/train2/weights/prune.pt")
print("done")

 

其中,factor=0.8 表示的是保持率,factor越小,裁剪的就越多,一般不建议裁剪太多。

        运行prune.py,可得到剪枝后的模型prune.pt,保存在./runs/detect/train2/weight/中。同文件夹下,还有last.onnx,可以看到onnx文件的大小比剪枝前变小了,具体结构(onnx模型结构查看)也和剪枝前的onnx相比有了轻微变化。

三、回调训练(finetune)

1、代码修改

        首先,将先前在./ultralytics/engine/trainer.py中添加的L1正则化部分注释掉:

  

                # Backward
                self.scaler.scale(self.loss).backward()
 
                # # ========== added(新增) ==========
                # # 1 constrained training
                # l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
                # for k, m in self.model.named_modules():
                #     if isinstance(m, nn.BatchNorm2d):
                #         m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
                #         m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
                # # ========== added(新增) ==========
 
                # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
                if ni - last_opt_step >= self.accumulate:
                    self.optimizer_step()
                    last_opt_step = ni

 

        然后,在该文件第543行左右添加代码 “self.model = weights” :

  

    def setup_model(self):
        """Load/create/download model for any task."""
        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return
 
        model, weights = self.model, None
        ckpt = None
        if str(model).endswith(".pt"):
            weights, ckpt = attempt_load_one_weight(model)
            cfg = weights.yaml
        else:
            cfg = model
        self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)
        # ========== added(新增) ==========
        # 2 finetune 回调训练
        self.model = weights
        # ========== added(新增) ==========
        return ckpt

 

2、再训练

         利用已经剪枝好的模型prune.pt,我们再次启动训练(/yolov8/train.py):

  

from ultralytics import YOLO
 
model = YOLO('./runs/detect/train5/weights/prune.pt')
results = model.train(data='./data/data_nc5/data_nc5.yaml', batch=8, epochs=100, save=True)

 

注意,这里把model改成了"prune.pt",而不是原来的"yolov8n.yaml"

        训练后新的模型保存在“./runs/detect/train3/weight/”中。后面可按需要进一步进行模型的推理和部署。

 

标签:剪枝,教程,prune,conv,self,yolov8,model,data
From: https://www.cnblogs.com/chentiao/p/18342742

相关文章

  • pytorch中中的模型剪枝方法
     一,剪枝分类 所谓模型剪枝,其实是一种从神经网络中移除"不必要"权重或偏差(weigths/bias)的模型压缩技术。关于什么参数才是“不必要的”,这是一个目前依然在研究的领域。 1.1,非结构化剪枝 非结构化剪枝(UnstructuredPuning)是指修剪参数的单个元素,比如全连接层中的单个权......
  • 【web3.0】Web3 开发教程与代码资源:探索如何在Web3项目中开发应用
    引言Web3,作为区块链技术和互联网融合的产物,正逐步重塑我们对数字世界的理解与交互方式。它不仅仅是一个技术概念,更是一个去中心化、用户主权的网络愿景,旨在通过智能合约、去信任的交易和加密货币等技术手段,为用户提供前所未有的数据安全性和经济自主权。本教程将引导你从零开......
  • steam使用环境,下载,安装综合教程
    在安装steam前必须先安装steam++(WattToolkit),(在注册steam账号期间,或者登录期间先打开steam++)直接搜索引擎搜索,找到下载打开后,在左边栏点击蓝色小闪电,然后全选点击一键加速。一、Steam下载安装1.1Steam下载Steam官网:https://store.steampowered.com/浏览器进入Steam......
  • 详细教程 MySQL 数据库 下载 安装 连接 环境配置 全面
    数据库就是储存和管理数据的仓库,对数据进行增删改查操作,其本质是一个软件。首先数据有两种,一种是关系型数据库,另一种是非关系型数据库。关系型数据库是以表的形式来存储数据,表和表之间可以有很多复杂的关系,比如:MySQL、Oracle、SQLServer等;非关系型数据库是以数据集的形式存......
  • 怎么在Ubuntu系统云服务器搭建自己的幻兽帕鲁服务器?幻兽帕鲁搭建教程
    《幻兽帕鲁》是一款备受瞩目的开放世界生存建造游戏,近期在游戏界非常火爆。玩家可以在游戏世界中收集神奇的生物“帕鲁”,并利用它们进行战斗、建造、农耕、工业生产等各种活动。与其他开放世界游戏不同,要想实现多人联机游戏,玩家需要自行搭建服务器。目录基本步骤创建和登录主机......
  • Linux设置定时任务命令crontab详解教程
    一、crontab命令介绍crontab是一个在Linux系统中用于设置周期性被执行的任务的工具,‌即可以执行定时任务,它可以帮助用户实现定时间运行程序或脚本的需求。‌/var/spool/cron/目录下存放的是每个用户包括root的crontab任务,每个任务以创建者的名字命名/etc/crontab这个文......
  • 直播自动回复浏览器插件开发-抖音直播自动回复插件-抖音小店飞鸽客服自动回复插件(简单
    抖音直播自动回复插件抖音小店飞鸽客服自动回复插件演示网站:https://gofly.sopans.com/douyin.html开发浏览器插件是一个相对复杂的过程,涉及到前端开发、浏览器API的使用以及插件的架构设计。以下是开发浏览器插件的一般步骤:了解浏览器插件基础:学习浏览器插件的基本概念,包......
  • Python 基础教程:List(列表)的使用
    《Python基础教程:List(列表)的使用》在Python中,列表是最基本的数据结构之一,它是一种有序的、可变的数据集合,可以包含任意类型的元素,包括数字、字符串、其他列表等。1.列表的创建列表使用方括号[]创建,列表中的元素用逗号,分隔。#创建一个包含整数的列表numbers......
  • 五级分销版蝶影全网VIP影视 APP源码 安卓+苹果iOS双端+搭建教程
    ###五级分销版蝶影全网VIP影视APP源码安卓+苹果iOS双端+搭建教程在数字娱乐的浪潮中,影视APP成为了人们生活中不可或缺的一部分。随着技术的不断进步,定制化的影视APP源码成为了市场上的新宠。本文将详细介绍一款名为“蝶影”的全网VIP影视APP源码,它支持五级分销模式,并提供......
  • 《刚刚问世》系列初窥篇-Java+Playwright自动化测试-5-创建首个自动化脚本(详细教程)
     软件测试微信群:https://bbs.csdn.net/topics/618423372 有兴趣的可以扫码加入 1.简介前面几篇宏哥介绍了两种(java和maven)环境搭建和浏览器的启动方法,这篇文章宏哥将要介绍第一个自动化测试脚本。前边环境都搭建成功了,浏览器也驱动成功了,那么我们不着急学习其他内容,首先宏......