首页 > 其他分享 >BERT的中文问答系统60

BERT的中文问答系统60

时间:2024-12-20 23:28:44浏览次数:8  
标签:BERT self attention ids 60 human input chatgpt 问答

增强GUI的用户体验,我们可以添加更多可调节的项目,并显示当前使用的显存等信息。以下是修改后的代码:

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import tkcalendar
import locale

# 设置本地化为中文
locale.setlocale(locale.LC_ALL, 'zh_CN.UTF-8')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item.get('question', '')
        human_answer = item.get('human_answers', [''])[0]
        chatgpt_answer = item.get('chatgpt_answers', [''])[0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {
     e}")

    return total_loss / len(data_loader)

# 模型评估函数
def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model

标签:BERT,self,attention,ids,60,human,input,chatgpt,问答
From: https://blog.csdn.net/weixin_54366286/article/details/144508978

相关文章

  • 题解:CF603A Alternative Thinking
    思路你猜这个题为什么是A题?很思维的解法。只允许翻转一次,所以最多只会在原答案上加\(2\)。所以我们来讨论仅有的三种可能:加\(2\),要有两段连续的\(0\)或\(1\)。加\(1\),要有一段连续的\(0\)或\(1\)。不加,没有连续的\(0\)或\(1\)。我们的代码模拟上面的三种......
  • Arduino LINX 实现上拉输入,并且实现对应VI以及C#调用(以MEGA2560PRO为例)
    固件部分思路:Arduino本身可以设置INPUT_PULLUP,而LINX中没有。猜测原因是LINX在具体实现中将PINMODE设置为INPUT,并且没有实现INPUT_PULLUP版本。因此只要修改LINX固件,增加PULLUP版本的实现即可。(如果不需要普通的浮空输入,直接把源代码里的INPUT改成INPUT_PULLUP即可,无须后续操作,这......
  • 深入浅出:一个 RAG问答机器人调优示例
    一、RAG基本流程为了让大模型能回答关于公司规章制度的问题,我们需要构建一个RAG应用,RAG应用的工作流程包括:前排提示,文末有大模型AGI-CSDN独家资料包哦!解析:加载公司规章制度文档(如pdf、docx等),并解析为文本形式;分段:对解析后的文档进行分段,因为大模型的输入长度是有限......
  • 6093. 不互质子序列 DP 分解质因数
    #include<iostream>#include<cstring>#include<algorithm>#include<vector>usingnamespacestd;constintN=1e5+10;intn;intf[N];//动态规划数组,用于记录以每个质因子为结尾的最长子序列长度intmain(){cin>>n;if(n==1){......
  • 【外设篇】STMG4芯片-Hal库-I2C通信AS5600编码器(基础工程)
    引言:AS5600为绝对值编码器,其接口有I2C和ADC两种,为配合FOC的10KHZ运行速率,博主使用I2C的DMA模式+高速波特率1MHZ或ADC模拟的方式读取电机电角度,并讲明绝对值编码器在PMSM电机里如何让电角度对齐正确角度,最后用STM32Cubemx和keil5实习代码。1.I2C的HAL库函数及ADC的HAL库函数......
  • 基于知识图谱的医疗问答系统(Kubernetes)
    目录一、前提准备1、创建neo4j用户,数据目录2、修改neo4j.conf配置文件二、k8s集群部署1、步骤文档2、选择k8s-master1节点打标,kube-scheduler直接将pod调度到该节点3、创建neo4j命名空间4、创建pv5、创建pvc6、创建neo4j的Deployment7、创建NodePort类型的svc(实......
  • labelme标注后的数据只剩下面积1600像素以内的小颗粒
    点击查看代码importcv2importnumpyasnpimportjsonimportosdeflist_jsons(folder_path):forfilenameinos.listdir(folder_path):iffilename.endswith(('.json',)):yieldos.path.join(folder_path,filename)defremove_spec......
  • 【内向基环树】LeetCode 2360. 图中的最长环
    题解内向基环树的一个基本特征就是总共有\(n\)个节点和\(n\)条边,且每个节点的出度至多为\(1\),因此本题符合内向基环树的特征。先使用拓扑排序,标记全部的简单环外的节点,剩余的节点就必定是环上的节点。参考代码classSolution{public:intlongestCycle(vector<int>......
  • 自然语言处理NLP——基于电影知识图谱和大型语言模型(LLM)的KBQA问答机器人(增加自然语言
    文章目录参考可视化逻辑运行演示参考https://github.com/Xiaoheizi2023/NLP_KBQA可视化逻辑提取出实体后去neo4j搜寻实体相关的图谱,然后返回数据再进行可视化可视化工具cytoscape.js提取实体逻辑:分词后比对关键词运行数据库:Mysql(保存聊天和用户和帖子信息)neo......
  • 608. 树节点 - 力扣(LeetCode)
    608.树节点-力扣(LeetCode)目标输入输入:Treetable:idp_id121314252输出输出:idtype1Root2Inner3Leaf4Leaf5Leaf分析树中的每个节点可以是以下三种类型之一:"Leaf":节点是叶子节点。"Root":节点是树的根节点。"lnner":节点既不是叶子节点也不是根节点。编写一个解决......