首页 > 编程语言 >基于Python的自然语言处理系列(22):模型剪枝(Pruning)

基于Python的自然语言处理系列(22):模型剪枝(Pruning)

时间:2024-10-04 09:49:12浏览次数:12  
标签:剪枝 nn 22 dim Python text torch batch

        在深度学习领域,尤其是当模型部署到资源有限的环境中时,模型压缩技术变得尤为重要。剪枝(Pruning)是一种常见的模型压缩方法,通过减少模型中不重要的参数,可以在不显著降低模型性能的情况下提升效率。在本文中,我们将详细介绍如何在PyTorch中使用剪枝技术,并通过一些实验展示其效果。

1. 加载数据集与预处理

        我们将使用TorchText库加载常用的AG_NEWS数据集,并进行预处理。首先,导入必要的库并设置随机种子以保证实验的可重复性。        

import torch, torchdata, torchtext
from torch import nn
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

        接下来,加载AG_NEWS数据集,并将其拆分为训练集、验证集和测试集。我们使用TorchText的random_split方法来进行数据划分。

from torchtext.datasets import AG_NEWS
train, test = AG_NEWS()

train_size = len(list(iter(train)))
too_much, train, valid = train.random_split(total_length=train_size, weights = {"too_much": 0.7, "smaller_train": 0.2, "valid": 0.1}, seed=999)

数据预处理

        我们将使用Spacy作为分词器,并将文本转换为整数表示。这里,我们使用build_vocab_from_iterator来为数据集生成词汇表。

from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(vocab["<unk>"])

2. FastText 预训练词向量

        接下来,我们将加载FastText的预训练词向量,并将其应用到我们的词汇表中。

from torchtext.vocab import FastText
fast_vectors = FastText(language='simple') # 使用FastText预训练词向量

fast_embedding = fast_vectors.get_vecs_by_tokens(vocab.get_itos()).to(device)
fast_embedding.shape

3. 数据加载器

        我们需要定义一个数据加载器collate_fn,来确保批处理中的序列长度一致(通过填充)。在这里,我们还会生成序列长度信息,以便后续用于LSTM中的打包序列处理。

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

pad_idx = vocab['<pad>']

def collate_batch(batch):
    label_list, text_list, length_list = [], [], []
    for (_label, _text) in batch:
        label_list.append(int(_label) - 1)  # 标签从0开始
        processed_text = torch.tensor([vocab[token] for token in tokenizer(_text)], dtype=torch.int64)
        text_list.append(processed_text)
        length_list.append(processed_text.size(0))
    return torch.tensor(label_list, dtype=torch.int64), pad_sequence(text_list, padding_value=pad_idx, batch_first=True), torch.tensor(length_list, dtype=torch.int64)

train_loader = DataLoader(train, batch_size=64, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid, batch_size=64, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test, batch_size=64, shuffle=False, collate_fn=collate_batch)

4. 模型定义

        我们将定义一个双向LSTM模型,并将预训练的FastText词向量作为模型的嵌入层权重初始化。

class LSTM(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hid_dim * 2, output_dim)

    def forward(self, text, text_lengths):
        embedded = self.embedding(text)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False, batch_first=True)
        packed_output, (hn, cn) = self.lstm(packed_embedded)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim=1)
        return self.fc(hn)

5. 模型训练与评估

        我们定义模型的训练与评估函数,并加载预训练好的模型进行测试。

criterion = nn.CrossEntropyLoss()

def accuracy(preds, y):
    predicted = torch.max(preds.data, 1)[1]
    return (predicted == y).sum().item() / len(y)

def evaluate(model, loader, criterion):
    model.eval()
    epoch_loss, epoch_acc = 0, 0
    with torch.no_grad():
        for label, text, text_length in loader:
            label, text = label.to(device), text.to(device)
            predictions = model(text, text_length).squeeze(1)
            loss = criterion(predictions, label)
            acc = accuracy(predictions, label)
            epoch_loss += loss.item()
            epoch_acc += acc
    return epoch_loss / len(loader), epoch_acc / len(loader)

test_loss, test_acc = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

6. 剪枝(Pruning)

随机剪枝

        首先,我们使用PyTorch的torch.nn.utils.prune库对模型进行随机剪枝。例如,以下代码将随机剪掉全连接层中95%的连接。

import torch.nn.utils.prune as prune

fc = model.fc
prune.random_unstructured(fc, name="weight", amount=0.95)
print(list(fc.named_buffers()))  # 打印权重掩码

基于L1范数的剪枝

        我们还可以基于权重的L1范数进行剪枝,以下代码展示了如何根据最小的L1范数剪枝95%的连接。

prune.l1_unstructured(fc, name="weight", amount=0.95)
print(fc.weight)

全局剪枝

        全局剪枝是通过在整个模型中移除最低重要性的连接,而不是逐层进行剪枝。我们可以使用global_unstructured来实现这一目标。

parameters_to_prune = [(model.embedding, 'weight'), (model.lstm, 'weight_ih_l0'), (model.lstm, 'weight_hh_l0'), (model.fc, 'weight')]
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.7)

7. 自定义剪枝方法

        我们还可以通过继承torch.nn.utils.prune.BasePruningMethod类来自定义剪枝方法。下面是一个简单的自定义剪枝示例,剪去张量中的每隔一个元素。

class ExamplePruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = 'unstructured'
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

