首页 > 其他分享 >【人人都能学得会的NLP - 文本分类篇 06】基于 Prompt 的小样本文本分类实践

【人人都能学得会的NLP - 文本分类篇 06】基于 Prompt 的小样本文本分类实践

时间:2024-12-04 22:58:34浏览次数:5  
标签:NLP 训练 模型 分类 news model 文本 ds

【人人都能学得会的NLP - 文本分类篇 06】基于 Prompt 的小样本文本分类实践


NLP Github 项目:


1、任务说明

随着预训练语言模型规模的增长,“预训练-微调”范式在下游自然语言处理任务上的表现越来越好,但与之相应地对训练数据量和计算存储资源的要求也越来越高。为了充分利用预训练语言模型学习到的知识,同时降低对数据和资源的依赖,提示学习(Prompt Learning)作为一种可能的新范式受到了越来越多的关注,在 FewCLUE、SuperGLUE 等榜单的小样本任务上取得了远优于传统微调范式的结果。

提示学习(Prompt Learning)的核心思想是将下游任务转化为预训练阶段的掩码预测(MLM)任务。实现思路包括通过模板(Template)定义的提示语句,将原有任务转化为预测掩码位置的词,以及通过标签词(Verbalizer)的定义,建立预测词与真实标签之间的映射关系。

以情感分类任务为例,“预训练-微调”范式和“预训练-提示”范式(以 PET 为例)之间的区别如下图所示

【微调学习】使用 [CLS] 来做分类,需要训练随机初始化的分类器,需要充分的训练数据来拟合。

【提示学习】通过提示语句和标签词映射的定义,转化为 MLM 任务,无需训练新的参数,适用于小样本场景。

2、预训练语言模型的学习范式

2.1 预训练模型 + 参数微调

参数微调方法存在的问题:

  • 下游任务的数据稀缺,导致过拟合的问题
  • 预训练任务和下游任务目标不一致
  • 预训练模型的参数量越来越大
    • 时间成本越大
    • 存储空间越大

缓解参数微调问题的方法:提示学习

2.2 预训练模型 + 提示学习

提示学习的优势:

  • (1)降低预训练任务(掩码语言模型或自回归语言模型)与下游任务之间的任务差距
  • (2)更好地适用于少样本(few-shot)单个样本(one-shot) 甚至零样本(zero-shot) 的情况
  • (3)当预训练模型参数量较大时,降低存储空间

3、提示学习的重要过程

  1. 预训练模型的选择:基于编码器的模型、基于解码器的模型、基于编码器-解码器的模型
  2. 提示工程(Prompt Engineering):构建较好的提示函数,更好地提升下游任务
  3. 答案工程(Answer Engineering/ Verbalizer):选择更好地答案集合,并映射到对应的标签
  4. 多提示学习:融合不同提示的各自优势
  5. 提示学习的参数更新策略:如何更新预训练模型的参数

3.1 预训练模型的选择

3.2 提示工程(Prompt Engineering)


3.3 答案工程(Answer Engineering/ Verbalizer)

3.4 多提示学习

3.5 提示学习的参数更新策略

4. 实现思路及流程

根据上边介绍,基于 Prompt API 实现文本分类的思路如下所示,模型的输入文本根据模板(Template)进行预处理,模型的输出结果经过标签词映射(Verbalizer)得到预测的映射词。

在建模过程中,对于输入文本,首先将其处理为模板 API 能够处理的标准形式,根据任务定义模板和标签词映射,调用模板 API 进行文本模板组合和文本序列编码,获得文本的语义向量表示;然后经过预训练语言模型得到预测向量,调用标签词映射的 API 取出标签词对应的概率。

基于 Prompt API 实现小样本提示学习文本分类的过程主要包括以下6个步骤:

(1)模型构建:确定文本分类使用的模型,本实践使用ERNIE-3.0 Base模型进行文本编码和标签词预测。

(2)数据准备:对于输入的文本进行相应的处理,包括数据标准化、模板定义、标签词映射、文本编码等。

