首页 > 其他分享 >昇思MindSpore 应用学习-基于 MindSpore 实现 BERT 对话情绪识别

昇思MindSpore 应用学习-基于 MindSpore 实现 BERT 对话情绪识别

时间:2024-08-01 20:53:03浏览次数:15  
标签:BERT 训练 text 模型 dataset 识别 数据 self MindSpore

基于 MindSpore 实现 BERT 对话情绪识别

模型简介

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。
BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。
在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。
因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。
BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。
对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。
下面以一个文本情感分类任务为例子来说明BERT模型的整个应用过程。

import os  # 导入os模块,用于与操作系统交互

import mindspore  # 导入MindSpore深度学习框架
from mindspore.dataset import text, GeneratorDataset, transforms  # 从mindspore.dataset导入文本处理和数据集相关的功能
from mindspore import nn, context  # 从mindspore导入神经网络模块和上下文管理器

from mindnlp._legacy.engine import Trainer, Evaluator  # 导入MindNLP的训练和评估模块
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback  # 导入用于模型保存和最佳模型回调的功能
from mindnlp._legacy.metrics import Accuracy  # 导入准确率评估指标

代码解析

  1. 导入模块
    • import os:引入操作系统相关功能,通常用于文件和路径操作。
    • import mindspore:引入MindSpore框架,提供深度学习的基础设施。
    • from mindspore.dataset import text, GeneratorDataset, transforms
      • text:处理文本数据的功能模块。
      • GeneratorDataset:可以通过生成器动态生成数据集。
      • transforms:用于数据预处理和转换的工具。
    • from mindspore import nn, context
      • nn:包含神经网络构建所需的各类模块和层。
      • context:用于设置MindSpore的运行环境和上下文配置。
    • from mindnlp._legacy.engine import Trainer, Evaluator
      • Trainer:用于模型的训练过程管理。
      • Evaluator:用于评估模型性能的工具。
    • from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
      • CheckpointCallback:用于在训练过程中保存模型的回调。
      • BestModelCallback:用于记录最佳模型的回调。
    • from mindnlp._legacy.metrics import Accuracy:引入准确率计算的指标模块。

API 解析

  • mindspore.dataset
    • 该模块提供了用于处理数据集的工具,支持文本数据的加载和转换。
  • mindspore.nn
    • 提供构建神经网络的基础组件,如层、损失函数等。
  • mindspore.context
    • 用于设置运行环境,例如选择计算设备(CPU/GPU/Ascend)。
  • mindnlp._legacy.engine
    • 提供训练和评估框架,可以简化模型的训练流程。
  • mindnlp._legacy.engine.callbacks
    • 提供回调机制,帮助用户在训练过程中实现模型保存、学习率调整等功能。
  • mindnlp._legacy.metrics
    • 提供性能评估指标,如准确率等,帮助在训练和评估阶段监测模型表现。

Building prefix dict from the default dictionary … Loading model from cache /tmp/jieba.cache Loading model cost 1.019 seconds. Prefix dict has been built successfully.

# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path  # 存储数据集的路径
        self._labels, self._text_a = [], []  # 初始化标签和文本列表
        self._load()  # 调用加载数据集的方法

    def _load(self):
        # 从指定路径加载数据集
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()  # 读取数据集文件内容
        lines = dataset.split("\n")  # 按行分割数据
        for line in lines[1:-1]:  # 遍历每一行,跳过第一行和最后一行
            label, text_a = line.split("\t")  # 按制表符分割标签和文本
            self._labels.append(int(label))  # 将标签转换为整数并添加到标签列表
            self._text_a.append(text_a)  # 将文本添加到文本列表

    def __getitem__(self, index):
        # 根据索引返回标签和文本
        return self._labels[index], self._text_a[index]

    def __len__(self):
        # 返回数据集的大小
        return len(self._labels)