结语

        在本篇文章中,我们探讨了模型剪枝(Pruning)的多种方法,包括随机剪枝、基于L1范数的剪枝和全局剪枝等。这些技术可以有效减少模型参数量,在不明显降低性能的情况下,显著提升模型的推理效率。剪枝方法的选择应根据模型和任务的特点来决定,不同的剪枝策略适用于不同的场景。

        剪枝作为模型压缩的一部分,尤其在部署到计算资源受限的设备时,能够大幅减少计算负担。同时,自定义剪枝方法也提供了灵活性,允许开发者根据需求进行更细粒度的优化。通过本文的实践,大家可以尝试不同的剪枝方法,观察其对模型大小和性能的影响。

        在下一篇文章中,我们将介绍DrQA,这是一个针对问答系统的经典模型。我们会探讨如何构建一个可以回答复杂问题的问答系统,继续深入自然语言处理领域的实际应用。

 

 

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

 

标签:剪枝,nn,22,dim,Python,text,torch,batch
From: https://blog.csdn.net/ljd939952281/article/details/142645757

相关文章

  • 用python写一个脚本:将指定目录下及其所有子文件夹下的视频文件按中间时间切分成两部分
    代码:importosfrommoviepy.editorimportVideoFileClipdefsplit_video(video_path,output_dir):#加载视频文件clip=VideoFileClip(video_path)duration=clip.duration#计算中间时间点midpoint=duration/2#创建输出目录i......
  • 22.响应式网络推广建站公司网页 Web前端网页制作 大学生期末大作业 html+css+js
     目录 一、前言 二、网页文件 三、网页效果四、代码展示1.HTML2.CSS3.JS 五、更多推荐一、前言 本实例应用html+css+js,响应式布局,可以根据不同的设备屏幕大小自动调整页面布局,手机等移动设备自适应界面,提高用户体验;支持包括IE、Firefox、Chrome、Safari等主......
  • 在VS2022上安装pygame模块
    一、安装在vs2022中随便打开或生产一个python项目,找到最右边的“解决方案资源管理器”,并找到“python环境”,点击鼠标右键打开“查看所有python环境”打开以后找到下面的“在PowerShell中打开”,点击打开然后输入”pipinstallpygame“并等待安装即可二、测试输入以下代码并运......
  • P9752 [CSP-S 2023] 密码锁&&P8814 [CSP-J 2022] 解密
    GutenTag!Schön,dichzusehen!今天也是很懒惰的一天呢!所以今天三合一!题目:[CSP-S2023]密码锁题目描述小Y有一把五个拨圈的密码锁。如图所示,每个拨圈上是从$0$到$9$的数字。每个拨圈都是从$0$到$9$的循环,即$9$拨动一个位置后可以变成$0$或$8$,因为校园里......
  • PotPlayer(免费媒体播放器) v1.7.22233.0 多语便携版
    概述PotPlayer是一款由韩国企业Daum开发的免费媒体播放器,它提供了丰富的功能和特点,使其成为许多用户的首选播放器。 软件功能支持多种音视频格式:PotPlayer支持大多数常见的音视频格式,包括MP4、AVI、MKV、MOV、FLV、MP3、WAV等。高质量的音视频播放:PotPlayer采用了先进的解码......
  • 2023-12-15 博士挑战--不完美达成 122918
    目录总纲现状反思未来总纲宇宙万物、世间一切自有因果。自助者天助,然若不自助,神明亦爱莫能助。人的好坏定义并非从言行举止来衡量,而是有无执着。现状老师也不会再强烈要求我们小组每个人都成为顶尖科学家了,也不会执着于己见了。我目前一切都好,正在做想做的理论方向且成......
  • 2023-11-25 Matlab和Python在气象中的常用代码 180401
    目录画图横坐标添加月份PythonMatlab画图横坐标添加月份Pythonimportmatplotlib.pyplotaspltimportpandasaspdimportnumpyasnp#准备时间和温度数据start_date=pd.to_datetime('1996-12-01')#thenextdateend_date=pd.to_datetime('1998-12-01')#the......
  • python基础(二)之字符串
    字符串的定义Python中的字符串可以使用单引号、双引号和三引号(三个单引号或三个双引号)括起来字符串的引号嵌套单引号定义法,可以内含双引号双引号定义法,可以内含单引号可以使用 \转义特殊字符来解除引号效用,变成普通字符串字符串的拼接和重复使用“+”号连接字符串变量......
  • Python异常处理:让你的代码更稳健的魔法
    引言:你是否曾经在代码中迷失?想象一下,你正在编写一个重要的Python程序,突然间,屏幕上弹出一条错误信息,仿佛一只无形的手将你的努力撕得粉碎。你是否曾经感到无助,甚至想要放弃?根据统计,程序员在开发过程中,约有70%的时间都在处理错误和异常。可见,异常处理不仅是编程的“必修课”,更是......
  • 用Python实现运筹学——Day 9: 线性规划的灵敏度分析
    一、学习内容1.灵敏度分析的定义与作用灵敏度分析(SensitivityAnalysis)是在优化问题中,分析模型参数变化对最优解及目标函数值的影响。它帮助我们了解在线性规划模型中,当某些参数(如资源供应量、成本系数等)发生变化时,最优解是否会发生变化,以及这种变化的幅度。灵敏度分析的......