首页 > 其他分享 >完整的端到端的中文聊天机器人

完整的端到端的中文聊天机器人

时间:2024-09-23 23:23:02浏览次数:3  
标签:中文 self 机器人 tokens TOKEN 聊天 input hidden size

这段代码是一个完整的端到端的中文聊天机器人的实现,包括数据处理、模型训练、预测和图形用户界面(GUI),下面是对各个部分功能的详细说明:

1. 导入必要的库

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import tkinter as tk
import jieba
import matplotlib.pyplot as plt
import os
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.amp import GradScaler, autocast

os: 用于设置环境变量和文件操作。
torch: PyTorch 库,用于构建和训练深度学习模型。
tkinter: 用于创建图形用户界面。
jieba: 用于中文分词。
matplotlib: 用于绘制损失曲线。
json: 用于读取 JSON 文件。
transformers: Hugging Face 的 Transformers 库,用于加载预训练模型和分词器。
torch.amp: 用于混合精度训练,提高训练速度和减少内存占用。

2. 定义特殊标记和词汇表

PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

word2index = {
   PAD_TOKEN: 0, UNK_TOKEN: 1, SOS_TOKEN: 2, EOS_TOKEN: 3}
index2word = {
   0: PAD_TOKEN, 1: UNK_TOKEN, 2: SOS_TOKEN, 3: EOS_TOKEN}

特殊标记:定义了四个特殊标记,分别表示填充、未知词、句子开始和句子结束。
词汇表:初始化词汇表,将特殊标记映射到索引。

3. 中文分词

def tokenize_chinese(sentence):
    tokens = jieba.lcut(sentence)
    return tokens

功能:使用 jieba 对输入的中文句子进行分词,返回分词后的词汇列表。

4. 构建词汇表

def build_vocab(sentences):
    global word2index, index2word
    vocab_size = len(word2index)
    for sentence in sentences:
        for token in tokenize_chinese(sentence):
            if token not in word2index:
                word2index[token] = vocab_size
                index2word[vocab_size] = token
                vocab_size += 1
    return vocab_size

功能:遍历所有句子,构建词汇表,将每个词映射到一个唯一的索引。

5. 将句子转换为张量

def sentence_to_tensor(sentence, max_length=50):
    tokens = tokenize_chinese(sentence)
    indices = [word2index.get(token, word2index[UNK_TOKEN]) for token in tokens]
    indices = [word2index[SOS_TOKEN]] + indices + [word2index[EOS_TOKEN]]
    indices += [word2index[PAD_TOKEN]] * (max_length - len(indices))
    return torch.tensor(indices, dtype=torch.long), len(indices)

功能:将输入的句子转换为张量,并返回句子的实际长度。句子被加上 和 标记,并用 标记填充到指定的最大长度。

6. 读取数据

def load_data(file_path):
    if file_path.endswith('.jsonl'):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = [json.loads(line) for line in f.readlines()]
    elif file_path.endswith('.json'):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = json.load(f)
    else:
        raise ValueError("不支持的文件格式。请使用 .jsonl 或 .json。")
    
    questions = [line['question'] for line in lines]
    answers = [random.choice(line['human_answers'] + line['chatgpt_answers']) for line in lines]
    return questions, answers

功能:从指定的 JSON 或 JSONL 文件中读取数据,返回问题和答案列表。

7. 数据增强

def data_augmentation(sentence):
    tokens = tokenize_chinese(sentence)
    augmented_sentence = []
    if random.random() < 0.1:
        insert_token = random.choice(list(word2index.keys())[4:])
        insert_index = random.randint(0, len(tokens))
        tokens.insert(insert_index, insert_token)
    if random.random() < 0.1 and len(tokens) > 1:
        delete_index = random.randint(0, len(tokens) - 1)
        del tokens[delete_index]
    if len(tokens) > 1 and random.random() < 0.1:
        index1, index2 = random.sample(range(len(tokens)), 2)
        tokens[index1], tokens[index2] = tokens[index2], tokens[index1]
    augmented_sentence = ''.join(tokens)
    return augmented_sentence

功能:对输入的句子进行随机插入、删除和交换操作,以增加数据的多样性。

8. 定义数据集

class ChatDataset(Dataset):
    def __init__(self, questions, answers):
        self.questions = questions
        self.answers = answers

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        input_tensor, input_length = sentence_to_tensor(self.questions[idx])
        target_tensor, target_length = sentence_to_tensor(self.answers[idx])
        return input_tensor, target_tensor, input_length, target_length

功能:定义一个自定义的数据集类,用于存储问题和答案,并将它们转换为张量。

9. 自定义 collate 函数

def collate_fn(batch):
    inputs, targets, input_lengths, target_lengths = zip(*batch)
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=word2index[PAD_TOKEN])
    targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=word2index[PAD_TOKEN])
    return inputs, targets, torch.tensor(input_lengths), torch.tensor(target_lengths)

功能:将一批数据进行填充,使其具有相同的长度,并返回填充后的输入、目标、输入长度和目标长度。

10. 创建数据集和数据加载器

