基于 MindSpore 实现 BERT 对话情绪识别
模型简介
BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。
BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。
在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。
因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。
BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。
对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。
下面以一个文本情感分类任务为例子来说明BERT模型的整个应用过程。
import os # 导入os模块,用于与操作系统交互
import mindspore # 导入MindSpore深度学习框架
from mindspore.dataset import text, GeneratorDataset, transforms # 从mindspore.dataset导入文本处理和数据集相关的功能
from mindspore import nn, context # 从mindspore导入神经网络模块和上下文管理器
from mindnlp._legacy.engine import Trainer, Evaluator # 导入MindNLP的训练和评估模块
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback # 导入用于模型保存和最佳模型回调的功能
from mindnlp._legacy.metrics import Accuracy # 导入准确率评估指标
代码解析
- 导入模块:
import os
:引入操作系统相关功能,通常用于文件和路径操作。import mindspore
:引入MindSpore框架,提供深度学习的基础设施。from mindspore.dataset import text, GeneratorDataset, transforms
:text
:处理文本数据的功能模块。GeneratorDataset
:可以通过生成器动态生成数据集。transforms
:用于数据预处理和转换的工具。
from mindspore import nn, context
:nn
:包含神经网络构建所需的各类模块和层。context
:用于设置MindSpore的运行环境和上下文配置。
from mindnlp._legacy.engine import Trainer, Evaluator
:Trainer
:用于模型的训练过程管理。Evaluator
:用于评估模型性能的工具。
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
:CheckpointCallback
:用于在训练过程中保存模型的回调。BestModelCallback
:用于记录最佳模型的回调。
from mindnlp._legacy.metrics import Accuracy
:引入准确率计算的指标模块。
API 解析
- mindspore.dataset:
- 该模块提供了用于处理数据集的工具,支持文本数据的加载和转换。
- mindspore.nn:
- 提供构建神经网络的基础组件,如层、损失函数等。
- mindspore.context:
- 用于设置运行环境,例如选择计算设备(CPU/GPU/Ascend)。
- mindnlp._legacy.engine:
- 提供训练和评估框架,可以简化模型的训练流程。
- mindnlp._legacy.engine.callbacks:
- 提供回调机制,帮助用户在训练过程中实现模型保存、学习率调整等功能。
- mindnlp._legacy.metrics:
- 提供性能评估指标,如准确率等,帮助在训练和评估阶段监测模型表现。
Building prefix dict from the default dictionary … Loading model from cache /tmp/jieba.cache Loading model cost 1.019 seconds. Prefix dict has been built successfully.
# prepare dataset
class SentimentDataset:
"""Sentiment Dataset"""
def __init__(self, path):
self.path = path # 存储数据集的路径
self._labels, self._text_a = [], [] # 初始化标签和文本列表
self._load() # 调用加载数据集的方法
def _load(self):
# 从指定路径加载数据集
with open(self.path, "r", encoding="utf-8") as f:
dataset = f.read() # 读取数据集文件内容
lines = dataset.split("\n") # 按行分割数据
for line in lines[1:-1]: # 遍历每一行,跳过第一行和最后一行
label, text_a = line.split("\t") # 按制表符分割标签和文本
self._labels.append(int(label)) # 将标签转换为整数并添加到标签列表
self._text_a.append(text_a) # 将文本添加到文本列表
def __getitem__(self, index):
# 根据索引返回标签和文本
return self._labels[index], self._text_a[index]
def __len__(self):
# 返回数据集的大小
return len(self._labels)
代码解析
- 类定义:
class SentimentDataset
:定义一个用于情感分析的数据集类,主要用于加载和提供数据。
- 初始化方法:
def __init__(self, path)
:self.path = path
:将数据集文件路径存储到实例变量中。self._labels, self._text_a = [], []
:初始化两个空列表,用于存储标签和文本数据。self._load()
:调用私有方法_load
来加载数据。
- 数据加载方法:
def _load(self)
:with open(self.path, "r", encoding="utf-8") as f
:以只读模式打开数据集文件,指定编码为UTF-8。dataset = f.read()
:读取文件内容。lines = dataset.split("\n")
:将文件内容按行切割。for line in lines[1:-1]
:循环遍历每行数据,跳过第一行(通常是表头)和最后一行(可能为空行)。label, text_a = line.split("\t")
:将每行数据按制表符分割,获取标签和文本。self._labels.append(int(label))
:将标签转换为整数并添加到_labels
列表。self._text_a.append(text_a)
:将文本添加到_text_a
列表。
- 索引获取方法:
def __getitem__(self, index)
:根据给定的索引返回对应的标签和文本。return self._labels[index], self._text_a[index]
:返回标签和文本的元组。
- 长度获取方法:
def __len__(self)
:返回数据集中样本的数量。return len(self._labels)
:返回标签列表的长度。
API 解析
__init__
:构造函数,用于初始化类的实例,并设置初始状态。- 文件读取:使用Python内置的
open
函数来读取文件内容,通过with
语句确保文件在使用后正确关闭。 - 列表操作:使用列表的
append
方法动态添加数据。 __getitem__
** 和 **__len__
:这两个方法是Python的数据模型方法,允许类的实例像列表一样被索引和测量长度,使得SentimentDataset
类可以很方便地与其他数据处理库(如PyTorch或MindSpore)配合使用。
数据集
这里提供一份已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符(‘\t’)分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。
label–text_a
0–谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1–我有事等会儿就回来和你聊
2–我见到你很高兴谢谢你帮我
这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
# 使用wget命令从指定URL下载情感检测数据集并保存为emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz
# 解压下载的tar.gz文件,提取内容
代码解析
- 下载数据集:
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
:!wget
:使用shell命令wget
下载文件。https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz
:指定要下载的数据集的URL。-O emotion_detection.tar.gz
:将下载的文件保存为emotion_detection.tar.gz
。
- 解压文件:
!tar xvf emotion_detection.tar.gz
:!tar
:使用shell命令tar
来处理归档文件。xvf
:这是tar
命令的选项:x
:表示解压缩。v
:表示显示解压缩过程中的文件(verbose模式)。f
:表示后面跟的是文件名。
emotion_detection.tar.gz
:要解压的文件名。
API 解析
wget
:- 一个用于从网络下载文件的命令行工具,支持 HTTP、HTTPS 和 FTP 协议。
tar
:- 用于打包和解压缩文件的命令行工具,常用于Linux和Unix系统。
.tar.gz
格式是经过gzip压缩的tar档案,结合了两种工具的优点。
- 用于打包和解压缩文件的命令行工具,常用于Linux和Unix系统。
注意事项
- 在执行上述命令时,确保你的环境支持
!
前缀的shell命令,这通常在Jupyter Notebook或某些支持魔法命令的环境中有效。 - 下载和解压的操作需要网络连接,并且保存路径需要有写入权限。
数据加载和数据预处理
新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。
import numpy as np # 导入NumPy库,用于数值计算和操作
def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
# 获取当前设备目标,判断是否为Ascend
is_ascend = mindspore.get_context('device_target') == 'Ascend'
column_names = ["label", "text_a"] # 定义数据集的列名
# 创建生成器数据集
dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
# 定义类型转换操作
type_cast_op = transforms.TypeCast(mindspore.int32)
def tokenize_and_pad(text):
# 根据设备类型进行分词和填充操作
if is_ascend:
# 在Ascend上进行填充和截断
tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
else:
# 在非Ascend设备上只进行分词
tokenized = tokenizer(text)
return tokenized['input_ids'], tokenized['attention_mask'] # 返回输入ID和注意力掩码
# 数据集映射操作,应用分词和填充函数
dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
# 为标签应用类型转换
dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
# 批处理数据集
if is_ascend:
dataset = dataset.batch(batch_size) # 在Ascend上使用常规批处理
else:
dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
'attention_mask': (None, 0)}) # 在非Ascend上使用填充批处理
return dataset # 返回处理后的数据集
代码解析
- 导入:
import numpy as np
:导入NumPy库,通常用于数值运算,但在此代码中未直接使用。
- 函数定义:
def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True)
:source
:数据源,可以是文本文件、数据集对象等。tokenizer
:用于将文本转换为模型输入格式的分词器。max_seq_len
:设置文本的最大序列长度,超过该长度的文本将被截断。batch_size
:每个批次的样本数量。shuffle
:是否在生成数据集时打乱数据顺序。
- 设备判断:
is_ascend = mindspore.get_context('device_target') == 'Ascend'
:判断当前执行环境是否为Ascend设备。
- 数据集创建:
column_names = ["label", "text_a"]
:定义数据集中包含的列名。dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
:创建一个生成器数据集。
- 类型转换操作:
type_cast_op = transforms.TypeCast(mindspore.int32)
:创建一个类型转换操作,将标签转换为32位整数。
- 分词和填充:
def tokenize_and_pad(text)
:定义一个内部函数用于对输入文本进行分词和填充。tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
:在Ascend设备上进行分词,填充到最大长度。- 返回分词后的
input_ids
和attention_mask
。
- 数据集映射:
dataset = dataset.map(...)
:将文本列应用分词和填充操作,同时将标签列应用类型转换操作。
- 批处理:
if is_ascend
:根据设备类型选择合适的批处理方式。dataset.batch(batch_size)
:在Ascend设备上使用常规批处理。dataset.padded_batch(batch_size, pad_info=...)
:在其他设备上使用填充批处理,指定填充值。
- 返回数据集:
return dataset
:返回处理后的数据集对象。
API 解析
- GeneratorDataset:MindSpore中的一个数据集类,允许用户通过生成器动态生成数据。
- map:数据集的映射方法,可以对每个样本应用给定的操作。
- TypeCast:用于将数据类型转换为指定类型的操作。
- batch / padded_batch:用于将数据集分成批次,支持标准批处理和填充批处理,以处理不同长度的输入。
昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:
from mindnlp.transformers import BertTokenizer # 导入BertTokenizer类,用于加载和使用BERT分词器
# 从预训练的BERT模型加载分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 获取分词器的填充标记ID
tokenizer.pad_token_id
# 创建训练数据集,使用自定义的SentimentDataset类和分词器
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
# 创建验证数据集
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
# 创建测试数据集,禁用打乱
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
# 获取训练数据集的列名称
dataset_train.get_col_names()
# 从训练数据集中获取一个迭代器并打印下一个样本
print(next(dataset_train.create_tuple_iterator()))
代码解析
- 导入分词器:
from mindnlp.transformers import BertTokenizer
:从MindNLP库导入BERT分词器的类。
- 加载预训练的BERT分词器:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
:- 使用指定的模型名称加载预训练的BERT分词器,这里是中文BERT模型。
- 获取填充标记ID:
tokenizer.pad_token_id
:获取分词器的填充标记的ID,用于后续的填充操作。
- 创建数据集:
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
:- 创建训练数据集,使用自定义的
SentimentDataset
类,将数据集路径和分词器传入。
- 创建训练数据集,使用自定义的
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
:- 创建验证数据集。
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
:- 创建测试数据集,并设置
shuffle=False
以保持数据顺序。
- 创建测试数据集,并设置
- 获取列名称:
dataset_train.get_col_names()
:调用方法获取训练数据集中列的名称,通常是标签和文本列。
- 创建迭代器并打印样本:
print(next(dataset_train.create_tuple_iterator()))
:- 创建一个迭代器,使用
next()
获取下一个样本并打印出来,通常返回的是一个元组,包含标签和分词后的文本。
- 创建一个迭代器,使用
API 解析
- BertTokenizer:
- 用于处理BERT模型的文本输入,负责将文本转换为模型可以接受的ID格式,并执行必要的填充和截断。
- from_pretrained:
- 类方法,用于加载预训练模型的分词器,支持多种语言和任务。
- pad_token_id:
- 分词器的填充标记ID,通常用于处理不同长度的输入,确保输入形状一致。
- SentimentDataset:
- 自定义的数据集类,用于从指定文件加载情感分析相关的数据。
- process_dataset:
- 处理数据集的函数,执行分词、填充和批处理等操作,返回已处理的数据集。
- create_tuple_iterator:
- 数据集的方法,用于创建一个迭代器,可以返回数据集中的样本。
- next():
- Python内置函数,用于获取迭代器的下一个值。
模型构建
通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。
from mindnlp.transformers import BertForSequenceClassification, BertModel # 导入BERT模型和用于序列分类的特定模型
from mindnlp._legacy.amp import auto_mixed_precision # 导入自动混合精度的工具
# 设置BERT配置并定义训练参数
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
# 从预训练的BERT模型加载序列分类模型,设置输出标签数量为3
model = auto_mixed_precision(model, 'O1')
# 应用自动混合精度以提高训练性能,'O1'为混合精度的优化策略
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
# 创建Adam优化器,设置学习率为2e-5,并将模型可训练参数作为优化目标
metric = Accuracy() # 定义准确率作为评估指标
# 定义回调函数以保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
# 创建检查点回调,设置保存路径、检查点名称、保存频率和最大保留检查点数量
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)
# 创建最优模型回调,自动加载最佳模型
# 创建训练器
trainer = Trainer(network=model, train_dataset=dataset_train,
eval_dataset=dataset_val, metrics=metric,
epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
%%time # 记录训练时间
# 开始训练
trainer.run(tgt_columns="labels")
代码解析
- 导入必要的库:
from mindnlp.transformers import BertForSequenceClassification, BertModel
:导入MindNLP中的BERT序列分类模型。from mindnlp._legacy.amp import auto_mixed_precision
:导入混合精度训练工具,以优化模型训练的性能和内存使用。
- 模型创建:
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
:- 从预训练的中文BERT模型加载一个用于序列分类的模型,设置标签数量为3(例如,情感分析中的三种情感)。
- 混合精度训练:
model = auto_mixed_precision(model, 'O1')
:- 应用自动混合精度以节省内存和加速训练,其中 ‘O1’ 是适用于混合精度训练的优化策略。
- 优化器定义:
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
:- 使用Adam优化器,学习率设为2e-5,优化目标是模型的可训练参数。
- 指标定义:
metric = Accuracy()
:- 定义准确率作为模型评估的指标。
- 定义回调:
ckpoint_cb = CheckpointCallback(...)
:- 创建检查点回调,以便在训练过程中定期保存模型检查点,设置保存路径、名称和最大检查点数量。
best_model_cb = BestModelCallback(...)
:- 创建最佳模型回调,以自动加载保存的最佳模型。
- 训练器创建:
trainer = Trainer(...)
:- 初始化训练器,传入模型、训练数据集、验证数据集、评估指标、训练周期、优化器和回调列表。
- 时间记录:
%%time
:在Jupyter Notebook中使用魔法命令记录代码块的执行时间。
- 开始训练:
trainer.run(tgt_columns="labels")
:- 开始模型的训练过程,指定目标列为标签列。
API 解析
- BertForSequenceClassification:
- BERT模型的变体,专门用于序列分类任务,能够处理文本分类任务的输入。
- auto_mixed_precision:
- 用于自动应用混合精度训练,结合使用不同的数据类型以提高训练效率。
- nn.Adam:
- Adam优化器,常用于深度学习中的参数优化。
- Accuracy:
- 评估指标类,用于计算模型的准确率。
- CheckpointCallback:
- 回调类,用于在训练过程中保存模型检查点。
- BestModelCallback:
- 回调类,用于自动保存和加载最佳模型。
- Trainer:
- MindSpore中用于管理整个训练过程的类,负责模型训练和评估的实施。
- run:
- 启动训练过程的方法,传入训练所需的参数。
模型验证
将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
# 创建评估器,传入训练好的模型、测试数据集和评估指标
evaluator.run(tgt_columns="labels")
# 运行评估,指定目标列为标签列
代码解析
- 创建评估器:
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
:- 使用训练好的模型、测试数据集和指定的评估指标初始化评估器。
network=model
:传入已经训练好的模型。eval_dataset=dataset_test
:传入要评估的测试数据集。metrics=metric
:传入评估所使用的指标(如准确率)。
- 运行评估:
evaluator.run(tgt_columns="labels")
:- 调用评估器的
run
方法开始评估过程,使用tgt_columns="labels"
指定要评估的目标列为标签列。
- 调用评估器的
API 解析
- Evaluator:
- 用于评估模型性能的类,能够帮助用户在特定数据集上计算和输出模型的评估指标。
- run:
- 方法用于执行评估过程,通常会输出评估结果和性能指标,比如准确率、F1分数等。
- tgt_columns:
- 指定在评估过程中需要关注的标签列,通常是模型预测的目标列。
模型推理
遍历推理数据集,将结果与标签进行统一展示。
# 加载待预测数据集
dataset_infer = SentimentDataset("data/infer.tsv")
def predict(text, label=None):
label_map = {0: "消极", 1: "中性", 2: "积极"} # 定义标签映射
# 将输入文本进行分词并转换为Tensor格式
text_tokenized = Tensor([tokenizer(text).input_ids])
# 使用模型进行预测
logits = model(text_tokenized)
# 获取预测标签(取最大值的索引作为预测结果)
predict_label = logits[0].asnumpy().argmax()
# 格式化输出信息
info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
if label is not None:
info += f" , label: '{label_map[label]}'" # 如果提供真实标签,则输出真实标签
print(info) # 打印信息
# 遍历待预测数据集并进行预测
for label, text in dataset_infer:
predict(text, label)
代码解析
- 加载待预测数据集:
dataset_infer = SentimentDataset("data/infer.tsv")
:- 从指定的文件路径加载待预测的数据集,这里是
infer.tsv
文件。
- 从指定的文件路径加载待预测的数据集,这里是
- 定义预测函数:
def predict(text, label=None):
:- 定义一个
predict
函数,接受文本输入和可选的真实标签。
- 定义一个
- 标签映射:
label_map = {0: "消极", 1: "中性", 2: "积极"}
:- 创建一个字典,将数值标签映射到对应的情感描述。
- 文本分词和转换:
text_tokenized = Tensor([tokenizer(text).input_ids])
:- 使用提前定义的分词器对输入文本进行分词,并将生成的ID转换成Tensor格式,以便输入到模型中。
- 模型预测:
logits = model(text_tokenized)
:- 将分词后的文本输入到已训练的模型中获取预测结果(logits)。
- 获取预测标签:
predict_label = logits[0].asnumpy().argmax()
:- 从模型输出的logits中获取预测标签,通过取最大值的索引(
argmax()
)来确定最可能的情感类别。
- 从模型输出的logits中获取预测标签,通过取最大值的索引(
- 格式化输出信息:
info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
:- 创建一个包含输入文本和预测结果的字符串。
if label is not None:
语句用于检查是否提供了真实标签,如果有,则将真实标签添加到输出信息中。
- 打印信息:
print(info)
:输出格式化的信息到控制台。
- 遍历数据集并进行预测:
for label, text in dataset_infer:
:- 遍历待预测的数据集
dataset_infer
,对于每一对(label, text)
,调用predict
函数进行预测。
- 遍历待预测的数据集
API 解析
- SentimentDataset:
- 自定义的数据集类,用于加载情感分析任务的输入数据。
- Tensor:
- MindSpore中的数据结构,用于存储和处理多维数据,尤其是在深度学习中作为输入或输出。
- tokenizer:
- 分词器实例,负责将文本转换为模型可以接受的ID格式。
- model:
- 已训练的情感分类模型,用于对输入文本生成预测结果。
- logits:
- 模型的输出,通常是每个类别的未归一化的得分,使用
argmax()
获取预测标签。
- 模型的输出,通常是每个类别的未归一化的得分,使用
- argmax():
- NumPy中的函数,用于返回数组中最大值的索引,在此用来确定最可能的情感类别。
自定义推理数据集
自己输入推理数据,展示模型的泛化能力。
predict("家人们咱就是说一整个无语住了 绝绝子叠buff")
标签:BERT,训练,text,模型,dataset,识别,数据,self,MindSpore
From: https://blog.csdn.net/qq_43638033/article/details/140807002