首页 > 其他分享 >代码

代码

时间:2024-05-29 14:54:44浏览次数:10  
标签:代码 attention mask ids train tf input

1. bert 二分类

import tensorflow as tf
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, TFBertModel
import pandas as pd

# 加载预训练的BERT模型和tokenizer
bert_model_name = './bert'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = TFBertModel.from_pretrained(bert_model_name)


# 定义输入处理函数
def encode_texts(query, title, tokenizer, max_length=128):
    encoded_dict = tokenizer.encode_plus(
        query,
        title,
        add_special_tokens=True,  # 添加 [CLS], [SEP] 等标记
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='tf'  # 返回 TensorFlow 张量
    )
    return encoded_dict['input_ids'], encoded_dict['attention_mask']


# 构建模型
def build_model(bert_model):
    input_ids = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.layers.Input(shape=(128,), dtype=tf.int32, name='attention_mask')

    bert_output = bert_model(input_ids, attention_mask=attention_mask)
    cls_output = bert_output.last_hidden_state[:, 0, :]  # 取出 [CLS] 向量

    dense = tf.keras.layers.Dense(256, activation='relu')(cls_output)
    dropout = tf.keras.layers.Dropout(0.3)(dense)
    dense2 = tf.keras.layers.Dense(128, activation='relu')(dropout)
    output = tf.keras.layers.Dense(1, activation='sigmoid')(dense2)  # 二分类问题用 sigmoid 激活

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
    model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output)
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model


# 读取数据集
def load_dataset(file_path, tokenizer, max_length=128):
    queries = []
    titles = []
    labels = []
    data = pd.read_csv(file_path, sep="\t")
    for query, title, label in zip(data['query'].tolist(), data['title'].tolist(), data["label"].tolist()):
        queries.append(query)
        titles.append(title)
        labels.append(int(label))

    input_ids_list = []
    attention_mask_list = []
    for query, title in zip(queries, titles):
        input_ids, attention_mask = encode_texts(query, title, tokenizer, max_length)
        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)

    input_ids = tf.concat(input_ids_list, axis=0)
    attention_masks = tf.concat(attention_mask_list, axis=0)
    labels = tf.convert_to_tensor(labels)

    return {'input_ids': input_ids, 'attention_mask': attention_masks}, labels


# 加载训练和测试数据
train_data, train_labels = load_dataset('train.csv', tokenizer)
test_data, test_labels = load_dataset('test.csv', tokenizer)

# 将TensorFlow张量转换为numpy数组
train_input_ids_np = train_data['input_ids'].numpy()
train_attention_masks_np = train_data['attention_mask'].numpy()
train_labels_np = train_labels.numpy()

# 将训练数据进一步划分为训练集和验证集
train_input_ids, val_input_ids, train_attention_masks, val_attention_masks, train_labels, val_labels = train_test_split(
    train_input_ids_np, train_attention_masks_np, train_labels_np, test_size=0.1, random_state=42)

# 将numpy数组转换回TensorFlow张量
train_inputs = {'input_ids': tf.convert_to_tensor(train_input_ids), 'attention_mask': tf.convert_to_tensor(train_attention_masks)}
val_inputs = {'input_ids': tf.convert_to_tensor(val_input_ids), 'attention_mask': tf.convert_to_tensor(val_attention_masks)}
train_labels = tf.convert_to_tensor(train_labels)
val_labels = tf.convert_to_tensor(val_labels)

# 模型实例化
model = build_model(bert_model)
model.summary()

# 训练模型
epochs = 3
batch_size = 8

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    history = model.fit(
        x={'input_ids': train_inputs['input_ids'], 'attention_mask': train_inputs['attention_mask']},
        y=train_labels,
        validation_data=(
            {'input_ids': val_inputs['input_ids'], 'attention_mask': val_inputs['attention_mask']},
            val_labels
        ),
        epochs=1,  # 每次只训练一个 epoch
        batch_size=batch_size,
        shuffle=True
    )

    # 基于测试数据集进行评估
    loss, accuracy, auc = model.evaluate(test_data, test_labels)
    print(f"Test loss: {loss}, Test accuracy: {accuracy}, Test AUC: {auc}")
    if epoch == 1:
        model.save(f"./bert_relevance_model", save_format='tf')

 