(3)训练配置:配置训练参数,使用 PromptTrainer API 进行环境、模型、优化器、训练预测等流程的自动初始化。

(4)模型训练:训练模型参数,以达到最优效果。

(5)模型评估:对训练好的模型进行评估测试,观察准确率和损失函数的变化情况。

(6)模型预测:选取一段新闻,判断新闻类别。

以下分别介绍每个步骤的具体实现过程。

5. 模型构建

我们使用ERNIE 3.0 Base作为预训练模型用于新闻分类。提示学习本质上是掩码预测(MLM)任务,因此可以使用 AutoModelForMaskedLM 来加载模型参数。

from paddlenlp.transformers import AutoTokenizer, AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("ernie-3.0-base-zh")
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-base-zh")

6. 数据准备

数据准备过程包括数据集确定、数据标准化、模板定义、标签词映射定义等步骤。本实践使用 PromptTrainer 进行训练,该 API 封装了 Prompt 相关的数据预处理过程,如模板文本组合和文本分词、编码的过程,因此不需要构造 DataLoader。

(1)数据集确定

FewCLUE 是专门用于中文小样本学习能力测评的榜单,涵盖了情感分析、新闻分类、语义匹配、指代消歧等阅读理解任务。这里我们使用其中的新闻分类数据集 TNEWS 作为示例,共包括15个新闻类别,每个类别有16条标注数据用于训练。除此之外,有240条标注数据用于验证,2010条数据用于测试。

PaddleNLP 中内置了该数据集,可直接调用 load_dataset 加载数据。

from paddlenlp.datasets import load_dataset

train_ds, dev_ds, test_ds = load_dataset("fewclue", "tnews", splits=["train_0", "dev_0", "test_public"])

(2)数据标准化

Prompt API 规定了输入数据的格式,我们需要先将已有数据转化为 InputExample 封装的标准格式。以 TNEWS 为例,转换代码如下

from paddlenlp.datasets import MapDataset
from paddlenlp.prompt import InputExample

def convert_tnews_to_example(data_ds):
    std_data = []
    for sample in data_ds:
        std_sample = InputExample(uid=sample["id"],
                                  text_a=sample["sentence"],
                                  text_b=None,
                                  labels=sample["label_desc"])
        std_data.append(std_sample)
    std_data_ds = MapDataset(std_data)
    return std_data_ds

train_ds = convert_tnews_to_example(train_ds)
dev_ds = convert_tnews_to_example(dev_ds)
test_ds = convert_tnews_to_example(test_ds)

(3)定义模版

模板(Template)的功能是在原有输入文本上增加提示语句,从而将原任务转化为 MLM 任务,可以分为离散型和连续型两种。更多信息可参考 Prompt 文档介绍。

本实践使用了 AutoTemplate API,支持快速定义手工初始化的连续模板,同时支持自动切换离散型和连续型模板。

  • 只定义用于初始化连续型向量的文本提示,即可得到拼接到句尾的连续型模板输入。例如,
"这条新闻标题的主题是"

等价于

"{'text': 'text_a'}{'soft': '这条新闻标题的主题是'}{'mask'}"

模板关键字

  • text :数据集中原始输入文本对应的关键字,包括text_atext_b
  • hard :自定义的文本提示语句。
  • mask :待预测词的占位符。
  • soft 表示连续型提示。若值为 None ,则随机初始化;若值为文本,则使用对应长度的连续性向量作为提示,并预训练词向量中文本对应的向量进行初始化。
from paddlenlp.prompt import AutoTemplate

prompt = "这条新闻标题的主题是"
template = AutoTemplate.create_from(
        prompt,
        tokenizer,
        max_seq_length=512,
        model=model,
        prompt_encoder="lstm",
        encoder_hidden_size=200)

(4)定义标签词映射

标签词映射(Verbalizer)也是提示学习中的重要模块,用于建立预测词和标签之间的映射,从而在下游任务与预训练任务间建立联系。更多信息可参考标签词映射 API 文档。

