常用剪枝工具
pytorch官方案例
import torch.nn.utils.prune as prune
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
print(torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
module = model.conv1
prune.random_structurd(module, name="weight", amount=0.3, dim=1)
#对同一层进行连续不同的剪枝
prune.l1_unstructured(module, name="weight", amount=3)
prune.l1_unstructured(module, name="bias", amount=3)
prune.ln_structured(module, name="bias", amount=0.5, n=3, dim=0)
序列化剪枝后的模型
在PyTorch中,named_buffers()
是一个模型的方法,它返回一个迭代器,这个迭代器包含了模型中所有持久化的缓冲区。在每次迭代中,它返回一个包含缓冲区名(name
)和缓冲区的张量(tensor
)的元组。
在神经网络中,有些数据虽然不是模型参数(也就是不会在反向传播中被更新),但是这些数据在前向传播过程中是需要的,这些数据就被称为缓冲区(buffer)。缓冲区通常用于存储不参与梯度计算,但需要在训练过程中持久化的数据。例如,批归一化(Batch Normalization)层中的运行平均值和运行方差就是存储在缓冲区中的。
对于剪枝操作来说,剪枝的掩码通常会被保存为一个缓冲区。这个掩码的作用是在前向传播过程中把被剪枝的权重(也就是被设为0的权重)从计算中排除出去。
所以,named_buffers()
函数就是用来获取模型中所有缓冲区的名称和对应的数据。这在进行剪枝操作时,可以用来检查剪枝的掩码是否已经被正确地添加到模型中。
#state_dict()是一个PyTorch模型的方法,它返回一个字典,其中包含了模型的所有参数,包括权重和偏置。字典的键是参数的名称,值是参数的值。这个字典可以用于保存和加载模型的参数。
#keys()是Python字典的一个方法,它返回字典的所有键的列表。
#所以,model.state_dict().keys()返回的是一个包含模型中所有参数名称的列表。weight和bias
print(model.state_dict().keys())
new_model = LeNet()
#这行代码开始遍历模型中的所有模块(或层)。named_modules()函数返回一个迭代器,每次迭代返回一个包含模块名(name)和模块实例(module)的元组。
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
global pruning
model = LeNet()
#第一个元素是model,第二个元素是这个model里哪一些参数要被剪掉
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
#进行全局无结构剪枝
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
"Sparsity"(稀疏性)是一个数学概念,用于描述一个矩阵中零元素的比例。在深度学习中,稀疏性通常用来描述模型权重矩阵中零值的比例。
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
自定义pruning functions
下面是每隔一个就进行一次非结构化剪枝
pytorch源码参考: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/prune.py#:~:text=%40abstractmethod,method recipe.
#该类是prune.BasePruningMethod的子类
class ImplEveryOtherPruningMethod(prune.BasePruningMethod):
#定义剪枝类型
PRUNING_TYPE = 'unstructured'
#重写了基类中的抽象方法compute_mask。该方法接收两个参数,一个是待剪枝的张量t,另一个是默认的掩码default_mask。
def compute_mask(self, t, default_mask):
#创建一个default_mask的副本,这是为了避免改变原始的default_mask。
mask = default_mask.clone()
#这个操作首先将掩码的形状改为一维mask.view(-1),然后选择索引为偶数的所有元素[::2],将它们设置为0。这样就达到了每隔一个元素剪枝的效果。
mask.view(-1)[::2] = 0
return mask
def Ieveryother_unstructured_prune(module, name):
#生成一个想要的mask,并且apply到module的元素上
ImplEveryOtherPruningMethod.apply(module, name)
return module
model = LeNet()
Ieveryother_unstructured_prune(model.fc3, name='bias')
print(model.fc3.bias_mask)
标签:剪枝,prune,weight,03,torch,module,pytorch,model
From: https://www.cnblogs.com/125418a/p/17519499.html