逻辑推理赛道baseline代码分析与总结
前言
主要是对baseline的代码进行了代码分析和流程总结,以及个人的一点关于prompt的想法
目录
1 引入依赖包
首先,代码引入了多种Python库和模块,包括并行处理、日志记录、HTTP请求、重试机制等。以下是关键的依赖包及其作用:
from multiprocessing import Process, Manager
import json
import os
from pprint import pprint
import re
from tqdm import tqdm
import random
import uuid
import openai
import tiktoken
import numpy as np
import requests
from retry import retry
from scipy import sparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
import time
from http import HTTPStatus
import dashscope
2 设置模型和API密钥
接下来,代码设置了一个模型名称和API密钥,并配置了日志记录,这里用的大模型是qwen1.5-1.8b-chat,通过调用API的方式进行的
这里模型不知道可不可以通过huggingface下载使用,还没实验
MODEL_NAME = 'qwen1.5-1.8b-chat'
dashscope.api_key = "your_api_key_here"
logger.remove() # 移除默认的控制台输出
logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")
MODEL_NAME
: 指定了使用的预训练模型名称。dashscope.api_key
: 设置了API密钥,用于访问模型。logger
: 通过loguru
库配置日志记录,包括日志文件的保存、轮换和压缩。
3 API调用和重试机制
定义了两个函数:api_retry
和call_qwen_api
。api_retry
函数实现了API调用的重试机制,call_qwen_api
函数负责实际的API调用。
def api_retry(MODEL_NAME, query):
max_retries = 5
retry_delay = 60 # in seconds
attempts = 0
while attempts < max_retries:
try:
return call_qwen_api(MODEL_NAME, query)
except Exception as e:
attempts += 1
if attempts < max_retries:
logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
raise
def call_qwen_api(MODEL_NAME, query):
messages = [{'role': 'user', 'content': query}]
response = dashscope.Generation.call(
MODEL_NAME,
messages=messages,
result_format='message', # set the result is message format.
)
if response.status_code == HTTPStatus.OK:
return response['output']['choices'][0]['message']['content']
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
raise Exception()
api_retry
: 实现了最多5次的重试机制,每次重试间隔60秒。call_qwen_api
: 通过dashscope
库调用预训练模型,并返回结果。
4 生成Prompt和解析结果
定义了两个函数:get_prompt
和extract
。get_prompt
函数生成适合模型输入的Prompt,extract
函数解析API返回的结果。
这里prompt策略可以改进,可以尝试使用普通prompt 和 COT prompt做一下对比,感觉COT prompt效果可能会有提升,可以尝试
def get_prompt(problem, question, options):
options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))
prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
### 题目:
{problem}
### 问题:
{question}
{options}
"""
return prompt
def extract(input_text):
ans_pattern = re.compile(r"答案是:(.)", re.S)
problems = ans_pattern.findall(input_text)
if problems == '':
return 'A'
return problems[0]
get_prompt
: 生成一个适合模型输入的Prompt,包含问题描述、选项等信息。extract
: 使用正则表达式解析模型返回的结果,提取答案。
5 处理数据
process_datas
函数负责并行处理数据,调用API并解析结果。
def process_datas(datas, MODEL_NAME):
results = []
with ThreadPoolExecutor(max_workers=16) as executor:
future_data = {}
lens = 0
for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
problem = data['problem']
for id, question in enumerate(data['questions']):
prompt = get_prompt(problem, question['question'], question['options'])
future = executor.submit(api_retry, MODEL_NAME, prompt)
future_data[future] = (data, id)
time.sleep(0.6) # 控制每0.6秒提交一个任务
lens += 1
for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
data = future_data[future][0]
problem_id = future_data[future][1]
try:
res = future.result()
extract_response = extract(res)
data['questions'][problem_id]['answer'] = extract_response
results.append(data)
except Exception as e:
logger.error(f"Failed to process text: {data}. Error: {e}")
return results
ThreadPoolExecutor
: 用于并行处理多个任务,提高处理效率。tqdm
: 用于显示任务提交和处理的进度。time.sleep(0.6)
: 控制每0.6秒提交一个任务,防止API请求过于频繁。
6 主函数
main
函数负责读取输入数据,调用数据处理函数并写入输出数据。
def main(ifn, ofn):
if os.path.exists(ofn):
pass
data = []
with open(ifn) as reader:
for line in reader:
sample = json.loads(line)
data.append(sample)
datas = data
return_list = process_datas(datas, MODEL_NAME)
print(len(return_list))
print("All tasks finished!")
return return_list
- 读取输入文件
ifn
中的数据,并调用process_datas
函数处理数据。 - 将处理后的结果写入输出文件
ofn
。
7 评估和过滤
定义了两个函数:evaluate
和filter_problems
。evaluate
函数评估模型结果,filter_problems
函数过滤和去重问题。
def evaluate(ofn):
data = []
with open(ofn) as reader:
for line in reader:
sample = json.loads(line)
data.append(sample)
pse = 0
cnt = 0
tot = 0
for task in data:
for question in task['questions']:
if MODEL_NAME in question:
tot += 1
cnt += question[MODEL_NAME] == question['answer']
else:
pse += 1
print(cnt, tot, cnt/tot, pse)
def filter_problems(data):
result = []
problem_set = set()
for item in data:
problem = item['problem']
if problem in problem_set:
for existing_item in result:
if existing_item['problem'] == problem:
if has_complete_answer(item['questions']):
existing_item['questions'] = item['questions']
existing_item['id'] = item['id']
break
else:
if has_complete_answer(item['questions']):
result.append(item)
problem_set.add(problem)
return result
evaluate
:
评估模型的正确率和完成度。
filter_problems
: 过滤重复问题并保留完整答案。
8 辅助函数
定义了一些辅助函数,如has_complete_answer
和find_missing_ids
来检查答案完整性和查找缺失ID。
def has_complete_answer(questions):
for question in questions:
if 'answer' not in question:
return False
return True
def find_missing_ids(dict_list):
extracted_ids = {int(d['id'][-3:]) for d in dict_list}
all_ids = set(range(500))
missing_ids = all_ids - extracted_ids
return sorted(missing_ids)
has_complete_answer
: 检查每个问题是否包含答案。find_missing_ids
: 查找缺失的ID,确保数据完整性。
知识点总结
Prompt工程-提示原则
原则一:编写清晰、具体的指令
应该通过提供尽可能清晰和具体的指令来表达希望模型执行的操作,这将引导模型给出正确的输出,并减少无关或不正确响应的可能。编写清晰的指令不意味着简短的指令,因为在许多情况下,更长的提示实际上更清晰且提供了更多上下文,这实际上可能导致更详细更相关的输出。
原则二:给模型时间去思考
如果模型匆忙地得出了错误的结论,应该尝试重新构思查询,请求模型在提供最终答案之前进行一系列相关的推理。换句话说,如果给模型一个在短时间或用少量文字无法完成的任务,它可能会猜测错误。这种情况对人来说也是一样的。如果让某人在没有时间计算出答案的情况下完成复杂的数学问题,他们也可能会犯错误。因此,在这些情况下,可以指示模型花更多时间思考问题,这意味着它在任务上花费了更多的计算资源。
提分想法
目前使用的prompt是:
prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
### 题目:
{problem}
### 问题:
{question}
{options}
"""
考虑可以使用不同的提示策略,如Few-Shots 、 COT 、 SC、 TOT 、 Step-Back
修改提示策略感觉是最简单的上分方法之一,对于同样的模型,同样的任务,使用不同的 Prompt,输出的结果也有不小的差异,所以这也是上分思路之一。
一般来说,使用Prompt技巧的结果会比不使用任何技巧要好,对于简单的任务并不是叠加所有的技巧就会更好,到达一定结果后,再叠加技巧不会提升效果。
目前针对这个任务来讲,我觉得ToT技巧可能会是最优提示策略,还需要后续实验~~~
标签:Task1,prompt,AI,Datawhale,return,api,import,problem,data From: https://blog.csdn.net/weixin_44812944/article/details/140748451