代码解析

  1. 类定义
    • class SentimentDataset:定义一个用于情感分析的数据集类,主要用于加载和提供数据。
  2. 初始化方法
    • def __init__(self, path)
      • self.path = path:将数据集文件路径存储到实例变量中。
      • self._labels, self._text_a = [], []:初始化两个空列表,用于存储标签和文本数据。
      • self._load():调用私有方法 _load 来加载数据。
  3. 数据加载方法
    • def _load(self)
      • with open(self.path, "r", encoding="utf-8") as f:以只读模式打开数据集文件,指定编码为UTF-8。
      • dataset = f.read():读取文件内容。
      • lines = dataset.split("\n"):将文件内容按行切割。
      • for line in lines[1:-1]:循环遍历每行数据,跳过第一行(通常是表头)和最后一行(可能为空行)。
        • label, text_a = line.split("\t"):将每行数据按制表符分割,获取标签和文本。
        • self._labels.append(int(label)):将标签转换为整数并添加到 _labels 列表。
        • self._text_a.append(text_a):将文本添加到 _text_a 列表。
  4. 索引获取方法
    • def __getitem__(self, index):根据给定的索引返回对应的标签和文本。
      • return self._labels[index], self._text_a[index]:返回标签和文本的元组。
  5. 长度获取方法
    • def __len__(self):返回数据集中样本的数量。
      • return len(self._labels):返回标签列表的长度。

API 解析

  • __init__:构造函数,用于初始化类的实例,并设置初始状态。
  • 文件读取:使用Python内置的 open 函数来读取文件内容,通过 with 语句确保文件在使用后正确关闭。
  • 列表操作:使用列表的 append 方法动态添加数据。
  • __getitem__** 和 **__len__:这两个方法是Python的数据模型方法,允许类的实例像列表一样被索引和测量长度,使得 SentimentDataset 类可以很方便地与其他数据处理库(如PyTorch或MindSpore)配合使用。

数据集

这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符(‘\t’)分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。
label–text_a
0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1–我有事等会儿就回来和你聊
2–我见到你很高兴谢谢你帮我
这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
# 使用wget命令从指定URL下载情感检测数据集并保存为emotion_detection.tar.gz

!tar xvf emotion_detection.tar.gz
# 解压下载的tar.gz文件,提取内容

代码解析

  1. 下载数据集
    • !wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
      • !wget:使用shell命令 wget 下载文件。
      • https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz:指定要下载的数据集的URL。
      • -O emotion_detection.tar.gz:将下载的文件保存为 emotion_detection.tar.gz
  2. 解压文件
    • !tar xvf emotion_detection.tar.gz
      • !tar:使用shell命令 tar 来处理归档文件。
      • xvf:这是 tar 命令的选项:
        • x:表示解压缩。
        • v:表示显示解压缩过程中的文件(verbose模式)。
        • f:表示后面跟的是文件名。
      • emotion_detection.tar.gz:要解压的文件名。

API 解析

  • wget
    • 一个用于从网络下载文件的命令行工具,支持 HTTP、HTTPS 和 FTP 协议。
  • tar
    • 用于打包和解压缩文件的命令行工具,常用于Linux和Unix系统。.tar.gz格式是经过gzip压缩的tar档案,结合了两种工具的优点。

注意事项

  • 在执行上述命令时,确保你的环境支持 ! 前缀的shell命令,这通常在Jupyter Notebook或某些支持魔法命令的环境中有效。
  • 下载和解压的操作需要网络连接,并且保存路径需要有写入权限。

数据加载和数据预处理

新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。

import numpy as np  # 导入NumPy库,用于数值计算和操作

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    # 获取当前设备目标,判断是否为Ascend
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]  # 定义数据集的列名
    
    # 创建生成器数据集
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    
    # 定义类型转换操作
    type_cast_op = transforms.TypeCast(mindspore.int32)

    def tokenize_and_pad(text):
        # 根据设备类型进行分词和填充操作
        if is_ascend:
            # 在Ascend上进行填充和截断
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            # 在非Ascend设备上只进行分词
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']  # 返回输入ID和注意力掩码

    # 数据集映射操作,应用分词和填充函数
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    
    # 为标签应用类型转换
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    
    # 批处理数据集
    if is_ascend:
        dataset = dataset.batch(batch_size)  # 在Ascend上使用常规批处理
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})  # 在非Ascend上使用填充批处理

    return dataset  # 返回处理后的数据集

