首页 > 编程问答 >如何使用 HuggingFace 的 run_translation.py 脚本从头开始训练翻译?

如何使用 HuggingFace 的 run_translation.py 脚本从头开始训练翻译?

时间:2024-07-22 10:53:34浏览次数:15  
标签:python huggingface-transformers

我尝试了各种 HuggingFace 脚本来构建语言模型,例如 run_mlm.py ( link )、 run_clm.py ( link ) 和 run_translation.py ( link ) 。对于前两个脚本,它可以从头开始训练语言模型(即没有基本模型)。

但是,我无法使用 run_translation.py 从头开始构建模型,因为命令中需要模型名称或路径线。鉴于我有一个相当大的并行翻译数据集,如何从头开始构建翻译模型?


你说的对,Hugging Face 的 run_translation.py 脚本不像用于语言建模的脚本那样直接支持从头开始训练。它需要一个预训练的模型作为起点。

但是,你仍然可以使用 Hugging Face Transformers 库从头开始训练翻译模型,但这需要一些额外步骤:

1. 准备你的数据:

  • 清理和预处理: 清理你的并行数据集,并将其预处理成适合训练的格式。这包括标记化、将句子转换为数字 ID 以及创建源语言和目标语言数据的对。
  • 数据集类: 使用 Hugging Face 的 Dataset 类创建自定义数据集类来加载和处理你的数据。你可以在官方文档中找到更多信息: https://huggingface.co/docs/datasets/
  • 数据加载器: 使用 DataLoader 类为你的训练和验证数据创建数据加载器。

2. 初始化模型:

  • 选择模型架构: 从 Hugging Face Transformers 库中选择你想要的编码器-解码器模型架构。一些流行的选择包括 BART、T5 和 MarianMT。
  • 从配置初始化: 不是从预训练的检查点加载模型,而是使用所选模型架构的配置对象从头开始初始化它。这将为你提供一个随机初始化的模型。
from transformers import AutoConfig, AutoModelForSeq2SeqLM

# 使用所需的配置
config = AutoConfig.from_pretrained("t5-small")  
# 从配置初始化模型
model = AutoModelForSeq2SeqLM.from_config(config)

3. 设置训练参数:

  • 优化器和调度器: 选择一个优化器(如 AdamW)和一个学习率调度器。
  • 训练参数: 定义批大小、epoch 数量和其他相关训练参数。

4. 编写训练循环:

  • 迭代 epoch: 创建一个在你的数据上迭代指定 epoch 数量的循环。
  • 迭代批次: 在每个 epoch 内,迭代你的训练数据加载器。
  • 前向传递: 对于每个批次,通过模型执行前向传递,计算损失。
  • 反向传递: 计算相对于模型参数的损失梯度。
  • 更新参数: 使用优化器更新模型参数。
  • 评估: 在每个 epoch 结束时(或定期),在你的验证集上评估模型,以跟踪进度。

5. 保存你的模型:

  • 保存检查点: 在训练后,保存你的模型检查点,以便你可以稍后加载和使用它进行翻译。

虽然这需要编写比使用 run_translation.py 脚本更多的代码,但它允许你从头开始完全控制训练过程和模型初始化。

请记住,从头开始训练翻译模型需要大量的计算资源和时间,尤其是在大型数据集上。考虑使用强大的 GPU 来加速训练过程。

标签:python,huggingface-transformers
From: 78770014

相关文章

  • Python学习计划——2.4列表推导式(List Comprehensions)
    列表推导式是Python的一种简洁且强大的语法,用于生成新的列表。它可以用更少的代码、更清晰的方式来创建列表,特别是在处理简单的循环和条件操作时。1.基本语法列表推导式的基本语法如下:[expressionforiteminiterable]expression:表达式,计算结果用于生成列表的元素。ite......
  • Python学习计划——2.3常用内置函数(len, max, min, sum, etc.)
    Python提供了许多内置函数,用于简化对数据结构的操作。以下是一些常用的内置函数及其详细说明。1.len()len()函数用于返回对象(如列表、元组、字符串、字典等)的长度(元素个数)。示例:#列表fruits=["apple","banana","cherry"]print(len(fruits))#输出:3#元组c......
  • 哪个 Python 框架可以在 Google Collab 中显示和更改图像?
    我希望能够在使用GoogleCollab时为RL绘制高fps的位图。我现在可以使用OpenCV绘制图像cv2_imshowgoogle替换cv2.imshow但是,它无法替换现有图像,它下面绘制了新的我能够在替换imshow函数中使用一些JavaScript来修复它。但刷新率约为......
  • VSCode 自动建议 python 导入而不依赖 Intellisense
    我正在使用Transformer中的AutoModel之类的对象,并且经常遇到自动导入建议无法找到的对象。我总是希望VSCode建议“从Transformer中执行”,而不是费心寻找它找不到的原因每当看到未定义的“AutoModel”时,都会导入AutoModel”,因此无需扫描任何python导入目录。这......
  • 如何使用Python计算位移自相关函数?
    我正在使用python来分析粒子的异常扩散。我已经得到了粒子轨迹的位移,我想计算并绘制位移自相关与滞后时间t的关系。我认为可能存在使用t和位移(如deltar)的自相关函数的一般函数,但我不能没找到。我可以得到函数或代码吗?可以使用numpy和matplotlib库在Python......
  • 一天一点点,第四天Python基础
    第一天:一天一点点。Python基础-CSDN博客第二天:一天一点点,接上章Python基础-CSDN博客第三天:一天一点点,第三天Python基础(循环语句)-CSDN博客推导式推导式是一种独特的数据处理方式,可以从一个数据序列构建另一个新的数据序列的结构体。推导式是一种强大且简洁的语法,适用于生......
  • Python - for循环不使用正则表达式附加数组
    以下代码从URL获取版本号,然后对于每个版本号,转到该版本号的页面并使用文件名的特定模式填充数组。生成的数组应包含每个版本号的文件名列表,但它似乎只包含早期版本(2.6)。使用print语句,我可以看到代码的工作原理是它获取sha256sums.asc文件-所有这些文件,所有版本。我猜......
  • 使用 callable_iterator (re.finditer) 导致 Python 冻结
    我有一个为文本的每一行调用的函数。deftokenize_line(line:str,cmd=''):matches=re.finditer(Patterns.SUPPORTED_TOKENS,line)tokens_found,not_found,start_idx=[],[],0print(matches)formatchinmatches:pass#Rest......
  • Python 的 time.sleep - 永远不会醒来
    我认为这将是那些简单的问题之一,但它让我感到困惑。[停止媒体:我是对的。找到了解决方案。查看答案。]我正在使用Python的单元测试框架来测试多线程应用程序。很好而且很直接-我有5个左右的工作线程监视一个公共队列,以及一个为它们制作工作项的生产者线程......
  • python中使用mitmproxy的http模块出错
    我有一个使用mitmproxyhttp函数的代码,它在这里惨败:defmain(stdscr):try:parser=argparse.ArgumentParser(description='NetSourNetworkAnalyzer')parser.add_argument('--proxy',action='store_true',help='EnableH......