本实践使用了 SoftVerbalizer API,基于 WARP 的思想修改了 ErnieModelForMaskedLM 的模型结构,将预训练模型最后一层“隐藏层-词表”替换为“隐藏层-标签”的映射。该层网络的初始化参数由标签词映射中的预测词词向量来决定,如果预测词长度大于一,则使用词向量均值进行初始化。

from paddlenlp.prompt import SoftVerbalizer

label_word_map = {
    "news_story":  "八卦",
    "news_entertainment": "明星",
    "news_finance": "经济",
    "news_sports": "体育",
    "news_edu": "校园",
    "news_game": "游戏",
    "news_culture": "文化",
    "news_tech": "科技",
    "news_car": "汽车",
    "news_travel": "旅行",
    "news_world": "国际",
    "news_agriculture": "农业",
    "news_military": "军事",
    "news_house": "房子",
    "news_stock": "股票"
}

verbalizer = SoftVerbalizer(tokenizer, 
                            model, 
                            labels=list(label_word_map.keys()),
                            label_words=label_word_map)

def convert_labels_to_ids(data_ds):
    new_data_ds = []
    for sample in data_ds:
        sample.labels = verbalizer.labels_to_ids[sample.labels]
        new_data_ds.append(sample)
    return MapDataset(new_data_ds)

train_ds = convert_labels_to_ids(train_ds)
dev_ds = convert_labels_to_ids(dev_ds)
test_ds = convert_labels_to_ids(test_ds)

示例

按照上述定义,调用 Prompt API 就可以得到模型需要的输入了。为了便于理解模板、标签词映射以及学习任务之间的关系,这里我们举个具体的例子。

  • 给定新闻分类 TNEWS 数据集中的一条标注数据。
{"label": 109, "label_desc": "news_tech", "sentence": "联想被踢出恒生指数,是什么导致了联想现在的这种境地?", "keywords": "", "id": 1522}
  • 将其标准化为 InputExample 实例为
InputExample(text_a="联想被踢出恒生指数,是什么导致了联想现在的这种境地?", labels="news_tech")
  • 调用模板 API 将上述实例与模板拼接,得到预训练模型的输入文本如下所示
联想被踢出恒生指数,是什么导致了联想现在的这种境地?这条新闻标题的主题是[MASK]
  • 标签词映射将 “news_tech” 映射为 “科技”,即我们希望 [MASK] 的位置预测结果为 “科技”,我们期望的完整预测结果如下
联想被踢出恒生指数,是什么导致了联想现在的这种境地?这条新闻标题的主题是科技
  • 在实践中,将定义标签词填入模板[MASK]的位置,得到的句子越通顺自然,学习效果越好。

7. 训练配置

本实践使用了 PromptTrainer 进行模型训练,该 API 封装了文本分类任务的整体训练流程,只需要定义必要模块,无需重复编写模板拼接、标签词映射、优化器、训练流程控制等代码,便于快速开发实践。

PromptTrainer 继承自 Trainer API,训练参数推荐使用命令行进行设置。为了方便在 Notebook 中配置参数,这里使用了列表定义的方式。更多参数配置介绍可参考Trainer文档PromptTrainer文档

from paddlenlp.prompt import PromptTuningArguments
from paddlenlp.trainer import PdArgumentParser

# 训练参数
config = ["--output_dir", "./checkpoints/", 
          "--learning_rate", "3e-5",
          "--ppt_learning_rate", "3e-4",
          "--num_train_epochs", "100",
          "--logging_steps", "5",
          "--per_device_train_batch_size", "4",
          "--per_device_eval_batch_size", "4",
          "--metric_for_best_model", "accuracy",
          "--load_best_model_at_end", "True",
          "--evaluation_strategy", "epoch",
          "--save_strategy", "epoch",
          "--load_best_model_at_end", "True"
         ]
parser = PdArgumentParser((PromptTuningArguments,))
training_args = parser.parse_args_into_dataclasses(args=config,
                                                   look_for_args_file=False)[0]

与提示学习相关的分类模型封装在了 PromptModelForSequenceClassification 中,可以通过 freeze_plm 参数控制训练过程中预训练模型参数是否更新,freeze_dropout 在前者的基础上进一步关闭了 dropout,以降低提示学习相关参数的学习难度。