标签:代码,attention,mask,ids,train,tf,input
From: https://www.cnblogs.com/qiaoqifa/p/18220301

相关文章

  • 源代码管理工具介绍——GitHub
    使用源代码管理工具的好处——提高团队的协作效率,降低开发风险,增强代码的稳定性和可维护性1.版本控制:源代码管理工具可以帮助开发团队更加有效地管理、追踪项目的不同版本,团队成员能够利用源代码管理工具方便轻松地查看以前所有的代码版本,比较更改、撤销错误或者恢复之前的代码......
  • 源略论源代码管理工具的精选介绍:聚焦TFS(TFS)
    在软件开发团队协作的生态系统中,源代码管理工具扮演着至关重要的角色,确保代码的版本控制、团队协作效率及项目管理。本文将聚焦于MicrosoftTeamFoundationServer(TFS)这一企业级的源代码管理平台,结合团队开发流程,探讨其安装配置、使用细节及如何促进团队协作。安装与配置FS2013......
  • [oeasy]python019_ 如何在github仓库中进入目录_找到程序代码_找到代码
    继续运行......
  • Aws CodeCommit代码仓储库
    1创建IAM用户IAM创建admin用户,增加AWSCodeCommitFullAccess权限2创建存储库CodePipeline->CodeCommit->存储库创建存储库3SSH1)window环境3.1.1上载SSH公有秘钥生成SSH秘钥ID3.1.2 编辑本地~/.ssh目录中名为“config”的SSH配置文件Hostgit......
  • java代码块
    Java中的代码块代码块分类静态代码块构造代码块局部代码块构造代码块怎么书写构造代码块publicclassDemo{{//构造代码块,书写位置是类中方法外}}构造代码块执行特点和作用执行特点:会在每一个构造方法执行前执行一次publicclassDemo......
  • 梯度提升机器LightGBM集成学习回归、分类、参数调优可视化实例|附数据代码
    全文链接:https://tecdat.cn/?p=36275原文出处:拓端数据部落公众号LightGradientBoostedMachine(简称LightGBM)是一个开源库,它为梯度提升算法提供了高效且有效的实现。LightGBM通过添加一种自动特征选择的方式,并专注于提升具有较大梯度的样本,来扩展梯度提升算法。这可以显著加速......
  • github源代码管理工具——使用介绍
    GitHub是一个面向开源及私有软件项目的在线代码托管平台,用户可以在GitHub上创建仓库(repository),将代码存储在仓库中,并与团队成员共享代码。并且提供了项目管理工具,如Issue跟踪、项目面板、里程碑、任务列表等,有助于团队项目的管理。除了Git代码仓库托管及基本的Web管理界面以外,还提......
  • Github——基于Git的代码托管平台
    Github是一个基于Git的代码托管平台,付费用户可以建私人仓库,我们一般的免费用户只能使用公共仓库,也就是代码要公开。Github由ChrisWanstrath,PJHyett与TomPreston-Werner三位开发者在2008年4月创办。迄今拥有59名全职员工,主要提供基于git的版本托管服务。今天,GitHub已是:一个......
  • 源代码管理工具——GitHub
    GitHub是一个面向开源及私有软件项目的托管平台,拥有超过1亿的开发人员、400万以上的组织机构和3.3亿以上的资料库。自2008年4月10日正式上线以来,GitHub已经成为管理软件开发以及发现已有代码的首选方法。它主要基于Git版本控制系统,提供了包括代码托管、问题跟踪、代码审查、代码片......
  • 【三变量联合分布函数copula】利用AIC BIC确定单变量最优拟合函数、利用AIC确定三变量
      ......