为了完善代码并实现评估功能,我们对现有的代码进行一些调整和扩展。以下是具体的改进:
评估功能:添加评估模型的功能,计算模型在测试集上的准确率。
GUI改进:优化GUI界面,使其更加用户友好。
日志记录:增强日志记录,确保每个步骤都有详细的记录。
- 评估功能
首先,我们需要添加一个评估函数,该函数将计算模型在测试集上的准确率。
# 评估函数
def evaluate(model, data_loader, device):
model.eval()
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
human_preds = (torch.sigmoid(human_logits) > 0.5).float()
chatgpt_preds = (torch.sigmoid(chatgpt_logits) > 0.5).float()
correct_predictions += (human_preds == human_labels).sum().item()
correct_predictions += (chatgpt_preds == chatgpt_labels).sum().item()
total_predictions += human_labels.size(0) + chatgpt_labels.size(0)
accuracy = correct_predictions / total_predictions
return accuracy
- GUI改进
在GUI中添加一个按钮来启动评估功能,并显示评估结果。
class XihuaChatbotGUI:
# ... 其他方法保持不变 ...
def create_widgets(self):
# ... 其他组件保持不变 ...
self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)
self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
self.history_button.grid(row=3, column=1, padx=10, pady=10)
self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
self.save_history_button.grid(row=3, column=2, padx=10, pady=10)
def evaluate_model(self):
test_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/test_data.jsonl'), self.tokenizer, batch_size=8, max_length=128)
accuracy = evaluate(self.model, test_data_loader, self.device)
logging.info(f"模型评估准确率: {
accuracy:.4f}")
self.log_text.insert(tk.END, f"模型评估准确率: {
accuracy:.4f}\n")
self.log_text.see(tk.END)
messagebox.showinfo("评估结果", f"模型评估准确率: {
accuracy:.4f}")
- 日志记录
确保每个步骤都有详细的日志记录,以便于调试和跟踪。
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
完整代码
以下是完整的代码,包括评估功能和改进的GUI:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging
标签:BERT,logging,data,self,30,human,import,chatgpt,问答
From: https://blog.csdn.net/weixin_54366286/article/details/143759004