首页 > 其他分享 >bert分类的代码

bert分类的代码

时间:2024-06-21 15:09:49浏览次数:10  
标签:bert 分类 代码 labels ids train attention tf input

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, TFBertModel
from transformers import RobertaTokenizer, TFRobertaModel
import pandas as pd
from random import shuffle
from sklearn.metrics import confusion_matrix, f1_score
import numpy as np
import random


# 设置 Python 的随机种子
seed_value = 42
np.random.seed(seed_value)
random.seed(seed_value)
# 设置 TensorFlow 的全局随机种子
tf.random.set_seed(seed_value)
os.environ['TF_DETERMINISTIC_OPS'] = '1'

# 加载预训练的BERT模型和tokenizer
bert_model_name = './bert'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = TFBertModel.from_pretrained(bert_model_name)
# bert_model_name = './robert'
# tokenizer = RobertaTokenizer.from_pretrained(bert_model_name)
# bert_model = TFRobertaModel.from_pretrained(bert_model_name)


# 计算详细指标
def action_recall_accuracy(y_pred, y_true):
    cm = confusion_matrix(y_true, y_pred)

    # 计算每个类别的准确率和召回率
    num_classes = cm.shape[0]
    accuracy = []
    recall = []

    for i in range(num_classes):
        # 计算准确率:预测正确的样本数 / 实际属于该类别的样本数
        acc = cm[i, i] / sum(cm[i, :])
        accuracy.append(acc)

        # 计算召回率:预测正确的样本数 / 预测为该类别的样本数
        rec = cm[i, i] / sum(cm[:, i])
        recall.append(rec)

    # 打印结果
    for i in range(num_classes):
        print(f"类别 {i} 的准确率: {accuracy[i]:.3f}")
        print(f"类别 {i} 的召回率: {recall[i]:.3f}")

    scores = []

    for i in range(num_classes):
        # 计算F1分数
        f1 = f1_score(y_true, y_pred, average=None)[i]
        scores.append(f1)

        # 打印F1分数
        print(f"类别 {i} 的F1分数: {scores[i]:.3f}")

    # 打印各类别F1-score的平均值
    average_f1 = sum(scores) / len(scores)
    print(f"各类别F1-score的平均值: {average_f1:.3f}")


# 定义输入处理函数
def encode_texts(query, title, tokenizer, max_length=128):
    encoded_dict = tokenizer.encode_plus(
        query,
        title,
        add_special_tokens=True,  # 添加 [CLS], [SEP] 等标记
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='tf'  # 返回 TensorFlow 张量
    )
    return encoded_dict['input_ids'], encoded_dict['attention_mask']


# 构建模型
def build_model(bert_model):
    input_ids = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='attention_mask')

    bert_output = bert_model(input_ids, attention_mask=attention_mask)
    cls_output = bert_output.last_hidden_state[:, 0, :]  # 取出 [CLS] 向量

    dense = tf.keras.layers.Dense(256, activation='relu')(cls_output)
    dropout = tf.keras.layers.Dropout(0.5)(dense)
    dense2 = tf.keras.layers.Dense(32, activation='relu')(dropout)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense2)  # 二分类问题用 sigmoid 激活

    model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
                  loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model


# 读取数据集
def load_dataset(file_path, tokenizer, max_length=128):
    queries = []
    titles = []
    labels = []
    data = pd.read_csv(file_path)
    all_data = []
    for query, title, label in zip(data['query'].tolist(), data['title'].tolist(), data["label"].tolist()):
        all_data.append([query, title, int(label)])

    shuffle(all_data)
    for item in all_data:
        query, title, label = item
        queries.append(query)
        titles.append(title)
        labels.append(label)

    input_ids_list = []
    attention_mask_list = []
    for query, title in zip(queries, titles):
        input_ids, attention_mask = encode_texts(query, title, tokenizer, max_length)
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)

    input_ids = tf.concat(input_ids_list, axis=0)
    attention_masks = tf.concat(attention_mask_list, axis=0)
    labels = tf.convert_to_tensor(labels)

    return {'input_ids': input_ids, 'attention_mask': attention_masks}, labels


# 加载训练和测试数据
train_data, train_labels = load_dataset('train.csv', tokenizer)
test_data, test_labels = load_dataset('test.csv', tokenizer)

# 将TensorFlow张量转换为numpy数组
train_input_ids_np = train_data['input_ids'].numpy()
train_attention_masks_np = train_data['attention_mask'].numpy()
train_labels_np = train_labels.numpy()

# 将训练数据进一步划分为训练集和验证集
train_input_ids, val_input_ids, train_attention_masks, val_attention_masks, train_labels, val_labels = train_test_split(
    train_input_ids_np, train_attention_masks_np, train_labels_np, test_size=0.05, random_state=42, shuffle=False)

# 将numpy数组转换回TensorFlow张量
train_inputs = {'input_ids': tf.convert_to_tensor(train_input_ids), 'attention_mask': tf.convert_to_tensor(train_attention_masks)}
val_inputs = {'input_ids': tf.convert_to_tensor(val_input_ids), 'attention_mask': tf.convert_to_tensor(val_attention_masks)}
train_labels = tf.convert_to_tensor(train_labels)
val_labels = tf.convert_to_tensor(val_labels)

# 模型实例化
model = build_model(bert_model)
model.summary()

# 计算类权重以强调准确性
neg_weight = 10.0
pos_weight = 1.0  # 使正类样本的权重较低,减少召回率
class_weight = {0: neg_weight, 1: pos_weight}

