首页 > 其他分享 >Bert实现情感分析demo

Bert实现情感分析demo

时间:2024-07-13 17:20:30浏览次数:12  
标签:Bert bert demo self dataset 情感 train model save

Bert实现情感分析demo

数据集

IMDB数据集.

代码以及部分讲解

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

没有cuda就启用cpu

class CustomClassifier(nn.Module):
    def __init__(self,bert):
        super(CustomClassifier,self).__init__()
        self.bert = bert
        self.fc1 = nn.Linear(768,512)
        self.fc2 = nn.Linear(512,2)
        self.dropout = nn.Dropout(0.1) #防止泛化
        self.relu = nn.ReLU() #激活函数,防止梯度消失。
        self.softmax = nn.LogSoftmax(dim=1) 

    def forward(self,input_ids,attention_mask,labels = None):
        _, cls_hidden_state = self.bert(input_ids, attention_mask=attention_mask, return_dict=False) #得到隐层状态,可以继续前向传播
        x = self.fc1(cls_hidden_state)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        if labels is not None:
            loss_fn = nn.NLLLoss() #NLLLOSS损失函数和LogSoftmax绑定使用
            loss = loss_fn(x,labels)
            return loss,x
        return x


    def save(self,save_directory):
        os.makedirs(save_directory,exist_ok=True)
        torch.sava(self.state_dict(),os.path.join(save_directory,"pytorch_model.bin"))
        self.bert.config.to_json_file(os.path.join(save_directory,"config.json"))

    @classmethod #主要可以通过cls使用类的方法
    def from_pretrained(cls, save_directory):
        bert_model = AutoModel.from_pretrained(save_directory)
        model = cls(bert_model)
        model.load_state_dict(torch.load(os.path.join(save_directory,'pytorch_model.bin')))
        return model

这是主要的类,有关参数可以进行微调。
这个demo使用模型是bert-base,预训练模型以及分词器在huggingface上下载。

主要下载tokenizer_config.json,vocab.txt,config.json,pytorch_model.bin文件
BERT Base:

  • 隐藏层(Transformer 层)数量:12 层
  • 隐藏状态维度:768
  • 自注意力头数量:12
  • 总参数数量:110M
def load_local_dataset(data_dir):
    data = {'train': [], 'test': []}
    labels = {'neg': 0, 'pos': 1}
    
    for split in ['train', 'test']:
        for label in ['neg', 'pos']:
            folder = os.path.join(data_dir, split, label)
            for filename in tqdm(os.listdir(folder), desc=f"Loading {split} {label} data"):
                with open(os.path.join(folder, filename), 'r', encoding='utf-8') as f:
                    data[split].append({'text': f.read(), 'label': labels[label]})
    
    train_dataset = Dataset.from_pandas(pd.DataFrame(data['train']))
    test_dataset = Dataset.from_pandas(pd.DataFrame(data['test']))
    return DatasetDict({'train': train_dataset, 'test': test_dataset})

加载预训练数据,转化为dataset格式方便输入。

data_dir = './imdb/imdb'  # 替换为你的数据路径
dataset = load_local_dataset(data_dir)
model_name = "./bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)
#加载分词器和预训练模型
# 创建自定义分类模型
model = CustomClassifier(bert_model).to(device)
def preprocess_function(examples):
    return tokenizer(examples['text'],truncation=True,padding=True,max_length=512)

encoded_dataset = dataset.map(preprocess_function,batch = True,desc="Tokenizing")

对数据进行预处理,截断,填充等。

trainingargs = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch"
)

