项目目录结构
code
XihuaChatbot/
├── data/
│ └── train_data.jsonl
├── logs/
│ └── (自动创建的日志文件)
├── models/
│ └── xihua_model.pth
│ └── bert-base-chinese/
├── icons/
│ └── icon.ico
├── src/
│ ├── init.py
│ ├── dataset.py
│ ├── model.py
│ ├── gui.py
│ └── main.py
└── requirements.txt
目录说明
XihuaChatbot/: 项目根目录。
data/: 存放训练数据文件,例如 train_data.jsonl。
logs/: 存放日志文件,自动创建。
models/: 存放模型文件,例如 xihua_model.pth 和预训练模型文件夹 bert-base-chinese。
icons/: 存放图标文件,例如 icon.ico。
src/: 存放源代码文件。
init.py: 使 src 成为一个 Python 包。
dataset.py: 定义数据集类 XihuaDataset。
model.py: 定义模型类 XihuaModel。
gui.py: 定义图形用户界面类 XihuaChatbotGUI。
main.py: 主入口文件,启动应用程序。
requirements.txt: 列出项目依赖的库。
文件内容
src/dataset.py
python
import os
import json
import jsonlines
from difflib import SequenceMatcher
class XihuaDataset:
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
if self.validate_item(item):
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {
i + 1}: {
e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = [item for item in json.load(f) if self.validate_item(item)]
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {
file_path}: {
e}")
return data
def validate_item(self, item):
required_keys = ['question', 'human_answers', 'chatgpt_answers']
if all(key in item for key in required_keys):
return True
logging.warning(f"跳过无效项: 缺少必要键 {
required_keys}")
return False
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item['question']
human_answer = item['human_answers'][0]
chatgpt_answer = item['chatgpt_answers'][0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {
idx}: {
e
标签:py,04,max,self,机器人,item,length,聊天,data
From: https://blog.csdn.net/weixin_54366286/article/details/142759911