def create_dataset_and_dataloader(questions_file, answers_file, batch_size=10, shuffle=True, split_ratio=0.8):
    questions, answers = load_data(questions_file)
    vocab_size = build_vocab(questions + answers)
    dataset = ChatDataset(questions, answers)
    
    train_size = int(split_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
    return train_dataset, train_dataloader, val_dataset, val_dataloader, vocab_size

功能:创建训练和验证数据集及数据加载器,并返回词汇表的大小。

11. 定义模型结构

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_step, hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        gru_output, hidden = self.gru(embedded, hidden)
        output = self.softmax(self.out(gru_output.squeeze(1)))
        return output, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, tokenizer):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.tokenizer = tokenizer

    def forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):
        batch_size = input_tensor.size(0)
        max_target_len = max(target_lengths)
        vocab_size = self.decoder.out.out_features
        outputs = torch.zeros(batch_size, max_target_len, vocab_size).to(self.device)
        encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)
        decoder_input = torch.tensor([[word2index[SOS_TOKEN]] * batch_size], device=self.device).transpose(0, 1)
        decoder_hidden = encoder_hidden
        for t in range(max_target_len<

标签:中文,self,机器人,tokens,TOKEN,聊天,input,hidden,size
From: https://blog.csdn.net/weixin_54366286/article/details/142468857

相关文章

  • Creo 11.0百度云资源中文版+详细安装教程下载
    如大家所熟悉的,Creo是一款计算机辅助设计(CAD)应用程序,由PTC开发。该套件由应用程序组成,每个应用程序都为产品开发中的用户角色提供一组独特的功能。Creo在Windows系统上运行并兼容,提供用于3DCAD参数化特征实体建模、3D直接建模、2D正交视图、有限元分析和仿真、原理图设计......
  • qt mvsc编译器中文乱码
    qtmvsc编译器中文乱码1.问题mvsc编译对中文不太友好,设置ui界面时显示中文会乱码。2.解决办法方法1修改qtcreator文件编码格式工具->选项->文本编辑在pro文件里添加如下代码msvc{QMAKE_CFLAGS+=/utf-8QMAKE_CXXFLAGS+=/utf-8}添加完成点击重新构建,就可......
  • lazarus使用中文拼音首字母实现中文变量等快速代码补全
    在lazarus使用中文变量等代码补全功能基础上,按以下方法就可以实现输入中文拼音首字母就可以快速代码补全功能。代码补全功能:Ctrl+w 打开\lazarus\ide\wordcompletion.pp找到 procedureAddIfMatch(constALine,ALineUp:string;constAFirstPos,ALength:Integer);(lazarus......
  • DevExpress WPF中文教程:如何解决行焦点、选择的常见问题?
    DevExpressWPF拥有120+个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpressWPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。无论是Office办公软件的衍伸产品,还是以数据为中心......
  • 为什么大多数开发人员都避免在MySQL表名和列名中使用中文?
    大多数开发人员避免在MySQL表名和列名中使用中文,主要有以下几个原因:1.兼容性问题不同的数据库系统和工具对字符集的支持各不相同。使用中文可能导致在不同平台或工具间的数据迁移和兼容性问题。2.编码和显示问题在某些开发环境或工具中,中文可能会出现乱码,导致调试和维护......
  • [机器人仿真]WEBOTS中创建轮腿机器人模型-并联闭环机构的创建和使用
    想着做个轮腿的机器人玩玩,但是如果光用PID做算法,对于轮子加腿的结构似乎效果并不好,为了实现轮腿本身能够飞坡在一定高度下能够跳跃,我想着上个仿真模型来调试和学习LQR算法机器人仿真的软件似乎挺多,我查到比较常用的有ROS套件的一个,还有就是webots本着界面简单,开源(还有校园网方便......
  • 2024睿抗机器人开发者大赛CAIP-编程技能赛-本科组(省赛) RC-u5 工作安排详解
    本文参考https://www.cnblogs.com/Kescholar/p/18306136这一题可能对高手来说就能轻而易举的看出是个01背包,但是对于我这种小白还是要经过详细的分析才可以理解。我们题目要求的是获得的最大报酬,题目的影响因素有三个:工作时长、工作截止时间、对应的报酬,那么怎么样合理的去......
  • Abaqus 2024百度云下载:附中文安装包+教程
    正如大家所熟知的,Abaqus是一款有限元分析软件,能够高效的配合工程师完成创作。它可以高精度地实现包括金属、橡胶、高分子材料、复合材料、钢筋混凝土、可压缩超弹性泡沫材料以及土壤和岩石等地质材料的工程仿真计算。“Abaqus”不仅具有出色的仿真计算能力,由于其基于Python开......
  • 使用java做一个微信机器人
    微信机器人这个功能,目前在市面上运用的还是不是很多,每个人实现机器人的目的也不一样,有的为了自动加好友;有的为了自动拉群:也有的为了机器人对话聊天等等一系列。想必大家对微信机器人感兴趣的伙伴,也在aithub上面搜索了很多吧,但是大多数走到一半遇到各种bug就没有继续坚持走下去,原......
  • 解决vsc中文乱码
    关于vscode使用coderunner运行python代码出现中文乱码的解决办法_coderunner运行乱码-CSDN博客CodeRunner插件设置"setPYTHONIOENCODING=utf8&&python-u"  ......