首页 > 其他分享 >03常用pytorch剪枝工具

03常用pytorch剪枝工具

时间:2023-07-01 16:44:22浏览次数:46  
标签:剪枝 prune weight 03 torch module pytorch model

常用剪枝工具

pytorch官方案例

import torch.nn.utils.prune as prune

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = LeNet().to(device=device)
module = model.conv1
prune.random_structurd(module, name="weight", amount=0.3, dim=1)

#对同一层进行连续不同的剪枝
prune.l1_unstructured(module, name="weight", amount=3)
prune.l1_unstructured(module, name="bias", amount=3)
prune.ln_structured(module, name="bias", amount=0.5, n=3, dim=0)

序列化剪枝后的模型

在PyTorch中,named_buffers()是一个模型的方法,它返回一个迭代器,这个迭代器包含了模型中所有持久化的缓冲区。在每次迭代中,它返回一个包含缓冲区名(name)和缓冲区的张量(tensor)的元组。

在神经网络中,有些数据虽然不是模型参数(也就是不会在反向传播中被更新),但是这些数据在前向传播过程中是需要的,这些数据就被称为缓冲区(buffer)。缓冲区通常用于存储不参与梯度计算,但需要在训练过程中持久化的数据。例如,批归一化(Batch Normalization)层中的运行平均值和运行方差就是存储在缓冲区中的。

对于剪枝操作来说,剪枝的掩码通常会被保存为一个缓冲区。这个掩码的作用是在前向传播过程中把被剪枝的权重(也就是被设为0的权重)从计算中排除出去。

所以,named_buffers()函数就是用来获取模型中所有缓冲区的名称和对应的数据。这在进行剪枝操作时,可以用来检查剪枝的掩码是否已经被正确地添加到模型中。

#state_dict()是一个PyTorch模型的方法,它返回一个字典,其中包含了模型的所有参数,包括权重和偏置。字典的键是参数的名称,值是参数的值。这个字典可以用于保存和加载模型的参数。
#keys()是Python字典的一个方法,它返回字典的所有键的列表。
#所以,model.state_dict().keys()返回的是一个包含模型中所有参数名称的列表。weight和bias
print(model.state_dict().keys())

new_model = LeNet()
#这行代码开始遍历模型中的所有模块(或层)。named_modules()函数返回一个迭代器,每次迭代返回一个包含模块名(name)和模块实例(module)的元组。
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

global pruning

model = LeNet()
#第一个元素是model,第二个元素是这个model里哪一些参数要被剪掉
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)
#进行全局无结构剪枝
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

"Sparsity"(稀疏性)是一个数学概念,用于描述一个矩阵中零元素的比例。在深度学习中,稀疏性通常用来描述模型权重矩阵中零值的比例。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

自定义pruning functions

下面是每隔一个就进行一次非结构化剪枝

自定义剪枝pytorch官方教程: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#:~:text=Global sparsity%3A 20.00%25-,Extending torch.nn.utils.prune with custom pruning,-functions

pytorch源码参考: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/prune.py#:~:text=%40abstractmethod,method recipe.

#该类是prune.BasePruningMethod的子类
class ImplEveryOtherPruningMethod(prune.BasePruningMethod):
    #定义剪枝类型
    PRUNING_TYPE = 'unstructured'
    #重写了基类中的抽象方法compute_mask。该方法接收两个参数,一个是待剪枝的张量t,另一个是默认的掩码default_mask。
    def compute_mask(self, t, default_mask):
        #创建一个default_mask的副本,这是为了避免改变原始的default_mask。
        mask = default_mask.clone()
        #这个操作首先将掩码的形状改为一维mask.view(-1),然后选择索引为偶数的所有元素[::2],将它们设置为0。这样就达到了每隔一个元素剪枝的效果。
        mask.view(-1)[::2] = 0
        return mask
def Ieveryother_unstructured_prune(module, name):
    #生成一个想要的mask,并且apply到module的元素上
    ImplEveryOtherPruningMethod.apply(module, name) 
    return module
model = LeNet()
Ieveryother_unstructured_prune(model.fc3, name='bias')

print(model.fc3.bias_mask)