代码解析

  1. 导入
    • import numpy as np:导入NumPy库,通常用于数值运算,但在此代码中未直接使用。
  2. 函数定义
    • def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True)
      • source:数据源,可以是文本文件、数据集对象等。
      • tokenizer:用于将文本转换为模型输入格式的分词器。
      • max_seq_len:设置文本的最大序列长度,超过该长度的文本将被截断。
      • batch_size:每个批次的样本数量。
      • shuffle:是否在生成数据集时打乱数据顺序。
  3. 设备判断
    • is_ascend = mindspore.get_context('device_target') == 'Ascend':判断当前执行环境是否为Ascend设备。
  4. 数据集创建
    • column_names = ["label", "text_a"]:定义数据集中包含的列名。
    • dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle):创建一个生成器数据集。
  5. 类型转换操作
    • type_cast_op = transforms.TypeCast(mindspore.int32):创建一个类型转换操作,将标签转换为32位整数。
  6. 分词和填充
    • def tokenize_and_pad(text):定义一个内部函数用于对输入文本进行分词和填充。
      • tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len):在Ascend设备上进行分词,填充到最大长度。
      • 返回分词后的 input_idsattention_mask
  7. 数据集映射
    • dataset = dataset.map(...):将文本列应用分词和填充操作,同时将标签列应用类型转换操作。
  8. 批处理
    • if is_ascend:根据设备类型选择合适的批处理方式。
      • dataset.batch(batch_size):在Ascend设备上使用常规批处理。
      • dataset.padded_batch(batch_size, pad_info=...):在其他设备上使用填充批处理,指定填充值。
  9. 返回数据集
    • return dataset:返回处理后的数据集对象。

API 解析

  • GeneratorDataset:MindSpore中的一个数据集类,允许用户通过生成器动态生成数据。
  • map:数据集的映射方法,可以对每个样本应用给定的操作。
  • TypeCast:用于将数据类型转换为指定类型的操作。
  • batch / padded_batch:用于将数据集分成批次,支持标准批处理和填充批处理,以处理不同长度的输入。

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:

from mindnlp.transformers import BertTokenizer  # 导入BertTokenizer类,用于加载和使用BERT分词器

# 从预训练的BERT模型加载分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')   

# 获取分词器的填充标记ID
tokenizer.pad_token_id  

# 创建训练数据集,使用自定义的SentimentDataset类和分词器
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)

# 创建验证数据集
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)

# 创建测试数据集,禁用打乱
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)

# 获取训练数据集的列名称
dataset_train.get_col_names()  

# 从训练数据集中获取一个迭代器并打印下一个样本
print(next(dataset_train.create_tuple_iterator()))  

代码解析

  1. 导入分词器
    • from mindnlp.transformers import BertTokenizer:从MindNLP库导入BERT分词器的类。
  2. 加载预训练的BERT分词器
    • tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
      • 使用指定的模型名称加载预训练的BERT分词器,这里是中文BERT模型。
  3. 获取填充标记ID
    • tokenizer.pad_token_id:获取分词器的填充标记的ID,用于后续的填充操作。
  4. 创建数据集
    • dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
      • 创建训练数据集,使用自定义的 SentimentDataset 类,将数据集路径和分词器传入。
    • dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
      • 创建验证数据集。
    • dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
      • 创建测试数据集,并设置 shuffle=False 以保持数据顺序。
  5. 获取列名称
    • dataset_train.get_col_names():调用方法获取训练数据集中列的名称,通常是标签和文本列。
  6. 创建迭代器并打印样本
    • print(next(dataset_train.create_tuple_iterator()))
      • 创建一个迭代器,使用 next() 获取下一个样本并打印出来,通常返回的是一个元组,包含标签和分词后的文本。

