首页 > 其他分享 >bert模型数据集加载方式

bert模型数据集加载方式

时间:2024-05-28 23:04:12浏览次数:27  
标签:bert torch True 模型 labels ids data self 加载

数据集构造

无论是机器学习还是深度学习对于数据集的构造都是十分重要。

现记录一下PyTorch 的 torch.utils.data.Dataset 类的子类。Dataset 类是PyTorch框架中用于处理数据的基本组件,它允许用户定义自己的数据集类,以满足特定任务的需求。

Dataset是一个抽象基类,用于创建自定义数据集。它定义了两个核心方法:getitemlen,它们是所有数据集必须实现的方法。

类定子类:
重写 init 方法来初始化数据集,可能需要加载数据、解析数据等。
重写 getitem 方法来根据索引返回数据集中的一个样本,通常会包含数据的加载、解码等操作。
重写 len 方法来返回数据集中样本的数量。

import pandas as pd
from transformers import BertTokenizerFast
import torch


# 读取数据
df = pd.read_csv("./a.csv", encoding="utf-8")
texts = df["content"][:10].tolist()
labels = df["punish_result"][:10].tolist()

texts = list(map(lambda x: str(x), texts))
# texts和labels是一个list,可以自己构造一个

# Hugging Face下载这个模型google-bert/bert-base-chinese
model_name = "./bert-base-chinese" 
# 加载分词器
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# 对文本进行编码
# truncation=True 文本超过max_length进行截断处理
# padding=True 文本不足max_length进行pad处理 补0
train_encodings = tokenizer(texts, truncation=True, padding=True, max_length=32)

# 封装数据为PyTorch Dataset
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        # item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        # 等价上面注释写法,for循环比较好理解
        item = {}
        for key, val in self.encodings.items():
            item[key] = torch.tensor(val[idx])

        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


train_dataset = TextDataset(train_encodings, labels)

for dta in train_dataset:
    print(dta)
    break

# 打印数据如下:

# {'input_ids': tensor([ 101, 1585,  511,  872, 1962, 8024, 2769, 6821, 6804, 3221,  976, 6858,
#         7599, 6392, 1906, 4638,  511, 2769, 2682, 7309,  671,  678, 8024, 1493,
#         6821, 6804, 7444, 6206, 6821,  671, 1779,  102]),
# 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#         0, 0, 0, 0, 0, 0, 0, 0]),
# 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#         1, 1, 1, 1, 1, 1, 1, 1]),
# 'labels': tensor(0)
# }

上述代码主要通过加载bert-base-chinese模型的分词器处理原始数据,之后实现一个Dataset的子类将数据封装到PyTorch框架可识别数据结构。

数据集构造二
# 自定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, path):
        df = pd.read_csv(path, encoding="utf-8")
        self.texts = df["content"].tolist()
        self.labels = df["punish_result"].tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        text = str(self.texts[i])
        label = int(self.labels[i])
        return text, label

path = "./data/abuse.csv"
dataset = Dataset(path)

# len(dataset), type(dataset.texts[0])
# 17182 老板百亿检测干什么

加载词典和分词器

# 加载字典和分词工具
token = BertTokenizer.from_pretrained("./bert-base-chinese")

设置辅助函数

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    """
    batch_text_or_text_pairs:
    类型: 列表或元组的列表。
    含义: 输入的文本数据,可以是单个文本列表(如果只处理单个句子)或配对的文本(如对话或翻译任务中的源语言和目标语言句子)。
    truncation:
    类型: 布尔值。
    含义: 是否对超过最大长度的文本进行截断。设置为 True 表示会截断超出长度限制的文本。
    padding:
    类型: 字符串。
    含义: 决定如何填充短于最大长度的文本。'max_length' 表示所有样本都会被填充到max_length的长度,以确保批次内的所有元素长度一致。
    max_length:
    类型: 整数。
    含义: 设定的最大序列长度。所有输入的文本将会被截断或填充到这个长度。
    return_tensors:
    类型: 字符串。
    含义: 指定返回的张量类型。'pt' 表示返回 PyTorch 张量,其他可能的选项有 'tf'(TensorFlow 张量)或 'np'(NumPy 数组)。
    return_length:
    类型: 布尔值。
    含义: 如果设置为 True,函数还会返回一个列表,其中包含每个输入文本的原始长度,这对于知道哪些部分是填充的很有用。
    """
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        truncation=True,
        padding="max_length",
        max_length=500,
        return_tensors="pt",
        return_length=True,
    )

    # input_ids: 编码之后的数字
    input_ids = data["input_ids"]

    # attention_mask是一个与输入tokens相同形状的二维数组
    # 1 表示有效的位置,即非填充的tokens。这些位置在计算注意力分数时会被考虑。
    # 0 表示填充的位置,模型在计算注意力时会忽略这些位置。
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    labels = torch.LongTensor(labels)

    # print(data['length'], data['length'].max())
    # tensor([ 56,  71,  32, 159,  34, 179,  33,  79,  49,  33,  98,  89, 212,  41,
    #      63,  58]) tensor(212)

    return input_ids, attention_mask, token_type_ids, labels

加载数据集