根据实验经验,Base/Large规模的模型在训练时同时更新预训练模型参数效果较好。

from paddlenlp.prompt import PromptModelForSequenceClassification

# Prompt 分类模型
prompt_model = PromptModelForSequenceClassification(
        model,
        template,
        verbalizer,
        freeze_plm=False,
        freeze_dropout=False)

除了模型,Trainer的初始化还需要定义损失函数、评估函数、训练策略等模块。这里分别使用了交叉熵作为损失函数、准确度作为评估函数,以及内置的早停 Callback 用于控制训练在何时结束。

import paddle
from paddle.metric import Accuracy
from paddlenlp.prompt import PromptTrainer
from paddlenlp.trainer import EarlyStoppingCallback

# 损失函数
criterion = paddle.nn.CrossEntropyLoss()

# 评估函数
def compute_metrics(eval_preds):
    metric = Accuracy()
    correct = metric.compute(paddle.to_tensor(eval_preds.predictions),
                             paddle.to_tensor(eval_preds.label_ids))
    metric.update(correct)
    acc = metric.accumulate()
    return {'accuracy': acc}

# 早停策略(可选)
callbacks = [
    EarlyStoppingCallback(early_stopping_patience=4,
                          early_stopping_threshold=0.)
]

# Trainer 定义
trainer = PromptTrainer(model=prompt_model,
                        tokenizer=tokenizer,
                        args=training_args,
                        criterion=criterion,
                        train_dataset=train_ds,
                        eval_dataset=dev_ds,
                        callbacks=callbacks,
                        compute_metrics=compute_metrics)

8. 模型训练

Trainer 中封装了模型训练、模型保存、日志打印等模块,直接调用相应的方法即可实现。

train_result = trainer.train(resume_from_checkpoint=None)
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_model()
trainer.save_metrics("train", metrics)
trainer.save_state()

7. 模型评估

通常来讲,我们会划分出测试集来评估模型的泛化效果。

test_ret = trainer.predict(test_ds)
trainer.log_metrics("test", test_ret.metrics)

[2022-10-20 19:58:29,383] [ INFO] - ***** Running Prediction *****
[2022-10-20 19:58:29,386] [ INFO] - Num examples = 2010
[2022-10-20 19:58:29,388] [ INFO] - Pre device batch size = 4
[2022-10-20 19:58:29,391] [ INFO] - Total Batch size = 4
[2022-10-20 19:58:29,393] [ INFO] - Total prediction steps = 503

8. 模型预测

import numpy as np
from paddlenlp.prompt import InputFeatures

def infer(model, text):
    model.eval()
    inputs = [InputExample(text_a=sample) for sample in text]
    inputs = [model.template.wrap_one_example(sample) for sample in inputs]
    inputs = InputFeatures.collate_fn(inputs)

    outputs = model(inputs["input_ids"],
                    inputs["mask_ids"],
                    inputs.get("soft_token_ids", None),
                    return_hidden_states=False)
    preds = np.argmax(outputs, axis=-1).tolist()
    for idx, sample in enumerate(text):
        label = model.verbalizer.ids_to_labels[preds[idx]]
        print(f"新闻文本: {sample}, 预测类别: {label}")

infer(prompt_model, ["炒期货能成亿万富豪吗?", "季后赛最有价值球员榜,浓眉第5 哈登第3,榜首太霸道"])
新闻文本: 炒期货能成亿万富豪吗?, 预测类别: news_stock
新闻文本: 季后赛最有价值球员榜,浓眉第5 哈登第3,榜首太霸道, 预测类别: news_sports

【动手学 RAG】系列文章:

【动手部署大模型】系列文章:

【人人都能学得会的NLP】系列文章:

标签:NLP,训练,模型,分类,news,model,文本,ds
From: https://blog.csdn.net/weixin_44025655/article/details/144203258

