首页 > 其他分享 >欺诈文本分类检测(十三):交叉训练验证

欺诈文本分类检测(十三):交叉训练验证

时间:2024-09-08 19:23:48浏览次数:10  
标签:欺诈 训练 验证 dataset checkpoint train path model 文本

1. 引言

交叉验证主要讨论的是数据集的划分问题。

通常情况下,我们会采用均匀随机抽样的方式将数据集划分成3个部分——训练集、验证集和测试集,这三个集合不能有交集,常见的比例是8:1:1(如同前文我们所作的划分)。这三个数据集的用途分别是:

  • 训练集:用来训练模型,去学习模型的权重和偏置这些参数,这些参数可称为学习参数。
  • 验证集:用于在训练过程中选择超参数,比如批量大小、学习率、迭代次数等,它并不参与梯度下降,也不参与学习参数的确定。
  • 测试集:用于训练完成后评价最终的模型时使用,它既不参与学习参数的确定,也不参数超参数的选择,而仅仅使用于模型的评价。

注:千万不能在训练过程中使用测试集,不论是用于训练还是用于超参数的选择,这会将测试数据无意中提前透露给模型,相当于作弊,使得模型测试时准确率虚高。

而交叉验证与上述不同的地方在于:在手动划分时只分出训练集和测试集,到真正训练时才从训练集中动态抽取一定比例作为验证集,并且在多轮训练中会循环提取不同的训练集和验证集。数据集划分大概如下图:
在这里插入图片描述

  • 第一轮训练时,将训练集平均分成5份,选1份作为验证集,其余4份作为训练集。
  • 第二轮训练时,取另外的1份作为验证集,剩余4份作为训练集。
  • ……
  • 如此循环,直到每份数据都参与过训练和验证。

这样做的好处在于:模型能更充分的利用数据,更全面的学习到数据的整体特征,减少过拟合风险。

2. 训练过程

2.1 初始化

这一部分同前文训练的预设一样,基本没有什么改变。

%run trainer.py
traindata_path = '/data2/anti_fraud/dataset/train0819.jsonl'
evaldata_path = '/data2/anti_fraud/dataset/eval0819.jsonl'
model_path = '/data2/anti_fraud/models/modelscope/hub/Qwen/Qwen2-1___5B-Instruct'
output_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_1'

声明要使用的GPU设备。

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'

加载模型和tokenizer分词器。

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
train_dataset, eval_dataset = load_dataset(traindata_path, evaldata_path, tokenizer)

在这里插入图片描述

2.2 数据处理

这一部分主要是将前文构造训练/测试数据集所构造的训练集和验证集合并,采用sklearn库中的KFold重新按折子进行数据集分割。

import glob
import gc
import numpy as np
from datasets import Dataset, concatenate_datasets
from sklearn.model_selection import KFold

拼接训练集和验证集作为一个数据集。

datasets = concatenate_datasets([train_dataset, eval_dataset])
len(datasets)
21135

创建KFold对象用于按折子划分数据集。

  • n_splits=5:表示将数据集划分为5份。
  • shuffle=True:表示调用kf.split划分数据集前先将顺序打乱。

KFold是由sklearn库提供的k折交叉验证方法,它通过将数据集分成k个相同大小的子集(称为折),每次迭代数据集时,使用其中一个作为验证集,其余4个作为训练集,并重复这个过程k次。

kf = KFold(n_splits=5, shuffle=True)
kf
KFold(n_splits=5, random_state=None, shuffle=True)

用kfold划分数据集时,实际拿到的是数据在数据集中的索引顺序,如下面示例的效果。

indexes = kf.split(np.arange(len(datasets)))
train_indexes, val_indexes = next(indexes)
train_indexes, val_indexes, len(train_indexes), len(val_indexes)
(array([    0,     2,     3, ..., 21129, 21131, 21134]),
 array([    1,     9,    12, ..., 21130, 21132, 21133]),
 16908,
 4227)