"""
dataset:
类型: torch.utils.data.Dataset 的实例。
含义: 指定要加载的数据集。dataset 参数接收之前定义的 TextDataset 实例,包含了预处理过的文本数据和标签。
batch_size:
类型: 整数。
含义: 每个批次(batch)中的样本数量。在这个例子中,设置为 16,意味着数据加载器每次返回的将是包含16个样本的数据批次,用于模型训练或评估。
collate_fn:
类型: 可调用对象(如函数)。
含义: 用于整理一个批次的数据。当从数据集中取出多个样本时,collate_fn 会被调用来将这些样本打包成一个批次。这对于处理变长序列(如文本)特别有用,因为需要对不同长度的序列进行填充或截断以适应批处理。如果没有提供,默认的 collate_fn 可能不适用于所有情况,特别是当数据具有复杂结构时。
shuffle:
类型: 布尔值。
含义: 是否在每个 epoch 开始时对数据集进行随机洗牌。设置为 True 表示在训练过程中数据会随机排序,有助于提高模型的泛化能力。对于验证或测试集,通常应设为 False。
drop_last:
类型: 布尔值。
含义: 如果设置为 True,在最后一个批次不足以填满整个 batch_size 时,这个批次将会被丢弃。如果设为 False,则最后一个批次可能包含少于 batch_size 的样本数量。这在某些模型训练中是有用的,尤其是当模型设计要求固定的批次大小时。
"""
loader = torch.utils.data.DataLoader(
    dataset=dataset, 
    batch_size=16, 
    collate_fn=collate_fn, 
    shuffle=True, 
    drop_last=True
)

# for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
#     print(i, input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
#     # 0 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16])
#     break

标签:bert,torch,True,模型,labels,ids,data,self,加载
From: https://blog.csdn.net/weixin_42924890/article/details/139269528

相关文章

  • AI模型 YOLOv8在工业中的应用案例
    YOLOv8在工业中的应用案例一、YOLOv8简介YOLOv8(YouOnlyLookOnce,Version8)是YOLO系列的最新版本,以其高效和实时检测的能力在工业领域得到了广泛应用。本文将介绍YOLOv8在几个具体工业应用中的案例,并提供相关的GitHub资源。二、YOLOv8的工业应用案例案例一:自动化生产......
  • AI大模型技术速成:产品经理的转型之路
    作为一名优秀的产品经理,大模型技术简直是我我们工作中的超级助手,它让我们的产品设计和决策变得更加高效和精准。大模型在自然语言处理、数据分析、预测建模等方面的强大能力,使我能够更深入地理解用户需求,从而设计出更符合用户期望的产品。以下是大模型对产品经理的帮助主要......
  • 开源大模型与闭源大模型比较
    开源大模型与闭源大模型,你更看好哪一方?开源大模型与闭源大模型各有其优势和劣势,选择哪一方,实际上取决于多个维度的考量。以下是对两者进行详细分析的基础上,给出的综合观点。数据隐私一、开源大模型数据隐私保护:透明度:开源大模型的核心优势之一是其高度的透明度。由于源代......
  • 盒模型,百分比设置元素的大小学习
    1.盒模型<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><metaname="viewport"content="width=device-width,initial-scale=1.0"><title>盒模型/框模型</title>......
  • 深度学习-nlp-微调BERT--82
    目录importtorchimporttorch.nnasnnfromtorch.utils.dataimportTensorDataset,DataLoader,RandomSampler,SequentialSamplerfromsklearn.model_selectionimporttrain_test_splitfromtransformersimportBertTokenizer,BertConfigfromtransformersimpo......
  • 深度学习——自己的训练集——测试模型(CNN)
    测试模型1.导入新图片名称2.加载新的图片3.加载图片4.使用模型进行预测5.获取最可能的类别6.显示图片和预测的标签名称7.图像加载失败输出导入新的图像,显示图像和预测的类别标签。1.导入新图片名称new_image_path='456.jpg'2.加载新的图片new_image=cv2.i......
  • 介绍图片懒加载的几种实现方法
    在JavaScript中,懒加载(LazyLoading)主要用于延迟加载资源,例如图片、视频、音频、脚本等,直到它们真正需要时才加载。这样可以提高页面的加载速度和性能。以下是几种常见的JavaScript懒加载实现方式:1.监听滚动事件通过监听滚动事件来实现图片懒加载是一种传统并且常见的方......
  • 进程间通信(队列和生产消费模型)
    【一】引入【1】什么是进程间通信进程间通信(Inter-ProcessCommunication,IPC)是指两个或多个进程之间进行信息交换的过程【2】如何实现进程间通信借助于消息队列,进程可以将消息放入队列中,然后由另一个进程从队列中取出这种通信方式是非阻塞的,即发送进程不需要等待接收进......
  • Dolphinscheduler不重启加载Oracle驱动
    转载自刘茫茫看山问题背景某天我们的租户反馈数据库连接缺少必要的驱动,我们通过日志查看确实是缺少部分数据库的驱动,因为DolphinScheduler默认只带了Oracle和MySQL的驱动,并且需要将pom文件中的test模式去掉才可以在打包的时候引入。我们的任务量比较大,在3.0存在容错机制的情况下......
  • 如何使用Python和大模型进行数据分析和文本生成
    如何使用Python和大模型进行数据分析和文本生成Python语言以其简洁和强大的特性,成为了数据科学、机器学习和人工智能开发的首选语言之一。随着大模型(LargeLanguageModels,LLMs)如GPT-4的崛起,我们能够利用这些模型实现诸多复杂任务,从文本生成到智能对话、数据分析等等。在......