相关文章

  • 分类算法中的样本不平衡问题及其解决方案
    一、样本不平衡问题概述在机器学习的分类任务中,样本不平衡是指不同类别训练样本数量存在显著差异的现象。这一差异会给模型训练和性能评估带来挑战,尤其在处理少数类样本时,模型可能难以有效学习其特征。以二分类为例,理想情况下正负样本数量应相对平衡,如各1000个样本时,模......
  • 从零开始利用coze智能体API提取抖音视频文本内容
    作用:可以将抖音视频说话的内容转成文本。本文从零开始搭建coze智能体到添加解析插件、到开通API、再创建请求密钥全流程讲解。完全从零开始一步步操作,直至达到最终目的。扣子的API能力个人免费使用【注上免费请求说明】免费版和专业版的对比:文档链接:https://www.coze.cn/docs......
  • 论文泛读《PICCOLO : Exposing Complex Backdoors in NLP Transformer Models》
    发表时间:2022期刊会议:IEEESymposiumonSecurityandPrivacy(SP)论文单位:PurdueUniversity论文作者:YingqiLiu,GuangyuShen,GuanhongTao,ShengweiAn,ShiqingMa,XiangyuZhang方向分类:BackdoorAttack论文链接开源代码摘要后门可以被注入到NLP模型中,使得当......
  • 一个简单的图像分类神经网络
     importtorchimporttorch.onnxfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTensorbatch_size=64device="cuda"#这部分代码加载了FashionMNIST数据集,datasets.Fa......
  • 使用ModelArts VS Code插件调试训练ResNet50图像分类模型
    应用场景Notebook等线上开发工具工程化开发体验不如IDE,但是本地开发服务器等资源有限,运行和调试环境大多使用团队公共搭建的CPU或GPU服务器,并且是多人共用,这带来一定的环境搭建和维护成本。因此使用本地IDE+远程Notebook结合的方式,可以同时享受IDE工程化开发和云上资源的即开......
  • 全球最大分类广告商的Karpenter落地实践:减负运维、减少中断、每月省21万 (下)
    原文链接:https://medium.com/adevinta-tech-blog/the-karpenter-effect-redefining-our-kubernetes-operations-80c7ba90a599编译:CloudPilotAI在上一篇文章中,我们介绍了Adevinta迁移至Karpenter后如何利用这一开源工具为运维团队减负、增强应用稳定性以及实现成本优化(月......
  • 【机器学习】机器学习的基本分类-监督学习-支持向量机(Support Vector Machine, SVM)
    支持向量机是一种强大的监督学习算法,主要用于分类问题,但也可以用于回归和异常检测。SVM的核心思想是通过最大化分类边界的方式找到数据的最佳分离超平面。1.核心思想目标给定训练数据,其中是特征向量,是标签,SVM的目标是找到一个超平面将数据分开,同时最大化分类边界的......
  • android手机的微信H5弹出的软键盘挡住了文本框,如何解决?
    Android微信H5页面中,软键盘弹出挡住输入框的问题,是一个比较常见且棘手的问题。核心原因在于微信内置浏览器对window.resize事件的处理机制与常规浏览器不同,以及Android系统本身的碎片化。以下是一些解决方案,建议结合实际情况选择和组合使用:1.使用scrollIntoView()方......
  • Python基于滑动窗口CNN损伤梁桥数据、故宫城墙图像数据分类可视化|附数据代码
    全文链接:https://tecdat.cn/?p=38442原文出处:拓端数据部落公众号分析师:YufeiGuo在现代土木结构工程领域,结构损伤的准确识别与定位对于保障基础设施的安全性和耐久性具有极为关键的意义。传统的人工检查方法,如目视检查以及借助专业设备进行检测,在很长一段时间内占据着主导地位,......
  • 【新手初步了解】SQL注入按不同方式的分类
    一、按照提交方式分类(一)get提交(按照请求方式划分)1.取数据$_GET(不是针对get请求,也就是不是取get请求携带的数据,而是取得查询参数数据)对于post请求(只要url上面有携带查询参数的)也是可以的,如下是携带查询参数的post请求(注意:post请求也是可以携带查询参数的,post改成get也是可......