API 解析

  • BertTokenizer
    • 用于处理BERT模型的文本输入,负责将文本转换为模型可以接受的ID格式,并执行必要的填充和截断。
  • from_pretrained
    • 类方法,用于加载预训练模型的分词器,支持多种语言和任务。
  • pad_token_id
    • 分词器的填充标记ID,通常用于处理不同长度的输入,确保输入形状一致。
  • SentimentDataset
    • 自定义的数据集类,用于从指定文件加载情感分析相关的数据。
  • process_dataset
    • 处理数据集的函数,执行分词、填充和批处理等操作,返回已处理的数据集。
  • create_tuple_iterator
    • 数据集的方法,用于创建一个迭代器,可以返回数据集中的样本。
  • next()
    • Python内置函数,用于获取迭代器的下一个值。

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel  # 导入BERT模型和用于序列分类的特定模型
from mindnlp._legacy.amp import auto_mixed_precision  # 导入自动混合精度的工具

# 设置BERT配置并定义训练参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)  
# 从预训练的BERT模型加载序列分类模型,设置输出标签数量为3

model = auto_mixed_precision(model, 'O1')  
# 应用自动混合精度以提高训练性能,'O1'为混合精度的优化策略

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)  
# 创建Adam优化器,设置学习率为2e-5,并将模型可训练参数作为优化目标

metric = Accuracy()  # 定义准确率作为评估指标

# 定义回调函数以保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)  
# 创建检查点回调,设置保存路径、检查点名称、保存频率和最大保留检查点数量

best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)  
# 创建最优模型回调,自动加载最佳模型

# 创建训练器
trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])  

%%time  # 记录训练时间

# 开始训练
trainer.run(tgt_columns="labels")  

代码解析

  1. 导入必要的库
    • from mindnlp.transformers import BertForSequenceClassification, BertModel:导入MindNLP中的BERT序列分类模型。
    • from mindnlp._legacy.amp import auto_mixed_precision:导入混合精度训练工具,以优化模型训练的性能和内存使用。
  2. 模型创建
    • model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
      • 从预训练的中文BERT模型加载一个用于序列分类的模型,设置标签数量为3(例如,情感分析中的三种情感)。
  3. 混合精度训练
    • model = auto_mixed_precision(model, 'O1')
      • 应用自动混合精度以节省内存和加速训练,其中 ‘O1’ 是适用于混合精度训练的优化策略。
  4. 优化器定义
    • optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
      • 使用Adam优化器,学习率设为2e-5,优化目标是模型的可训练参数。
  5. 指标定义
    • metric = Accuracy()
      • 定义准确率作为模型评估的指标。
  6. 定义回调
    • ckpoint_cb = CheckpointCallback(...)
      • 创建检查点回调,以便在训练过程中定期保存模型检查点,设置保存路径、名称和最大检查点数量。
    • best_model_cb = BestModelCallback(...)
      • 创建最佳模型回调,以自动加载保存的最佳模型。
  7. 训练器创建
    • trainer = Trainer(...)
      • 初始化训练器,传入模型、训练数据集、验证数据集、评估指标、训练周期、优化器和回调列表。
  8. 时间记录
    • %%time:在Jupyter Notebook中使用魔法命令记录代码块的执行时间。
  9. 开始训练
    • trainer.run(tgt_columns="labels")
      • 开始模型的训练过程,指定目标列为标签列。

API 解析

  • BertForSequenceClassification
    • BERT模型的变体,专门用于序列分类任务,能够处理文本分类任务的输入。
  • auto_mixed_precision
    • 用于自动应用混合精度训练,结合使用不同的数据类型以提高训练效率。
  • nn.Adam
    • Adam优化器,常用于深度学习中的参数优化。
  • Accuracy
    • 评估指标类,用于计算模型的准确率。
  • CheckpointCallback
    • 回调类,用于在训练过程中保存模型检查点。
  • BestModelCallback
    • 回调类,用于自动保存和加载最佳模型。
  • Trainer
    • MindSpore中用于管理整个训练过程的类,负责模型训练和评估的实施。
  • run
    • 启动训练过程的方法,传入训练所需的参数。

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)  
# 创建评估器,传入训练好的模型、测试数据集和评估指标

