首页 > 其他分享 >01修建结构

01修建结构

时间:2023-07-01 16:45:34浏览次数:46  
标签:layer 01 prune nn self pruned 修建 np 结构

1非结构化剪枝

1.1.1细粒度剪枝

细粒度剪枝是一种特定类型的剪枝方法,它指的是单个权重级别的剪枝。在细粒度剪枝中,模型中的每一个权重都会被独立地考虑是否需要被剪枝。这种方法的优点是可以非常精确地控制模型的大小和复杂性,因为可以精确地选择哪些权重需要被剪枝。然而,这也是一种计算复杂度较高的方法,因为需要对每一个权重都进行评估。

下面的代码直接对权重按绝对值大小来评估重要性,剪掉绝对值大小小的

以多层感知机(MLP)为例,下面是一个稍微复杂点的MLP

import torch.nn as nn
import torch
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # 输入通道数为3,输出通道数为64,卷积核大小为3,padding为1
        self.bn1 = nn.BatchNorm2d(64)  # 批标准化层,输入通道数为64
        self.relu1 = nn.ReLU(inplace=True)  # ReLU激活函数,inplace=True表示直接修改输入的张量,而不是返回一个新的张量
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(128 * 4 * 4, 1024)  # 全连接层,输入大小为128*4*4,输出大小为1024
        self.fc2 = nn.Linear(1024, 10)  # 全连接层,输入大小为1024,输出大小为10

    def forward(self, x):
        x = self.conv1(x)  # 第一层卷积
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)  # 第二层卷积
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)  # 第三层卷积
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.conv4(x)  # 第四层卷积
        x = self.bn4(x)
        x = self.relu4(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc1(x)  # 第一个全连接层
        x = self.fc2(x)  # 第二个全连接层
        return x

接下来构建一个剪枝函数,输入是layer和prune_rate,指的是按多少比例去裁剪当前这个卷积层的权重

def prune_conv_layer(layer, prune_rate):
    ## 按比例prune掉当前卷积层的权重
    if isinstance(layer, nn.Conv2d): ## 检查对象是否是nn.Conv2d
       	##取到它里面的每一个权重,取权重里面的一个data(因为weight里面除了data还有梯度),一个是数据一个是梯度
        ##然后运回cpu,然后用numpy做实现,后续讲解pytorch实现
        weights =  layer.weight.data.cpu().numpy()
        print(weights.shape)
        ##由于我们想按比例剪枝,所以先取weight里的元素,假设有15个,先排序
        ## 将这15个元素展成一维,然后排序,只取前80%,剪掉后面的20%
        num_weights = weights.size ## 计算大小
        num_prune = int(num_weights * prune_rate) ## 计算需要剪掉多少个元素,取整
        flat_weights = np.abs(weights.reshape(-1)) ## 将权重展平,并取绝对值
        treshold = np.sort(flat_weights)[num_prune] ## 找到下标是num_prune的元素,保留大于等于treshold的元素
        weights[abs(weights) < treshold] = 0## 将小于treshold的元素置为0
        ##将剪枝后的权重转换为torch张量并赋值给卷积层的权重
        layer.weight.data = torch.from_numpy(weights).to(layer.weight.device)
        

net = Net()
prune_rate = 0.2 ## 每一层都去掉20%
for layer in net.modules():##访问每一层
	prune_conv_layer(layer, prune_rate)

1.1.2向量剪枝

向量剪枝,是把当前权重的同一行,同一列全部剪掉,赋值为0

import numpy as np
import matplotlib.pyplot as plt

def vector_pruning(matrix, idx):
    #取到idx的row和col
    row, col = idx
    pruned_matrix = matrix.copy()
    #当前行和列都置为0
    pruned_matrix[row, :] = 0
    pruned_matrix[:, col] = 0
    return pruned_matrix

matrix = np.random.randn(3, 4)
idx = (1, 2)
# prune the matrix
pruned_matrix = vector_pruning(matrix, idx)
print(matrix)
print("")
print(pruned_matrix)

1.1.3卷积核剪枝(kernel level)

对卷积层进行剪枝,将一定比例的权重设置为0

![b0059c8de2b0dd606a1fe47785388dd1](/Users/pyf/Library/Containers/com.tencent.qq/Data/Library/Application Support/QQ/nt_qq_a874608d4504d6acc7b3d847b89e39bd/nt_data/Pic/2023-06/Ori/b0059c8de2b0dd606a1fe47785388dd1.png)

# 5 * 4 * 3 * 3

def prune_conv_layer(layer, prune_rate):
    if isinstance(layer, nn.Conv2d): ## 检查对象是否是nn.Conv2d
        weights =  layer.weight.data.cpu().numpy()
        num_weights = weights.size ## 计算大小
        num_prune = int(num_weights * prune_rate) ## 计算需要剪掉多少个元素,取整
        #对每一组卷积计算L2范数,权重的平方求和,axis是需要求和的维度,5 * 4 * 3 * 3其中5是0维度
        norm_per_filter = np.sum(weight**2, axis = (1, 2, 3))#会得到5个值
        #根据L2范数排序,选择一定比例的filter,将里面的元素置为0
        #先对这5个数排序
        indices = np.argsort(norm_per_filter)[-num_prune:]
        print(indices)
        weight[indices] = 0;
        layer.weight.data = torch.from_numpy(weights).to(layer.weight.device)

下面给出一个例子

import torch
import numpy as np

#定义权重张量 (3 * 2 * 2)

weight1 = np.array([[3, 2],
                  [3, 4]])
weight2 = np.array([[5, 6],
                  [7, 8]])
weight3 = np.array([[9, 0],
                  [1, 2]])
weight = np.stack([weight1, weight2, weight3], axis = 0)
prune_rate = 2/3 # 要prune掉
print(weight.shape)
num_prune = int(weight.shape[0] * prune_rate)
l2norm_per_kernel = np.sqrt(np.sum(weight**2, axis = (1, 2)))
print(l2norm_per_kernel)
indices = np.argsort(l2norm_per_kernel)[:num_prune]
print(indices)
weight[indices]
print(weight)

结构化剪枝

可视化方法简介

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#可视化函数
def visualize_tensor(tensor, batch_spacing=3):#batch_spacing用于控制可视化后张量与张量之间的间隔
    fig = plt.figure()#创建了一个新的空白图形窗口
    ax = fig.add_subplot(111, projection='3d')#在fig窗口上添加一个新的3d子图
    for batch in range(tensor.shape[0]): #遍历张量的第一个维度,称之为batch
        for channel in range(tensor.shape[1]): # 遍历第二个维度,称之为channel
            for i in range(tensor.shape[2]): # height
                for j in range(tensor.shape[3]): # width
                    #计算每个立方体在3d图形中的位置,位置由x,y,z三个坐标决定
                    #tensor.shape[3] 是 "width" 维度的大小
                    x, y, z = j + (batch * (tensor.shape[3] + batch_spacing)), i, channel
                    # 值为0是红色,否则为灰色
                    color = 'red' if tensor[batch, channel, i, j] == 0 else 'gray'
                    #bar3d用于绘制立方体,edgecolor是边缘颜色,alpha是透明度
                    ax.bar3d(x, z, y, 1, 1, 1, shade=True, color=color, edgecolor="black", alpha=0.9)
                    
    ax.set_xlabel('Width')
    # ax.set_ylabel('B & C')
    ax.set_zlabel('Height')
    #ax.set_zlim(ax.get_zlim()[::-1]):这个函数设置了 z 轴的范围。具体来说,ax.get_zlim() 获取了当前 z 轴的范围,然后 [::-1] 将这个范围反转。这意味着 z 轴的方向被反转,也就是说,较大的 "channel" 值将被显示在图形的下方,而较小的 "channel" 值将被显示在图形的上方。这样做的目的是为了使图形的显示更符合常规的 3D 观察习惯。
    ax.set_zlim(ax.get_zlim()[::-1])
    #控制 z 轴标签和 z 轴之间的距离,将其设置为 15 是为了确保 z 轴的标签有足够的空间显示,不会与图形元素重叠。
    ax.zaxis.labelpad = 15 # adjust z-axis label position
    
    plt.show()
    
    
def prune_conv_layer(conv_layer, prune_method, percentile=20, vis=True):
    pruned_layer = conv_layer.copy()    
    if prune_method == "fine_grained":
        pruned_layer[np.abs(pruned_layer) < 0.05] = 0
    if prune_method == "vector_level":
        #np.linalg.norm默认计算L2范数,计算L1范数可以这样写np.linalg.norm(x, ord=1)
        #axis=-1表示沿着最后一个维度来计算范数
        l2_sum = np.linalg.norm(pruned_layer, axis=-1)
    if prune_method == "kernel_level":
        # 计算每个kernel的L2范数
        l2_sum = np.linalg.norm(pruned_layer, axis=(-2, -1))
    if prune_method == "filter_level":
        # 计算每个filter的L2范数,因为np.linalg.norm无法对3维求范数
        # 这里先平方,再求和,只是没有开根
        l2_sum = np.sum(pruned_layer**2, axis=(-3, -2, -1))
    if prune_method == "channel_level":
        # 计算每个filter的L2范数
        l2_sum = np.sum(pruned_layer**2, axis=(-4, -2, -1))
        #这里对l2_sum做reshape是因为考虑的1/3/4维度,补上第一维度,后面mask才能正确索引
        # add a new dimension at the front
        l2_sum = l2_sum.reshape(1, -1)  # equivalent to l2_sum.reshape(1, 10)

        # repeat the new dimension 8 times
        #np.repeat是指沿着第一个轴(axis=0)重复 l2_sum,重复的次数是 pruned_layer 的第一个维度的长度
        l2_sum = np.repeat(l2_sum, pruned_layer.shape[0], axis=0)
    #np.percentile 是 numpy 库中的一个函数,它用于计算给定数据的指定百分位数。
    #例如,np.percentile(data, 25) 将会计算 data 的第25百分位数。
    #代码的意思是计算l2_sum里百分之percentile的数是多少
    threshold = np.percentile(l2_sum, percentile)
    #以threshold作为阈值如果 l2_sum 中的某个元素小于 threshold,那么对应位置的布尔值就会是 True;否则,它就会是 False。
    mask = l2_sum < threshold
    print(pruned_layer.shape)
    print(mask.shape)
    print("-----------------------------")
    #mask标记了哪些位置为0
    pruned_layer[mask] = 0
    if vis:
        visualize_tensor(pruned_layer)
    return pruned_layer
# 生成一个tensor
#uniform是指按照均匀分布生成随机浮点数
#np.random.uniform(low, high, size)是指生成最小值low,最大值high均匀分布的随机值
tensor = np.random.uniform(low = -1, high = 1, size = (3, 10, 4, 5))

# Prune the conv layer and visualize it
pruned_tensor = prune_conv_layer(tensor, "vector_level", vis=True)
pruned_tensor = prune_conv_layer(tensor, "kernel_level", vis=True)
pruned_tensor = prune_conv_layer(tensor, "filter_level", vis=True)
pruned_tensor = prune_conv_layer(tensor, "channel_level",percentile=40, vis=True)

标签:layer,01,prune,nn,self,pruned,修建,np,结构
From: https://www.cnblogs.com/125418a/p/17519486.html

相关文章

  • 【题解】#119. 最大整数 题解(2023-07-01更新)
    #119.最大整数题解题目传送门更新日志2023-05-2617:20文章完成2023-05-3015:22文章审核通过2023-07-0116:04修改了代码题目知识点字符串+贪心题意说明设有n个正整数($n<20$),将它们连接成一排,组成一个最大的多位整数。(题目简介明了,一看就是出题人懒得写题目背景)......
  • 【题解】P8679 [蓝桥杯 2019 省 B] 填空问题 题解
    P8679[蓝桥杯2019省B]填空问题题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-05-2521:02文章完成2023-05-2711:34文章通过审核2023-06-2021:03优化了文章代码格式试题A:组队【解析】本题是一道经典的DFS搜索题,每次对各号位的选手进行DFS,......
  • 【题解】P8684 [蓝桥杯 2019 省 B] 灵能传输 题解
    P8684[蓝桥杯2019省B]灵能传输题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-06-2021:46文章完成【解析】本题涉及到了$3$种算法:前缀和,排序以及贪心(1)前缀和本题实际上要求通过某种灵能传输可以使得该序列的最大值最小。而由前缀和可知,当某一个前......
  • 【置顶】FZQOJ题解集(2023-07-01更新)
    #68.「NOIP2004」津津的储蓄计划题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-02-0117:20文章完成2023-02-0316:09文章审核通过2023-02-0422:15修改了注释2023-05-2709:27修改了$\LaTeX$2023-07-0115:45修改了代码题目知识点模拟题目分析......
  • 【置顶】luogu题解集(2023-07-01更新)
    P8679[蓝桥杯2019省B]填空问题题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-05-2521:02文章完成2023-05-2711:34文章通过审核2023-06-2021:03优化了文章代码格式试题A:组队【解析】本题是一道经典的DFS搜索题,每次对各号位的选手进行DFS,......
  • 01_Maven
    1.Maven是什么?Maven翻译为"专家"、"内行",是Apache下的一个纯Java开发的开源项目(https://maven.apache.org/)。基于项目对象模型(缩写:POM)概念,Maven利用一个中央信息片断能管理一个项目的构建、报告和文档等步骤。Maven是一个项目开发结构管理工具,可以对Java项目结构、......
  • 算法学习day03链表part01-203、707、206
    packageSecondBrush.LinkedList.LL1;/***203.移除链表元素*删除链表中等于给定值val的所有节点。*自己再次概述一下这个过程:*1.移除元素,要采用设置虚拟节点的方式,因为那样不需要考虑头结点问题*2.设置两个虚拟指向*3.移除元素就是遍历链表,然后碰到目标值......
  • 【题解】#373. 「USACO1.1」Friday the Thirteenth 题解(2023-07-01更新)
    #373.「USACO1.1」FridaytheThirteenth题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-02-0117:20文章完成2023-02-0318:50文章审核通过2023-02-0319:17修改了注释2023-05-2520:25修改了$\LaTeX$2023-05-2520:32再次修改了$\LaTeX$,感谢ACRU......
  • 10.11 定义枚举结构
    demo1在枚举类中定义成员属性与方法enumColor{ //枚举类 RED("红色"),GREEN("绿色"),BLUE("蓝色"); //枚举对象要写在首行 privateStringtitle;//成员属性 privateColor(Stringtitle){//构造方法初始化属性; this.title=title; } @Override publicStrin......
  • 【题解】#105. 「USACO1.3」Ski Course Design 题解(2023-07-01更新)
    #105.「USACO1.3」SkiCourseDesign题解题目传送门欢迎大家指出错误并联系这个蒟蒻更新日志2023-02-0117:20文章完成2023-02-0316:09文章审核通过2023-02-0422:15修改了注释2023-05-1621:44修改了$\LaTeX$2023-07-0115:59修改了代码题目知识点模拟+搜索......