如上所示,训练集的数量16908和验证集的数量4227比例基本是4:1。

2.3 超参数定义

定义超参构造函数,包括训练参数和Lora微调参数。这里相对于之前作的调整在于:

  • 修改评估和保存模型的策略,由每100step改为每个epoch保存一次,原因是前者保存的checkpoint有太多冗余,节省一些磁盘空间。
  • 将num_train_epochs调整为2,表示每个折子的数据集训练2遍,k=5时数据总共会训练10遍。

注:当per_device_train_batch_size=16时训练过程中会意外发生OOM,所以临时将批次大小per_device_train_batch_size改为8.

def build_arguments(output_path):
    train_args = build_train_arguments(output_path)
    train_args.eval_strategy='epoch'
    train_args.save_strategy='epoch'
    train_args.num_train_epochs = 2
    train_args.per_device_train_batch_size = 8
    
    lora_config = build_loraconfig()
    lora_config.lora_dropout = 0.2   
    lora_config.r = 16
    lora_config.lora_alpha = 32
    return train_args, lora_config

Lora配置和前文最后一次训练的配置相同,秩采用16,dropout采用0.2.

2.4 重新定义模型加载

由于训练过程中需要迭代更换不同的训练集和验证集组合,而更换数据集就需要重新创建训练器,传入新的模型实例,相当于从头开始训练。

为了实现后一次训练能在前一次训练结果的基础上继续训练,就需要找到前一次训练的最新checkpoint。所以定义一个find_last_checkpoint方法,用于从一个目录中查找最新的checkpoint。

# 确定最后的checkpoint目录
def find_last_checkpoint(output_dir):
    checkpoint_dirs = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
    last_checkpoint_dir = max(checkpoint_dirs, key=os.path.getctime)
    return last_checkpoint_dir

find_last_checkpoint("/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0830_1")
'/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0830_1/checkpoint-3522'
  • glob.glob 函数可以在指定目录下查找所有匹配 checkpoint-* 模式的子目录
  • os.path.getctime 返回文件的创建时间(或最近修改时间)
  • max 函数根据这些时间找出最后创建的目录,也就是最新的checkpoint。

定义一个新的加载模型的方法,用于从基座模型和指定的checkpoint中加载最新训练的模型,并根据训练目标来设置参数的require_grad属性。

def load_model_with_checkpoint(model_path, checkpoint_path='', device='cuda'):
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)
    # 加载lora权重
    if checkpoint_path: 
        model = PeftModel.from_pretrained(model, model_id=checkpoint_path).to(device)
    
    # 将基础模型的参数设置为不可训练
    for param in model.base_model.parameters():
        param.requires_grad = False
    
    # 将 LoRA 插入模块的参数设置为可训练
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
    return model

如上代码逻辑所示,将来自lora的参数都设置为需要梯度requires_grad = True,其余原始基座模型的参数设置不可训练requires_grad = False

2.5 构建训练过程

在这个训练过程中,除了第一次训练是从0初始化的微调秩矩阵,后面几次训练则都是从指定checkpoint来初始化微调秩,这导致了原先定义的build_trainer方法不通用。所以定义一个新的训练器构建方法,将加载微调参数的逻辑移到外面。

def build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset):
    # 开启梯度检查点时,要执行该方法
    if train_args.gradient_checkpointing:
        model.enable_input_require_grads()
    return Trainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],  # 早停回调
    )

下面定义交叉训练的主循环。

results = []
last_checkpoint_path = ''

for fold, (train_index, val_index) in enumerate(kf.split(np.arange(len(datasets)))):
    train_dataset = datasets.select(train_index)
    eval_dataset = datasets.select(val_index)

    output_path = f'/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_{fold}'
    train_args, lora_config = build_arguments(output_path)
    # 第一次训练和后面几次训练所采用的模型加载方法不同
    if last_checkpoint_path:
        model = load_model_with_checkpoint(model_path, last_checkpoint_path, device)
    else:
        model = load_model(model_path, device)
        model = get_peft_model(model, load_config)

    model.print_trainable_parameters()
    trainer = build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset)
    train_result = trainer.train()
    results.append(train_result)
    
    last_checkpoint_path = find_last_checkpoint(output_path)