evaluator.run(tgt_columns="labels")  
# 运行评估,指定目标列为标签列

代码解析

  1. 创建评估器
    • evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
      • 使用训练好的模型、测试数据集和指定的评估指标初始化评估器。
      • network=model:传入已经训练好的模型。
      • eval_dataset=dataset_test:传入要评估的测试数据集。
      • metrics=metric:传入评估所使用的指标(如准确率)。
  2. 运行评估
    • evaluator.run(tgt_columns="labels")
      • 调用评估器的 run 方法开始评估过程,使用 tgt_columns="labels" 指定要评估的目标列为标签列。

API 解析

  • Evaluator
    • 用于评估模型性能的类,能够帮助用户在特定数据集上计算和输出模型的评估指标。
  • run
    • 方法用于执行评估过程,通常会输出评估结果和性能指标,比如准确率、F1分数等。
  • tgt_columns
    • 指定在评估过程中需要关注的标签列,通常是模型预测的目标列。

模型推理

遍历推理数据集,将结果与标签进行统一展示。

# 加载待预测数据集
dataset_infer = SentimentDataset("data/infer.tsv")  

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}  # 定义标签映射

    # 将输入文本进行分词并转换为Tensor格式
    text_tokenized = Tensor([tokenizer(text).input_ids])  
    # 使用模型进行预测
    logits = model(text_tokenized)  
    # 获取预测标签(取最大值的索引作为预测结果)
    predict_label = logits[0].asnumpy().argmax()  
    # 格式化输出信息
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"  
    if label is not None:
        info += f" , label: '{label_map[label]}'"  # 如果提供真实标签,则输出真实标签
    print(info)  # 打印信息

# 遍历待预测数据集并进行预测
for label, text in dataset_infer:
    predict(text, label)  

代码解析

  1. 加载待预测数据集
    • dataset_infer = SentimentDataset("data/infer.tsv")
      • 从指定的文件路径加载待预测的数据集,这里是 infer.tsv 文件。
  2. 定义预测函数
    • def predict(text, label=None):
      • 定义一个 predict 函数,接受文本输入和可选的真实标签。
  3. 标签映射
    • label_map = {0: "消极", 1: "中性", 2: "积极"}
      • 创建一个字典,将数值标签映射到对应的情感描述。
  4. 文本分词和转换
    • text_tokenized = Tensor([tokenizer(text).input_ids])
      • 使用提前定义的分词器对输入文本进行分词,并将生成的ID转换成Tensor格式,以便输入到模型中。
  5. 模型预测
    • logits = model(text_tokenized)
      • 将分词后的文本输入到已训练的模型中获取预测结果(logits)。
  6. 获取预测标签
    • predict_label = logits[0].asnumpy().argmax()
      • 从模型输出的logits中获取预测标签,通过取最大值的索引(argmax())来确定最可能的情感类别。
  7. 格式化输出信息
    • info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
      • 创建一个包含输入文本和预测结果的字符串。
    • if label is not None: 语句用于检查是否提供了真实标签,如果有,则将真实标签添加到输出信息中。
  8. 打印信息
    • print(info):输出格式化的信息到控制台。
  9. 遍历数据集并进行预测
    • for label, text in dataset_infer:
      • 遍历待预测的数据集 dataset_infer,对于每一对 (label, text),调用 predict 函数进行预测。

API 解析

  • SentimentDataset
    • 自定义的数据集类,用于加载情感分析任务的输入数据。
  • Tensor
    • MindSpore中的数据结构,用于存储和处理多维数据,尤其是在深度学习中作为输入或输出。
  • tokenizer
    • 分词器实例,负责将文本转换为模型可以接受的ID格式。
  • model
    • 已训练的情感分类模型,用于对输入文本生成预测结果。
  • logits
    • 模型的输出,通常是每个类别的未归一化的得分,使用 argmax() 获取预测标签。
  • argmax()
    • NumPy中的函数,用于返回数组中最大值的索引,在此用来确定最可能的情感类别。

