首页 > 其他分享 >HuggingFace | huggingface中遇到的坑

HuggingFace | huggingface中遇到的坑

时间:2023-07-07 12:23:12浏览次数:31  
标签:__ 遇到 self labels HuggingFace huggingface train model data

一、不要尝试使用huggingface的Trainer函数加载自定义模型

理论上说,Hugging Face的Trainer函数可以加载自定义模型,只要您的模型是基于PyTorch或TensorFlow实现的,并且实现了必要的方法(如forward方法和from_pretrained方法)。

要将您的自定义模型与Hugging Face的Trainer函数一起使用,您需要使用Transformers库中的Trainer类,该类提供了训练、评估和预测模型的方法,以及与PyTorch和TensorFlow模型集成的功能。

代码示例:

import torch
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # 假设数据是一个 tuple,第一个元素是输入,第二个元素是标签
        input_ids = self.data[idx][0]
        labels = self.data[idx][1]
        return {"input_ids": input_ids, "labels": labels}

# 假设您有一个包含 100 个样例的数据集,每个样例是一个 tuple,包含一个长度为 10 的输入和一个标签
data = [([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 0) for i in range(100)]

# 创建 train_dataset 和 eval_dataset,这里使用前 80 个样例作为训练集,后 20 个样例作为评估集
train_dataset = MyDataset(data[:80])
eval_dataset = MyDataset(data[80:])

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)
    
    def forward(self, input_ids, labels=None):
        logits = self.linear(input_ids)
        if labels is not None:
            labels = labels.to(torch.float)
            labels = labels.view(-1)
            loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fn(logits.view(-1, 1), labels)
            return loss
        else:
            return logits
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model = cls(*model_args, **kwargs)
        state_dict = torch.load(pretrained_model_name_or_path)
        model.load_state_dict(state_dict)
        return model

model = MyModel()
training_args = TrainingArguments(
    output_dir='./results',          # 训练输出目录
    num_train_epochs=3,              # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=64,   # 评估批次大小
    warmup_steps=500,                # 预热步数
    weight_decay=0.01,               # 权重衰减
    logging_dir='./logs',            # 日志目录
    logging_steps=10,                # 日志步数
)

trainer = Trainer(
    model=model,                         # 自定义模型实例
    args=training_args,                  # 训练参数
    train_dataset=train_dataset,         # 训练数据集
    eval_dataset=eval_dataset            # 评估数据集
)

trainer.train()  # 开始训练

我们上面都可以运行,但是到了trainer.train()函数中,就有各种错误,比如:

RuntimeError: expected scalar type Float but found Long

这个错误通常是由于张量数据类型不匹配引起的。在 PyTorch 中,张量数据类型非常重要,因为它们指定了张量中存储的数值的精度和类型。如果您在模型的前向传递中使用了错误的数据类型,就会出现这个错误。

但是代码中我们找不到这个错误就说执行在trainer.train()函数这里,报了错。这样对debug代码不友好,而且这个错误解决了,可能还有更多错误。

使用Trainer函数的正确用法

想要使用Trainer函数进行训练模型,那么模型应该使用huggingface中的可以查到的模型。我们可以使用这些预训练模型进行微调。

比如,我们使用多语言预训练模型进行翻译的下游任务,而且我们可以训练领域的数据的embedding,然后可以进行领域的翻译训练,然后我们可以通过我们训练好的预训练模型进行使用。

在比如,我们可以进行数据的处理,把单语数据进行翻译。(待定)

问题继续发现。。。。。。

标签:__,遇到,self,labels,HuggingFace,huggingface,train,model,data
From: https://www.cnblogs.com/zhangxuegold/p/17534627.html

相关文章

  • 1.安装Rocky8.8 Ubuntu20.04版本中遇到的一些问题
    1.VMware的监视器看不到Rocky的全部图像,所以我在安装过程中改变了监视器的最大分辨率,这样不会影响系统的功能吧?2.Ubuntu系统安装中Instalcomplete界面中有个rooting运行中,我直接关机,又开机,影响不影响系统完整?3.在VMware中Ubuntu系统root登录的密码与XShell中Ubuntu系统root登录......
  • 实习中遇到的问题(3)
    np格式的数组在展示图片时需要设置的格式转换。如果是(1,28,28)格式,则不可以转换成图片展示 如果是(28,28,1)格式,则可以转换成图片展示 同时,(1,28,28,1)也是不能够展示的。我想有可能是最后一个数字是channel即通道数。关于np.newaxis()函数np.newaxis的功能是增加新的维度......
  • 实习中遇到的问题(3)
    transpose的使用 有关contiguous()的一些解释https://www.zhihu.com/tardis/bd/art/64551412?source_id=1001是关于设置张量中的数据连续化的操作np.set_printoptions的使用方法 ......
  • HuggingFace | 如何下载预训练模型
    本例我们在Linux上进行下载,下载的模型是bert-base-uncased。下载网址为:https://www.huggingface.co/bert-base-uncasedhuggingface的transformers框架,囊括了BERT、GPT、GPT2、ToBERTa、T5等众多模型,同时支持pytorch和tensorflow2,代码非常规范,使用也非常简单,但是模型使用的时候,......
  • 学迭代器遇到的一个问题
    事情是这样的functioniter(arr,index)index=index+1ifarr[index]thenreturnarr[index],indexendenda={1,2,3,5,4}fork,viniter,{1,2,5,6,5,4},0doprint(k)end运行结果:寄将手打的table改成a得......
  • rsync 遇到中文乱码文件名无法同步,并报错:rsync: rename "/test1/abc/abc/opt/abc/abc/
    rsync遇到中文文件名乱码报错报错如下:rsync:rename"/test1/abc/def/efg/abc-V2/img_abc/.δ\#261\#352\#314\#342-3.jpg.wdPu5C"->"event/abc-V2/img_abc/δ\#261\#352\#314\#342-3.jpg":Input/outputerror(5)rsync:rename"/test1/abc/def/e......
  • 实习中遇到的问题(1)
    什么是BatchNormalization?1、先取平均值2、计算sigama2.1、sigama计算方式是见图中公式3、每一项减去平均值然后除以sigama什么是Softmax?  什么是Attention和Transformer?最近在重新学习和认识Attention和Transformer,看到一个视频讲的很详细,是从矩阵计算角度讲......
  • 当使用POI打开Excel文件遇到out of memory时该如何处理?
    摘要:本文由葡萄城技术团队于博客园原创并首发。转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具、解决方案和服务,赋能开发者。当我们开发处理Excel文件时,ApachePOI是许多人首选的工具。但是,随着需求的增加、工程复杂,在打开复杂的Excel文件的时候可能会出现一些异......
  • 记录一下最近遇到的UE5 BUG
    1.UE5.2打包后,打开项目崩溃,提示:Assertionfailed:CastResult[File:D:\build\++UE5\Sync\Engine\Source\Runtime\CoreUObject\Public\UObject\Field.h][Line:961] CastFieldCheckedfailedwith0x0000015001062400  0x00007ff69dd254b6YH.exe!FRigVMMemoryHandle::......
  • 开发中MongoDB遇到的各种问题
    目录一、安装6版本以下二、安装6版本及以上三、安装6版本以下(解压版)四、配置本地WindowsMongoGB服务五、navicat连接远程mongodb数据库六、ip不一致问题一、安装6版本以下安装MongoDB6版本以下的可以参考以下博主->自动安装版(26条消息)MongoDB安装(超详细)_AIbro的博客-C......