首页 > 其他分享 >BERT的中文问答系统30

BERT的中文问答系统30

时间:2024-11-14 11:19:48浏览次数:3  
标签:BERT logging data self 30 human import chatgpt 问答

为了完善代码并实现评估功能,我们对现有的代码进行一些调整和扩展。以下是具体的改进:
评估功能:添加评估模型的功能,计算模型在测试集上的准确率。
GUI改进:优化GUI界面,使其更加用户友好。
日志记录:增强日志记录,确保每个步骤都有详细的记录。

  1. 评估功能
    首先,我们需要添加一个评估函数,该函数将计算模型在测试集上的准确率。
# 评估函数
def evaluate(model, data_loader, device):
    model.eval()
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            human_preds = (torch.sigmoid(human_logits) > 0.5).float()
            chatgpt_preds = (torch.sigmoid(chatgpt_logits) > 0.5).float()

            correct_predictions += (human_preds == human_labels).sum().item()
            correct_predictions += (chatgpt_preds == chatgpt_labels).sum().item()
            total_predictions += human_labels.size(0) + chatgpt_labels.size(0)

    accuracy = correct_predictions / total_predictions
    return accuracy
  1. GUI改进
    在GUI中添加一个按钮来启动评估功能,并显示评估结果。
class XihuaChatbotGUI:
    # ... 其他方法保持不变 ...

    def create_widgets(self):
        # ... 其他组件保持不变 ...

        self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def evaluate_model(self):
        test_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/test_data.jsonl'), self.tokenizer, batch_size=8, max_length=128)
        accuracy = evaluate(self.model, test_data_loader, self.device)
        logging.info(f"模型评估准确率: {
     accuracy:.4f}")
        self.log_text.insert(tk.END, f"模型评估准确率: {
     accuracy:.4f}\n")
        self.log_text.see(tk.END)
        messagebox.showinfo("评估结果", f"模型评估准确率: {
     accuracy:.4f}")
  1. 日志记录
    确保每个步骤都有详细的日志记录,以便于调试和跟踪。
def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

完整代码
以下是完整的代码,包括评估功能和改进的GUI:

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging

标签:BERT,logging,data,self,30,human,import,chatgpt,问答
From: https://blog.csdn.net/weixin_54366286/article/details/143759004

相关文章

  • 大模型神书《HuggingFace自然语言处理详解——基于BERT中文模型的任务实战》读完少走
    这几年,自然语言处理(NLP)绝对是机器学习领域最火的方向。那么今天给大家带来一本《HuggingFace自然语言处理详解——基于BERT中文模型的任务实战》这本大模型书籍资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费】★内容全面本......
  • 10.30
     实验4:抽象工厂模式本次实验属于模仿型实验,通过本次实验学生将掌握以下内容:1、理解抽象工厂模式的动机,掌握该模式的结构;2、能够利用抽象工厂模式解决实际问题。 [实验任务一]:人与肤色使用抽象工厂模式,完成下述产品等级结构:  实验要求:  源代码packageorg.e......
  • 0.1+0.2=0.30000000000000004
    看下效果这个网站能找到你想要的答案https://0.30000000000000004.com/十进制转二进制十进制整数转换为二进制整数采用"除2取余,逆序排列"法。具体做法是:用2整除十进制整数,可以得到一个商和余数;再用2去除商,又会得到一个商和余数,如此进行,直到商为小于1时为止然后把先得到......
  • 基于neo4j的英语四六级知识图谱问答系统
    大家好!今天我要和你们分享一个让英语学习变得更加高效有趣的科技新作—基于neo4j的英语四六级知识图谱问答系统!这个系统不仅仅是一个普通的学习工具,它是如何通过最新技术帮助我们更深入理解和掌握英语知识的一个典范。......
  • 大数据新视界 -- 大数据大厂之 Impala 性能提升:高级执行计划优化实战案例(下)(18/30)
           ......
  • 30 秒!用通义灵码画 SpaceX 星链发射流程图
    不想读前人“骨灰级”代码,不想当“牛马”程序员,想像看图片一样快速读复杂代码和架构?来了,灵码又加新buff!!通义灵码支持代码逻辑可视化,可以把你的每段代码画成流程图。你可以把它当成一个超级脑图工具,帮你快速画出代码逻辑和框架!接下来我们秀一下!今天我们就拿GitHub上开......
  • 30 秒!用通义灵码画 SpaceX 星链发射流程图
    不想读前人“骨灰级”代码,不想当“牛马”程序员,想像看图片一样快速读复杂代码和架构?来了,灵码又加新buff!!通义灵码支持代码逻辑可视化,可以把你的每段代码画成流程图。你可以把它当成一个超级脑图工具,帮你快速画出代码逻辑和框架!接下来我们秀一下!今天我们就拿GitHub上开......
  • shell脚本30个案例(一)
    通过一个多月的shell学习,总共写出30个案例,分批次进行发布,这次总共发布了5个案例,希望能够对大家的学习和使用有所帮助,更多案例会在下一次进行发布。案例一、备份指定目录下的文件到另一个目录1.问题在服务器环境中,需要定期备份特定目录(如/var/www/html)中的文件到备份目录(如/b......
  • 空气开关(空气断路器)根据额定电流的不同,可以选择不同规格的开关。家用230V电路中,常见的
    空气开关(空气断路器)根据额定电流的不同,可以选择不同规格的开关。家用230V电路中,常见的额定电流规格有6A、10A、16A、20A、25A、32A、40A、50A、63A等。这些规格的空气开关主要区别在于它们适应的电流负荷大小,从而保护不同功率的家用电器和电路。以下是这些常见规格的比较表格:......
  • L0G3000作业-Git基础知识
    一、闯关任务1任务要求:破冰之自我介绍首先fork一下GitHub-InternLM/Tutorial:LLM&VLMTutorial该项目到自己的账号,注意不要勾选下图的“Copythecamp4branchonly”。来到vscode启动虚拟环境,然后输入下面命令将仓库克隆到本地gitclonehttps://github.com/HuHu1226......