自定义推理数据集

自己输入推理数据,展示模型的泛化能力。

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

标签:BERT,训练,text,模型,dataset,识别,数据,self,MindSpore
From: https://blog.csdn.net/qq_43638033/article/details/140807002

相关文章

  • PHP文档识别接口,文字识别、OCR API
    在数字化浪潮的前沿下,文档识别接口如同一位智慧的在线“编目家”,随时随地工作在浩瀚的字符海洋中。想象一下,当我们面对堆积如山的纸质扫描文档,各种印刷文字以及文本图片时,通过文档识别功能,仅需导入图片,便能快速、精准地让静默的文字跃然于屏,化作清晰可读、可编辑、可归类、可......
  • 面部识别 - 机器学习
    我正在尝试在GoogleColab上使用Tensorflow进行面部识别,但遇到错误。以前工作得很好,但现在却抛出了这个错误。完整的.ipynb文件已链接(请注意,您需要一个包含.jpg文件的负数、正数和锚文件夹才能运行程序。)使暹罗模型出错文件链接:https://www.mediafire.com......
  • C#银行卡ocr识别接口的简单集成方式
    银行卡识别接口是指:以文字识别技术为基础衍生的银行卡卡面信息识别接口,该接口可以快速、精准的将银行卡卡面上包含银行卡号、卡类型、银行名称等文字信息提取成功,以帮助需要支付的平台进行银行卡身份的快速核验。企业又该如何快速的对银行卡识别接口进行集成?可以选择翔云......
  • C#营业执照识别接口、营业执照ocr
    营业执照识别接口,是基于光学字符识别技术的一种将图像中的字符转化为可编辑文本的技术。翔云营业执照识别接口,自主ocr核心技术,可快速精准识别营业执照上的全部字段信息,支持三证合一版营业执照和五证合一版营业执照。翔云营业执照识别接口提供免费测试体验服务,助力企业降......
  • 记一次 JUnit5 问题排查(不识别单测、mock 对象空指针等问题)
    背景最近开始使用JUnit5写单元测试,本地运行成功之后提交代码,触发流水线进进行覆盖率计算。结果出来之后傻眼了,几百个单侧只能识别到2个。先简单说一下具体的环境。本地使用IDEA自带的maven,版本为3.9.6,JUnit版本5.7.0。流水线使用jenkins触发maven命令,用的maven......
  • 图像识别的开源项目列举
    当涉及到图像识别的开源项目和示例代码时,以下是一些适合初学者快速提升能力的项目:TensorFlowModels:TensorFlowModels是一个由TensorFlow团队维护的开源项目,提供了许多经典的图像识别模型的实现代码。你可以从中学习和理解各种图像分类、目标检测和图像分割等任务的实现方式......
  • 论文阅读:引入词集级注意力机制的中文命名实体识别方法
    WSA-CNER方法首先,将输入序列的每个字映射成一个字向量;然后,将外部词汇信息整合到每个字的最终表示中;最后,将字的最终表示输入到序列建模层和标签预测层中,得到最终的预测结果。输入表示层使用SoftLexicon方法将输入序列中每个字的词典匹配结果划分为4个词集(BMES)。输入序列......
  • 基于Matlab的车牌识别系统设计与实现
    基于Matlab的车牌识别系统设计与实现摘要随着智能交通系统的不断演进,车牌识别技术已成为提升交通管理效率与准确性的关键。本文深入探讨了基于Matlab平台的车牌识别系统设计与实现,该系统通过精细的图像预处理、高效的车牌定位算法、精准的字符分割与识别技术,显著提升了车牌识......
  • GPT1-3及BERT的模型概述
    GPT1-3及BERT的模型概述(2020年5月之前LLMs主流模型)GPT-1(2018年6月)......