代码逻辑说明:

  • kf.split函数划分了5份数据索引,以这5份数据索引进行5次迭代。
  • 使用datasets.select基于索引在每次迭代时选择不同的数据作为训练集和验证集。
  • 为了避免前次迭代训练的结果被下次迭代的结果给覆盖,每次迭代训练通过fold来拼接不同的输出目录output_path。
  • 如果存在last_checkpoint_path,则从checkpoint来加载模型,如果不存在,则使用get_peft_model向模型中插入一个新的Lora微调秩。
  • 使用新的build_trainer_v2方法来构建训练器并开始训练。
  • 每次迭代完都找出此次训练中最新的checkpoint,作为下次训练的起点。
2.6 开始训练

运行上面的主循环开始训练。

最终可以收集到5次迭代训练的损失数据如下,每次迭代跑2轮数据集,共跑了10轮数据集。

EpochTraining LossValidation Loss
10.02330.02189
20.01380.01614
30.0088000.011420
40.0046000.013666
50.0032000.004718
60.0030000.004082
70.0072000.001999
80.0000000.000814
90.0049000.002273
100.0102000.002139

对比前面欺诈文本分类微调(七)—— lora单卡二次调优训练进行到2300步左右(大概两遍数据)就开始过拟合(主要现象是验证损失到0.0161就不再下降反而开始升高)。K折交叉训练直到第4次迭代(大概八遍数据)过后才达到损失最低点,第5次迭代才出现了略微的过拟合(相比于第4次),过拟合的现象得到了极大的缓解,验证损失也降到了一个更低的值0.000814,这说明数据相比之前训练来说得到了更充分的使用。

3. 评估测试

由于交叉训练中验证集和训练集都参与了模型学习参数的更新,所以用验证集进行评估已经没有意义。我们直接用测试集进行最后的评估。

第一轮迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_0/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:19<00:00, 11.75it/s]

tn:1135, fp:32, fn:128, tp:1054
precision: 0.9705340699815838, recall: 0.8917089678510999

第三轮迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_2/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:21<00:00, 11.64it/s]

tn:1133, fp:34, fn:64, tp:1118
precision: 0.9704861111111112, recall: 0.9458544839255499

第四次迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_3/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:21<00:00, 11.66it/s]

tn:1128, fp:39, fn:64, tp:1118
precision: 0.9662921348314607, recall: 0.9458544839255499

第五次迭代结果的评测:

%run evaluate.py
checkpoint_path='/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0903_4/checkpoint-4226'
testdata_path = '/data2/anti_fraud/dataset/test0819.jsonl'
evaluate(model_path, checkpoint_path, testdata_path, device, batch=True, debug=True)
progress: 100%|██████████| 2349/2349 [03:22<00:00, 11.58it/s]

tn:1124, fp:43, fn:50, tp:1132
precision: 0.963404255319149, recall: 0.9576988155668359

与之前单卡训练多卡微调的结果相比,精确率有一点点下降(0.9953->0.9634),但召回率却有了一个比较大的提升(0.9129->0.9576),这个测评结果的数据变化与上面损失结果的数据变化基本是一致的。

小结:本文通过引入K折交叉验证方法,循环选择不同的训练集和验证集进行多次迭代训练,将损失降到了一个更低的值,也在很大程度上缓解了[前面每次训练]过程中都出现的过拟合现象。最终在从未见过的测试数据集上进行评测时,召回率指标也有了一个较大的提升。从这个结果来看,K折交叉验证这种方法确实能让模型对数据学习的更充分,有助于模型泛化能力的提升。

相关阅读

