首页 > 其他分享 >一起学Hugging Face Transformers(13)- 模型微调之自定义训练循环

一起学Hugging Face Transformers(13)- 模型微调之自定义训练循环

时间:2024-07-08 09:55:01浏览次数:11  
标签:loss 13 Transformers 自定义 训练 模型 batch dataset 循环

文章目录


前言

Hugging Face Transformers 库为 NLP 模型的预训练和微调提供了丰富的工具和简便的方法。虽然 Trainer API 简化了许多常见任务,但有时我们需要更多的控制权和灵活性,这时可以实现自定义训练循环。本文将介绍什么是训练循环以及如何使用 Hugging Face Transformers 库实现自定义训练循环。


一、什么是训练循环

在模型微调过程中,训练循环是指模型训练的核心过程,通过多次迭代数据集来调整模型的参数,使其在特定任务上表现更好。训练循环包含以下几个关键步骤:

1. 训练循环的关键步骤

1) 前向传播(Forward Pass)

  • 模型接收输入数据并通过网络进行计算,生成预测输出。这一步是将输入数据通过模型的各层逐步传递,计算出最终的预测结果。

2) 计算损失(Compute Loss)

  • 将模型的预测输出与真实标签进行比较,计算损失函数的值。损失函数是一个衡量预测结果与真实值之间差距的指标,常用的损失函数有交叉熵损失(用于分类任务)和均方误差(用于回归任务)。

3) 反向传播(Backward Pass)

  • 根据损失函数的值,计算每个参数对损失的贡献,得到梯度。反向传播使用链式法则,将损失对每个参数的梯度计算出来。

4) 参数更新(Parameter Update)

  • 使用优化算法(如梯度下降、Adam 等)根据计算出的梯度调整模型的参数。优化算法会更新每个参数,使损失函数的值逐步减小,模型的预测性能逐步提高。

5) 重复以上步骤

  • 以上过程在整个数据集上进行多次(多个epoch),每次遍历数据集被称为一个epoch。随着训练的进行,模型的性能会不断提升。

2. 示例

假设你在微调一个BERT模型用于情感分析任务,训练循环的步骤如下:

1) 前向传播

  • 输入一条文本评论,模型通过各层网络计算,生成预测的情感标签(如正面或负面)。

2) 计算损失

  • 将模型的预测标签与实际标签进行比较,计算交叉熵损失。

3) 反向传播

  • 计算损失对每个模型参数的梯度,确定每个参数需要调整的方向和幅度。

4) 参数更新

  • 使用Adam优化器,根据计算出的梯度调整模型的参数。

5) 重复以上步骤

  • 在整个训练数据集上进行多次迭代,不断调整参数,使模型的预测精度逐步提高。

3. 训练循环的重要性

训练循环是模型微调的核心,通过多次迭代和参数更新,使模型能够从数据中学习,逐步提高在特定任务上的性能。理解训练循环的各个步骤和原理,有助于更好地调试和优化模型,获得更好的结果。

在实际应用中,训练循环可能会包含一些额外的步骤和技术,例如:

  • 批量训练(Mini-Batch Training):将数据集分成小批量,每次训练一个批量,降低计算资源的需求。
  • 学习率调度(Learning Rate Scheduling):动态调整学习率,以提高训练效率和模型性能。
  • 正则化技术(Regularization Techniques):如Dropout、权重衰减等,防止模型过拟合。

这些技术和方法结合使用,可以进一步提升模型微调的效果和性能。

二、使用 Hugging Face Transformers 库实现自定义训练循环

1. 前期准备

1)安装依赖

首先,确保已经安装了必要的库:

pip install transformers datasets torch

2)导入必要的库

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm

2. 加载数据和模型

1) 加载数据集

这里我们以 IMDb 电影评论数据集为例:

dataset = load_dataset("imdb")

2) 加载预训练模型和分词器

我们将使用 distilbert-base-uncased 作为基础模型:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

3) 预处理数据

定义一个预处理函数,并将其应用到数据集:

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

4) 创建数据加载器

train_dataloader = DataLoader(encoded_dataset["train"], batch_size=8, shuffle=True)
eval_dataloader = DataLoader(encoded_dataset["test"], batch_size=8)

3. 自定义训练循环

1) 定义优化器和学习率调度器

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

2) 定义训练和评估函数

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def train_loop():
    model.train()
    for batch in tqdm(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

def eval_loop():
    model.eval()
    total_loss = 0
    correct_predictions = 0

    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            logits = outputs.logits
            total_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == batch["labels"]).sum().item()

    avg_loss = total_loss / len(eval_dataloader)
    accuracy = correct_predictions / len(eval_dataloader.dataset)
    return avg_loss, accuracy