标签:剪枝,prune,weight,03,torch,module,pytorch,model
From: https://www.cnblogs.com/125418a/p/17519499.html

相关文章

  • 算法学习day03链表part01-203、707、206
    packageSecondBrush.LinkedList.LL1;/***203.移除链表元素*删除链表中等于给定值val的所有节点。*自己再次概述一下这个过程:*1.移除元素,要采用设置虚拟节点的方式,因为那样不需要考虑头结点问题*2.设置两个虚拟指向*3.移除元素就是遍历链表,然后碰到目标值......
  • 游戏服务器被攻击怎么办?绍兴高防服务器租用203.135.102.x
    游戏服务器遭受攻击的原因可能有很多。攻击者可能会利用多种方式来入侵服务器,如通过计算机病毒、木马程序、蠕虫程序和社交工程等方式。这些攻击可以让服务器瘫痪,造成用户数据丢失、业务中断,甚至影响到公司的声誉。今天我就来和大家说原因和解决方法一、竞争对手来攻击你的服务器,让......
  • 八期day03-反编译工具和hook框架
    一反编译工具1.1常见反编译工具常见的反编译工具:jadx(推荐)、jeb、GDA反编译工具依赖于java环境,所以我们按照jdk1.2JDK环境安装#官方地址:(需要注册-最新java21)https://www.oracle.com/java/technologies/downloads/#下载地址链接:https://pan.baidu.com/s/1JxmjfGhW......
  • 51.pyinstaller打包后,打开exe程序提示SyntaxError: Non-UTF-8 code starting with '\
    最后开发了一款小工具,然后确定一切测试没有问题,想通过pyinstaller将其打包成exe,像类似的打包以前也经常打包的,复杂一点的也都是打包成功的,但这里感觉程序很简单,打包居然出现了以下错误。我的python版本是3.8.9,然后pyinstaller版本是5.9.0,不知道会不会是版本不兼容的问题,看网上哪......
  • Elasticsearch03
    1.SpringDataElasticsearch高级查询1.1.基本查询/***高级查询-基本查询*@return*/@RequestMapping("/matchQuery")publicIterable<Goods>matchQuery(){//词条查询MatchQueryBuildermatchQueryBuilder=QueryBuilders.matchQuery("title","......
  • 国产MCU-CW32F030开发学习-OLED模块
    国产MCU-CW32F030开发学习-OLED模块硬件平台CW32_48F大学计划板CW32_IOT_EVA物联网开发评估套件0.96IIColed模块软件平台KeilMDK5.31IAR串口调试助手IIC总线处理器和芯片间的通信可以形象的比喻成两个人讲话:1、你说的别人得能听懂:双方约定信号的协议。2、你......
  • 国产MCU-CW32F030开发学习-ST7735 LCD模块
    国产MCU-CW32F030开发学习-ST7735LCD模块硬件平台CW32_48F大学计划板CW32_IOT_EVA物联网开发评估套件0.96IIColed模块ST7735LCD模块硬件接口使用的2.54mm间距的排针接口,这使用杜邦线进行连接.ST7735参数供电电压3.3~5.5V驱动ICST7735分辨率12......
  • pytorch保存单通道灰度图片
    前言importtorchimporttorchvision.transformsastransformsfromtorchvision.utilsimportsave_imageimage=torch.randn(1,256,256)#示例,随机生成一个单通道图像#将图像张量保存为文件save_image(image,"single_channel_image.png",normalize=True)pytorch中......
  • 怎样导入pytorch gpu版本?
    1.下载anaconda2.在anaconda里创建环境create-npytorch_gpu#激活环境condaactivatepytorch_gpu3.在环境里install修改镜像接下来就是关键一步了,把-cpytorch表示的pytorch源,更改为国内的镜像。https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/先浏......
  • BL103BACnet网关配置采集BACnet MS/TP网关
    BL103BACnet网关是一款经济型楼宇自动化、暖通控制系统的物联网关,用于实现ModbusRTU、ModbusTCP、DL/T645、BACnetIP、BACnetMS/TP等多种协议转换为ModbusTCP、OPCUA、MQTT、BACnetIP、华为云IoT、亚马逊云IoT、阿里云IoT、ThingsBoard、金鸽云等协议的网关。BL103下行......