首页 > 其他分享 >BERT的中文问答系统42

BERT的中文问答系统42

时间:2024-11-24 19:57:58浏览次数:8  
标签:BERT max data self 42 length import path 问答

我们将对现有的代码进行扩展,以支持360百科的功能。这包括修改XihuaChatbotGUI类中的相关方法,以及添加一个新的搜索360百科的函数。此外,我们还需要更新历史记录的保存格式,以包含360百科的结果。

项目结构
code
project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md
README.md
markdown

羲和聊天机器人

项目介绍

羲和聊天机器人是一个基于BERT模型的问答系统。它可以从训练数据中学习,并能够回答用户提出的问题。此外,用户可以通过界面评价机器人的回答是否准确,并提供百度百科和360百科的参考答案。

目录结构

project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md

code

依赖

  • Python 3.7+
  • PyTorch
  • Transformers
  • Tkinter
  • Requests
  • BeautifulSoup

安装

pip install torch transformers requests beautifulsoup4

运行

python main.py

功能
用户输入问题,机器人给出回答。
用户可以评价回答是否准确。
如果回答不准确,可以选择查看百度百科或360百科的结果。
训练和重新训练模型。
查看和保存历史记录。

main.py

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def 

标签:BERT,max,data,self,42,length,import,path,问答
From: https://blog.csdn.net/weixin_54366286/article/details/144011099

相关文章

  • FIN421 Econometrics for Finance
    FIN421EconometricsforFinanceGroupCoursework(worth30%ofthetotalmodulegrade)Submissiondeadline:December12,2024Learningoutcomes:nstrumentalvariableregression,modeldiagnostic,andcheckingtheweaknessofinstrumentalvariables.Ana......
  • 2024-2025-1 20241427 《计算机基础与程序设计》第9周学习总结
    作业信息这个作业属于哪个课程[2024-2025-1-计算机基础与程序设计]这个作业要求在哪里https://www.cnblogs.com/rocedu/p/9577842.html#WEEK09这个作业的目标操作系统责任、内存与进程管理、分时系统、CPU调度、文件、文件系统、文件保护、磁盘调度作业正文htt......
  • 2024-2025-1 20241421《计算机基础与程序设计》第九周学习总结
    这个作业属于哪个课程2024-2025-1-计算机基础与程序设计)这个作业要求在哪里https://www.cnblogs.com/rocedu/p/9577842.html#WEEK09这个作业的目标操作系统责任、内存与进程管理、分时系统、CPU调度、文件、文件系统、文件保护、磁盘调度作业正文本博客链接......
  • 2024-2025-1 20241423 《计算机基础与程序设计》第九周学习总结
    作业信息这个作业属于哪个课程[2024-2025-1-计算机基础与程序设计](https://edu.cnblogs.com/campus/besti/2024-2025-1-CFAP)这个作业要求在哪里2024-2025-1计算机基础与程序设计第九周作业这个作业的目标操作系统责任、内存与进程管理、分时系统、CPU调度、文件、......
  • BERT的基本理念
    BERT的基本理念BERT的基本理念:word2vec是一类生成词向量的模型的总称。这类模型多为浅层或者双层的神经网络,通过训练建立词在语言空间中的向量关系。BERT是BidirectionalEncoderRepresentationsfromTransformers的缩写,意为多Transformer的双向编码器表示法,它是由谷......
  • 2024-2025-1 20241428张雄一《计算机基础与程序设计》第九周学习总结
    学期(如2024-2025-1)学号20241428《计算机基础与程序设计》第9周学习总结作业信息这个作业属于哪个课程<班级的链接>(如2024-2025-1-计算机基础与程序设计)这个作业的目标操作系统责任、内存与进程管理、分时系统、CPU调度、文件、文件系统、文件保护、磁盘调度作业......
  • 题解:SP1442 CHAIN - Strange Food Chain
    有三种可能的假话:编号\(>n\);自己吃自己;互吃。使用扩展域并查集(种类并查集)。code:#include<bits/stdc++.h>usingnamespacestd;intn,m,c,t,F[150005];intfind(intx){ if(F[x]==x)returnx; returnF[x]=find(F[x]);}intmain(){cin>>t;while......
  • GB/T 4208-2017 外壳防护等级(IP代码)(3)—特征数字和标志
    写在前面本系列文章主要讲解外壳防护等级GB/T4208标准的相关知识,希望能帮助更多的同学认识和了解GB/T4208标准。若有相关问题,欢迎评论沟通,共同进步。(*^▽^*)外壳防护等级6.第二位特征数字第二位特征数字表示外壳防止由于进水而对设备造成有害影响的防护等级。第二位......
  • GB/T 4208-2017 外壳防护等级(IP代码)(4)—试验要求
    写在前面本系列文章主要讲解外壳防护等级GB/T4208标准的相关知识,希望能帮助更多的同学认识和了解GB/T4208标准。若有相关问题,欢迎评论沟通,共同进步。(*^▽^*)外壳防护等级11.试验一般要求11.1防水防尘试验的环境条件除非有关产品标准另有规定,试验应在GB/T2421.1—2......
  • AI模型架构如(CNN)、(RNN)(LSTM、GRU)、(如BERT、GPT等)在不同领域中的具体应用
    AI模型架构如卷积神经网络(CNN)、循环神经网络(RNN)及其变体(如LSTM、GRU)、以及基于Transformer架构的模型(如BERT、GPT等)在现实生活中有广泛的应用。以下是这些模型在不同领域中的具体应用和形态表现:一、卷积神经网络(CNN)1.图像分类应用案例:手机中的人脸识别、社交媒体中的自......