3) 运行训练和评估

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    train_loop()
    avg_loss, accuracy = eval_loop()
    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

总结

通过上述步骤,我们实现了使用 Hugging Face Transformers 库的自定义训练循环。这种方法提供了更大的灵活性,可以根据具体需求调整训练过程。无论是优化器、学习率调度器,还是其他训练策略,都可以根据需要进行定制。希望这篇文章能帮助你更好地理解和实现自定义训练循环,为你的 NLP 项目提供更强大的支持。

标签:loss,13,Transformers,自定义,训练,模型,batch,dataset,循环
From: https://blog.csdn.net/kljyrx/article/details/140153000

相关文章

  • Day 41 | 322. 零钱兑换 、 279.完全平方数、139.单词拆分
    322.零钱兑换如果求组合数就是外层for循环遍历物品,内层for遍历背包。如果求排列数就是外层for遍历背包,内层for循环遍历物品。这句话结合本题大家要好好理解。视频讲解:https://www.bilibili.com/video/BV14K411R7yvhttps://programmercarl.com/0322.零钱兑换.html给定不同......
  • [LeetCode] 134. Gas Station
    想到了提前判断和小于0的情况,懒得写,果然被阴间用例10万个加油站坑了。classSolution:defcanCompleteCircuit(self,gas:List[int],cost:List[int])->int:#1n=len(gas)ifn==1:ifgas[0]>=cost[0]:ret......
  • 通信协议_C#实现自定义ModbusRTU主站
    背景知识:modbus协议介绍相关工具mbslave:充当从站。虚拟串口工具:虚拟出一对串口。VS2022。实现过程以及Demo打开虚拟串口工具:打开mbslave:此处从站连接COM1口。Demo实现创建DLL库,创建ModbusRTU类,进行实现:usingSystem;usingSystem.Collections.Generic;usi......
  • 前端JS特效第22集:html5音乐旋律自定义交互特效
    html5音乐旋律自定义交互特效,先来看看效果:部分核心的代码如下(全部代码在文章末尾):<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><title>ChimeTime™</title><linkrel="stylesheet"href="css/style.css......
  • Simple WPF: WPF 自定义按钮外形
    最新内容优先发布于个人博客:小虎技术分享站,随后逐步搬运到博客园。WPF的按钮提供了Template模板,可以通过修改Template模板中的内容对按钮的样式进行自定义,完整代码Github自取。使用Style定义扁平化的按钮样式定义一个ButtonStyleDictonary.xaml资源字典文件,在ControlTemplate......
  • SpringSecurity简单自定义配置
    初学者对于学习SpringSecurity相关的一些简单自定义配置总结。由于自身能力并不能和大佬相比较,以下的一些内容有误或有可改进地方,希望指出,我抱有一颗谦虚好学的心保持热情,并感谢指正。实现案例:1.基于内存的用户认证2.基于数据库的用户认证3.添加用户(数据库)4.自定义密......
  • srpingboot 自定义 start
    自动配置工程绑定配置文件,上逼格的start都支持自定义配置,我们也装像点~~@ConfigurationProperties("cyrus.hello")publicclassCyrusHelloProperties{//绑定配置文件cyrus.hello.username属性privateStringusername;publicStringgetUsernam......
  • Android 13.0 mt6771新增分区功能实现一
    1.前言 在13.0的系统ROM定制化开发中,在对某些特殊模块中关于数据的存储方面等需要新增分区来保存,所以就需要在系统分区新增相关的分区,来实现功能,接下来就来实现这个功能,来新增分区功能2.mt6771新增分区功能实现一的核心类build/make/core/Makefilebuild/make/cor......
  • Android面试题自定义View之Window、ViewRootImpl和View的三大流程
    本文首发于公众号“AntDream”,欢迎微信搜索“AntDream”或扫描文章底部二维码关注,和我一起每天进步一点点View的三大流程指的是measure(测量)、layout(布局)、draw(绘制)。下面我们来分别看看这三大流程View的measure(测量)MeasureSpecMeasureSpec是View的一个内部静......
  • 美食分享交流网站 毕业设计-附源码10913
    摘 要大数据时代下,数据呈爆炸式地增长。为了迎合信息化时代的潮流和信息化安全的要求,利用互联网服务于其他行业,促进生产,已经是成为一种势不可挡的趋势。在美食分享的要求下,开发一款整体式结构的美食分享交流网站,将复杂的系统进行拆分,能够实现对需求的变化快速响应、系统稳......