定义训练参数

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(p.label_ids, preds, average='binary')
    acc = accuracy_score(p.label_ids, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

定义评估函数,可以返回精确率,召回率,f1分数等。

trainer = Trainer(
    model = model,
    args = trainingargs,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['test'],
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()
model.save('./saved_model')
tokenizer.save_pretrained('./saved_model')

模型的训练和保存

model = CustomClassifier.from_pretrained('./saved_model')
model.to(device)

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./saved_model')

# 定义预测函数
def predict(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    probs = torch.softmax(outputs, dim=1)  
    label = torch.argmax(probs, dim=1).item()
    return 'pos' if label == 1 else 'neg'

# 进行预测
sample_text = "I failed the math exam this time."
prediction = predict(sample_text)
print(f"Prediction: {prediction}")

预测结果

训练大概需要30~50分钟
img
img

标签:Bert,bert,demo,self,dataset,情感,train,model,save
From: https://www.cnblogs.com/Sun-Wind/p/18300369

相关文章

  • 基于Python酒店评论情感分析可视化系统
    专业技术开发,收藏关注不迷路文章目录一、项目介绍二、开发环境三、功能介绍四、效果图五、文章目录一、项目介绍随着电商网络经济的兴起,更多的人选择在线上预订酒店出行,电商旅游平台使得旅行者可以通过评论更加自由地选择价格和服务合意的酒店,同时也给人们提供了......
  • Paimon Quick Start Demo
    主要解读:1.Paimon和Hadoop的包放到lib2.此处2中格式均可以:'warehouse'='file:/tmp/paimon''warehouse'='file:///tmp/paimon'3.数据持久化到了2中文件,断开连接。插入目标表任务不会中断,这个任务生命周期应该是服务器级别的流任务。再次连接后,创建catalog即可读取word_......
  • python制作甘特图的基本知识(附Demo)
    目录前言1.matplotlib2.plotly前言甘特图是一种常见的项目管理工具,用于表示项目任务的时间进度直观地看到项目的各个任务在时间上的分布和进度常用的绘制甘特图的工具是matplotlib和plotly主要以Demo的形式展示1.matplotlib功能强大的绘图库,适合制作静态......
  • 一起学Hugging Face Transformers(15)- 使用Transformers 进行情感分析
    文章目录前言一、环境准备二、加载预训练模型三、示例:情感分析四、处理数据集五、自定义模型总结思考前言情感分析(SentimentAnalysis)是自然语言处理(NLP)中的一个重要任务,旨在确定文本的情感倾向,如积极、消极或中性。HuggingFace的Transformers库提供了强大的工......
  • 基于深度学习的情感分析
    基于深度学习的情感分析是一种利用深度学习技术从文本数据中提取情感信息,判断文本的情感倾向(如正面、负面或中性)的方法。这项技术在市场营销、客户服务、社交媒体分析、产品评价和政治分析等领域有广泛应用。以下是对这一领域的系统介绍:1.任务和目标情感分析的主要任务和目......
  • 游戏AI的创造思路-技术基础-情感计算(1)
    游戏中的AI也是可以和你打情感牌的哦,不要以为NPC是没有感情的,不过,不要和NPC打过多的情感牌,你会深陷其中无法自拔的~~~~~~目录1.情感计算算法定义2.发展历史3.公式和函数3.1.特征提取阶段TF-IDF(词频-逆文档频率)公式:3.2.模型训练阶段3.3.情感识别阶段3.4.情感生......
  • 【python生成用例报告】unittest、HTMLTestReport、参数化demo
    使用第三方的报告模版,生成报告HTMLTestReport,本质是TestRunner-安装pipinstallHTMLTestReport-使用1.导包unittest、HTMLTestReport2.组装用例(套件,loader)3.使用HTMLTestReport中的runner执行套件4.查看报告目录结构:app.py:importosBase......
  • Franka Robot demo 关节阻抗控制(joint_impedance_control.cpp)
    //Copyright(c)2023FrankaRoboticsGmbH//UseofthissourcecodeisgovernedbytheApache-2.0license,seeLICENSE#include<array>#include<atomic>#include<cmath>#include<functional>#include<iostream>#include&......
  • Franka Robot demo 真空夹抓控制示例(vacuum_object.cpp)
    //Copyright(c)2019FrankaRoboticsGmbH//UseofthissourcecodeisgovernedbytheApache-2.0license,seeLICENSE#include<iostream>#include<thread>#include<franka/exception.h>#include<franka/vacuum_gripper.h>/**......
  • Franka Robot demo 力控 force_control.cpp
    //Copyright(c)2023FrankaRoboticsGmbH//UseofthissourcecodeisgovernedbytheApache-2.0license,seeLICENSE#include<array>#include<iostream>#include<Eigen/Core>#include<franka/duration.h>#include<franka/......