标签:欺诈,训练,验证,dataset,checkpoint,train,path,model,文本
From: https://blog.csdn.net/xiaojia1001/article/details/141928787

相关文章

  • Elasticsearch-5.6版本安装,添加登录验证,修改密码
    一.ES简介Elasticsearch是一个实时的分布式存储,搜索、分析的引擎。他的模糊查询的强大效率,目前被很多企业所青睐。二.本文背景由于从ElasticStack6.8和7.1版本才开始支持登录验证,所以对于用之前版本的无法升级,又需要添加授权访问或者修改密码,可以参考本文中用到的......
  • 一文看懂CAPTCHA验证
    CAPTCHA代表“全自动公共图灵测试,用于区分计算机和人类”。这些测试充当身份验证机制,以确保网站流量来自真人,而不是机器人。CAPTCHA测试可以在您与网站交互的过程中出现在各个时间点,例如在登录尝试、表单提交期间,或者当您的浏览活动模仿机器人流量时(例如,多次刷新浏览器或快......
  • 软件测试-RobotFramework-实战(清除、输入文本;鼠标点击;下拉框select、li;实战演示视频)
    学习笔记记录在用户信息界面,主要涉及头像上传,输入文本,选择按钮,下拉框选项。一、清除、输入文本 上传前一篇已经完成了,对于输入文本,主要就是一行代码inputtext输入框的地址要输入的文本但是如果输入框内还含有文本,就需要进行清除。\8--表示删除一个字符。 pre......
  • 安全验证:AE 2024安装包下载后的完整性检查
    安全验证:AE 2024安装包下载后的完整性检查安全验证:AE2024安装包下载后的完整性检查AdobeAfterEffects2024(AE2024)是一款功能强大的视频后期制作软件,广泛应用于影视特效、动画制作等领域。为了确保软件的完整性和安全性,在下载AE2024安装包后,进行完整性检查至关重要。本......
  • 【NLP自然语言处理】文本的数据分析------迅速掌握常用的文本数据分析方法~
    目录......
  • CSS设置禁止文本复制
    CSS设置禁止复制经常可以看到某些网站网页上的文字无法被选中,除了js控制,通过CSS样式user-select和z-index两个属性都可导致无法复制文字user-select浏览器中双击或点击选中文本,文本会被高亮显示,通过cssuser-select属性则设置是否允许选取元素的文本,该CSS有四个属性值auto:默......
  • 【十五派-注册安全分析报告-滑动验证加载不正常导致安全隐患】
    前言由于网站注册入口容易被黑客攻击,存在如下安全问题:暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞所以大部分网站及App都采取图形验证码或滑动验证码等交互解决方案,但在机器学习能力提......
  • 使用亚马逊Bedrock的Stable Diffusion XL模型实现文本到图像生成:探索AI的无限创意
    引言什么是AmazonBedrock?AmazonBedrock是亚马逊云服务(AWS)推出的一项旗舰服务,旨在推动生成式人工智能(AI)在各行业的广泛应用。它的核心功能是提供由顶尖AI公司(如AI21Labs、Anthropic、Cohere、Meta、MistralAI、StabilityAI以及亚马逊自身)开发的多种基础模型(FoundationMo......
  • Spring Boot 注解探秘:@Validated 开启数据验证之旅(上)
    在JavaWeb项目开发中,数据验证是一项至关重要的环节。Spring框架中的@Validated注解为我们提供了一种方便而强大的方式来实现数据验证。本文将详细介绍@Validated注解的用法及其在SpringBoot应用中的实践。一、基本介绍@Validated是Spring框架提供的用于数据验证......
  • 数字IC验证笔面试常见100题【持续更新】
    【提要】收集整理了一些网络上和我自己在秋招、实习时遇到的题目,适合数字验证方向求职的同学进行差缺补漏或者应对八股时的速成。    对于时间比较充裕并且有条件的同学,还是强烈建议找个实习来提升自己的能力以及校招竞争性,独立完成了一两个真实项目后,能大大加深对验证......