# 训练模型
epochs =3
batch_size = 32
true_labels = pd.read_csv('test.csv')['label'].astype('int32')

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    history = model.fit(
        x={'input_ids': train_inputs['input_ids'], 'attention_mask': train_inputs['attention_mask']},
        y=train_labels,
        validation_data=(
            {'input_ids': val_inputs['input_ids'], 'attention_mask': val_inputs['attention_mask']},
            val_labels
        ),
        epochs=1,  # 每次只训练一个 epoch
        batch_size=batch_size,
        shuffle=True,
        class_weight=class_weight  # 调整类别权重
    )

    # 基于测试数据集进行评估
    loss, accuracy, auc = model.evaluate(test_data, test_labels)
    print(f"Test loss: {loss}, Test accuracy: {accuracy}, Test AUC: {auc}")

    # 调整决策阈值
    threshold = 0.5  # 调高阈值以减少 False Positives 提升准确度

    # 计算精确率和召回率
    predictions = model.predict(test_data)
    pred_labels = [int(i > threshold) for i in predictions[:, 0]]
    true_labels = list(np.array(true_labels))
    action_recall_accuracy(pred_labels, true_labels)
    if epoch == 5:
        with open("pred_rs.txt", "w", encoding="utf-8") as out:
            for label, pred in zip(true_labels, predictions[:, 0]):
                out.write("{}\t{}\n".format(label, pred))

 

标签:bert,分类,代码,labels,ids,train,attention,tf,input
From: https://www.cnblogs.com/qiaoqifa/p/18260563

相关文章

  • 帮企商城10合一万能DIY分销商城小程序源码系统 带源代码包+搭建部署教程
    系统概述这是一款集多种功能于一体的源码系统,旨在为用户提供一站式的商城解决方案。它不仅支持小程序端,还能与其他平台无缝对接,满足不同用户的需求。代码示例系统特色功能一览   1.万能DIY功能:用户可以根据自己的需求和创意,自由定制商城的外观、布局和功能模块,打造......
  • 超级会员卡积分收银系统源码 带完整的安装代码包以及搭建部署教程
    系统概述超级会员卡积分收银系统源码是一款专为商业运营打造的综合性软件解决方案。它集成了会员卡管理、积分管理、收银管理等多种功能,旨在为企业提供高效、便捷、准确的运营管理工具。该系统源码采用先进的技术架构,具有良好的稳定性和扩展性,能够适应不同规模和类型的企业需......
  • 百度在线分销商城小程序源码系统 分销+会员组+新用户福利 前后端分离 带完整的安装代
    系统概述百度在线分销商城小程序源码系统是一款集分销、会员组管理和新用户福利于一体的前后端分离的系统。它采用先进的技术架构,确保系统的稳定性、高效性和安全性。该系统的前端基于小程序开发,为用户提供了便捷的购物体验和交互界面。用户可以通过小程序轻松浏览商品、下单......
  • CF519E A and B and Lecture Rooms(树上倍增 + 分类讨论)
    link一眼看上去没什么思路,手摸一下样例,发现有不同性质的点对求解想法很不一样,考虑先分类讨论看看。从简单的约束到强的约束分类讨论,这样更可做,也更好讨论,比如首先我就想到两点是否重合,然后所求点一定要到两点的距离相等,我就想到路径长度的奇偶性,接着就考虑复杂的深度关系...........
  • 课程设计——基于FPGA的交通红绿灯控制系统(源代码)
    摘要:        本课程设计旨在设计一个基于FPGA(现场可编程门阵列)的交通红绿灯控制系统。该系统模拟了实际道路交叉口的红绿灯工作场景,通过硬件描述语言(如Verilog或VHDL)编写源代码实现。系统包含三个主要部分:红绿灯显示模块、计时控制模块以及状态切换模块。红绿灯显示模......
  • 用Python执行JavaScript代码,这些方法你不可不知!
    目录1、PyExecJS:轻量级桥梁......
  • recastnavigation.Sample_TempObstacles代码注解 - rcBuildHeightfieldLayers
    烘培代码在rcBuildHeightfieldLayers本质上是为每个tile生成高度上的不同layer算法的关键是三层循环:forz轴循环forx轴循环for高度span循环判断span和相邻span的连通性(x/z平面相邻cell)如果联通,则标注为同一个layer,也就是在x/z平面上标注layer,形成像是互不相......
  • 【TensorFlow深度学习】开源社区支持与GitHub上贡献代码的流程
    开源社区支持与GitHub上贡献代码的流程开源社区支持与GitHub上贡献代码的流程:携手共创软件未来1.开源社区支持的意义2.如何在GitHub上找到合适的项目3.贡献代码的流程3.1.Fork与Clone3.2.创建分支3.3.修改代码3.4.提交与推送3.5.创建PullRequest......
  • 高效BUG管理:定级、分类和处理流程
    高效BUG管理:定级、状态跟踪与处理全流程前言一、BUG的定义二、BUG的定级三、BUG的状态四、BUG的处理流程1.BUG报告2.BUG确认3.BUG修复4.BUG验证5.BUG关闭五、常见问题与解决方案六、总结前言在测试工作中,BUG的定级和分类是一个重要环节,它直接影响到BUG修复的......
  • 代码随想录算法训练营第17天 | 二叉树04
    代码随想录算法训练营第17天找树左下角的值https://leetcode.cn/problems/find-bottom-left-tree-value/找树左下角的值代码随想录https://leetcode.cn/problems/find-bottom-left-tree-value/路径总和https://leetcode.cn/problems/path-sum/description/路径总和2https......