首页 > 其他分享 >BERT的中文问答系统31

BERT的中文问答系统31

时间:2024-11-17 13:17:13浏览次数:3  
标签:__ BERT data 31 length path import 问答 self

为了实现需求,我们对现有的30代码进行一些扩展,增加网络搜索功能,并在大模型无法提供满意答案时调用网络搜索。以下是详细的代码和文件结构说明:
文件结构

project_root/
│
├── data/
│   ├── train_data.jsonl
│   └── test_data.jsonl
│
├── logs/
│   ── (log files will be saved here)
│
├── models/
│   ── xihua_model.pth
│
├── main.py
├── xihua_chatbot_gui.py
└── README.md

main.py
这是主入口文件,用于启动GUI。

import tkinter as tk
from xihua_chatbot_gui import XihuaChatbotGUI

if __name__ == "__main__":
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

xihua_chatbot_gui.py
这是包含GUI和模型逻辑的文件。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)

标签:__,BERT,data,31,length,path,import,问答,self
From: https://blog.csdn.net/weixin_54366286/article/details/143760552

相关文章

  • 20222312 2024-2025-1 《网络与系统攻防技术》实验六实验报告
    1.实验内容及要求本实践目标是掌握metasploit的用法。指导书参考Rapid7官网的指导教程。https://docs.rapid7.com/metasploit/metasploitable-2-exploitability-guide/下载官方靶机Metasploitable2,完成下面实验内容。(1)前期渗透①主机发现(可用Aux中的arp_sweep,search一下就可......
  • 2024-2025-1 学号20241315《计算机基础与程序设计》第八周学习总结
    作业信息这个作业属于哪个课程2024-2025-1-计算机基础与程序设计这个作业要求在哪里<作业要求的链接>https://www.cnblogs.com/rocedu/p/9577842.html#WEEK08这个作业目标功能设计与面向对象设计面向对象设计过程面向对象语言三要素汇编、编译、解释、执行作......
  • 20222311 2024-2025-1 《网络与系统攻防技术》实验六实验报告
    1.实验内容1.1本周学习内容回顾使用了Metasploit框架,其是一个功能强大的渗透测试框架。在使用的过程当中,Metasploit提供了种类繁多的攻击模块,涵盖了远程代码执行、服务拒绝、提权等多种攻击方式,支持对多种操作系统和应用程序进行测试。除了漏洞利用,它还具备强大的后渗透功能,如......
  • 20222315 2024-2025-1 《网络与系统攻防实验六实验》实验六实验报告
    1、实验内容本实践目标是掌握metasploit的用法。指导书参考Rapid7官网的指导教程。https://docs.rapid7.com/metasploit/metasploitable-2-exploitability-guide/下载官方靶机Metasploitable2,完成下面实验内容。(1)前期渗透①主机发现(可用Aux中的arp_sweep,search一下就可以use......
  • 20222310 2024-2025-1 《网络与系统攻防技术》实验六实验报告
    一、实验内容学习掌握Metasploit工具的使用。下载靶机Metasploitable2,完成以下实验内容。1.前期渗透(1)主机发现(可用Aux中的arp_sweep,search一下就可以use)(2)端口扫描(可以直接用nmap,也可以用Aux中的portscan/tcp等)(3)扫描系统版本,漏洞等2.Vsftpd源码包后门漏洞(21端口)3.SambaMS-R......
  • 大模型实战项目:基于大模型+知识图谱的知识库问答 (附项目)
    今天给大家介绍一个git开源的宝藏项目—基于大模型+知识图谱的知识库问答,这里还搭配了一个演示dome给大家,如需要此项目练手的,我已经打包好了放在文末~基于大模型+知识图谱的知识库问答系统项目整体流程介绍项目整体包含5个部分:数据重构、图谱构建、图谱补全、对话......
  • [Codeforces Round 987 (Div. 2)](https://codeforces.com/contest/2031)解题报告
    CodeforcesRound987(Div.2)太好了是阳间场,我们有救了感觉脑子生锈了qwq,F题做不出来A分析知如果有\(i<j\)且\(a_i>a_j\)的情况出现则\(i\)和\(j\)一定至少改一个。所以答案即为\(n-cnt\),\(cnt\)为众数个数。B发现一个数离自己原本的位置距离不会超过\(1\),有......
  • CF2031
    A题意给一个单调不增序列,每次操作可以单点修,问把序列变为单调不减序列需要的最小操作次数。分析注意到事实上我们需要修改的数字非常多。考虑一个中间点\(x\),我们将所有小于\(x\)的数提升至\(x\),所有大于\(x\)的数减少至\(x\)。模拟这个过程是\(O(n^2)\)的,但我们发现......
  • springboot在线问答系统-毕业设计源码76418
    摘 要随着互联网趋势的到来,各行各业都在考虑利用互联网将自己推广出去,最好方式就是建立自己的互联网系统,并对其进行维护和管理。在现实运用中,应用软件的工作规则和开发步骤,采用Java技术建设在线问答系统。本设计主要实现集人性化、高效率、便捷等优点于一身的在线问答系......
  • 2024-2025-1 20241318 《计算机基础与程序设计》第八周学习总结
    这个作业属于哪个课程https://edu.cnblogs.com/campus/besti/2024-2025-1-CFAP(如[2024-2025-1-计算机基础与程序设计])这个作业要求在哪里https://www.cnblogs.com/rocedu/p/9577842.html#WEEK08这个作业的目标加入云班课,参考本周学习资源自学教材计算机科学概......