首页 > 其他分享 >如何使用大模型高效生产数据[含完整代码]

如何使用大模型高效生产数据[含完整代码]

时间:2024-09-27 22:49:16浏览次数:10  
标签:高效 -- 代码 openai str model 数据 模型 row

大模型出现之前我们的训练数据大都依赖人工标注、开源数据以及从线上数据中构造合适的监督数据,如果开源数据不太符合我们的业务需求(大部分情况下无法直接满足要求),且已有的线上数据也没办法抽取出符合要求的监督数据,这个时候恐怕只能依赖于人工标注了,但是人工标注又非常的耗费人力和时间。大模型出现后给我们提供了新的选择,我们可以通过构造高质量的prompt使用大模型给我们生产数据。原理其实很简单,所以本次分享的重点其实不在于原理,主要是想将本人工作中经常使用的一套代码分享出来,供大家直接使用

完整代码见:Data Generate Template | LlamaFactory

大致流程

当接到一个业务需求时跟产品对齐细节后就可以开始写prompt了(这里假设硬件资源支撑不了满足效果的例如72b模型,1.5b以及7b等模型直接用效果又无法达标)。我会先用vllm将效果好的大模型部署起来方便使用openai的sdk调用,反复调试迭代prompt差不多达到要求后我们就可以开始生产用于训练小模型的数据了。一般使用多进程加速数据生产。下面从代码层面讲讲具体的细节。

vllm部署大模型

vllm提供了非常方便的命令行部署命令:

CUDA_VISIBLE_DEVICES="0,1" python -m vllm.entrypoints.openai.api_server --served-model-name model_name  --model model_path --tensor-parallel-size 2 --port 8002

假设你的启动过程十分顺利,这时候你在终端就能看见打印出来的访问地址,一般是http://0.0.0.0:8002,这个时候你在浏览器中输入http://0.0.0.0:8002/docs就可以访问到一个可交互的文档界面,可以在这里尝试访问服务,看看是否可以正常调用。

调试prompt

vllm非常贴心的提供了一个基于gradio的示例代码供大家使用,调试prompt则会更加方便,代码在这里大家可以自取。但是我自己做了一点修改,可以在界面上直接修改prompt,这样就不用每次修改prompt后重启服务了,相对来说方便一点。修改后的版本如下:

import argparse
from collections.abc import Generator

import gradio as gr
from openai import OpenAI

# Argument parser setup
parser = argparse.ArgumentParser(
    description="Chatbot Interface with Customizable Parameters"
)
parser.add_argument(
    "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
)
parser.add_argument(
    "-m", "--model", type=str, default="gpt-3.5-turbo", help="Model name for the chatbot"
)
parser.add_argument(
    "--temp", type=float, default=0.8, help="Temperature for text generation"
)
parser.add_argument(
    "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
)
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)

# Parse the arguments
args = parser.parse_args()

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url

# Create an OpenAI client to interact with the API server
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


def predict(
    message: str, history: list[tuple[str, str]], system_message: str
) -> Generator[str, None, None]:
    # Convert chat history to OpenAI format
    history_openai_format = [{"role": "system", "content": system_message}]

    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    # Create a chat completion request and send it to the API server
    stream = client.chat.completions.create(
        model=args.model,  # Model name to use
        messages=history_openai_format,  # type: ignore  # Chat history
        # temperature=args.temp,  # Temperature for text generation
        stream=True,  # Stream response
        extra_body={
            "repetition_penalty": 1,
            "stop_token_ids": (
                [int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()]
                if args.stop_token_ids
                else []
            ),
        },
        max_tokens=2048,
    )

    # Read and return generated text from response stream
    partial_message = ""
    for chunk in stream:
        partial_message += chunk.choices[0].delta.content or ""  # type: ignore
        yield partial_message


# Create and launch a chat interface with Gradio
gr.ChatInterface(
    predict,
    additional_inputs=[
        gr.Textbox("you are a helpful assistant", label="System Prompt"),
    ],
    additional_inputs_accordion=gr.Accordion(open=True),
).queue().launch(server_name=args.host, server_port=args.port, share=True)

服务启动过程中可能出现下面的信息,问题其实不解决也行,只是不能够分享给他人使用了,在自己本地上访问是没问题的。但由于我可能需要给到产品去体验效果,所以把这个问题修复了下。

Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: 

1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_darwin_arm64
2. Rename the downloaded file to: frpc_darwin_arm64_v0.2
3. Move the file to this location: /xxx/.venv/lib/python3.12/site-packages/gradio

修复上面问题的具体步骤如下:

wget https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_darwin_arm64
mv frpc_darwin_arm64 frpc_darwin_arm64_v0.2
chmod +x frpc_darwin_arm64_v0.2
mv frpc_darwin_arm64_v0.2 you_gradio_path_in_env

再次启动时,就会打印出两个地址,如下。第二个地址可以分享给其他人访问

Running on local URL:  http://127.0.0.1:8001
Running on public URL: https://24e925b09b9a9c337d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)

我这边使用deepseek作为示例,打开地址后就可以看到如下对话界面。此时你可以快速地在下方的System Prom中迭代,在submit输入框中输入业务数据获得模型输出,查看输出是否符合要求。

这一步完成后我们应该在超大模型的基础上获得了一个不错的效果,能够满足业务要求。

大规模蒸馏数据

这一步其实主要是将前面启对话界面的代码改写成读取待标注的数据,并使用多进程调用vllm启动的服务,主要代码如下:

# 读取输入数据
df = pd.read_json(CONFIG["INPUT_FILE"], lines=True)
# 并行处理数据
with ProcessPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
    list(tqdm(executor.map(process_row, df.to_dict(orient="records")), total=len(df)))

process_row用于单次调用处理一条数据,具体实现可参考如下代码。主要逻辑是请求大模型获取标注结果,并保存每一条结果(当数据量较大时,这么做容错性比较高,不至于程序出错就会全部都需要重新标注),同时可以在post_process实现一定的后处理。

def process_row(row):
    try:
        user_input = USER_INPUT_TEMPLATE.format(**row)
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": user_input},
        ]
        response = (
            client.chat.completions.create(
                model=CONFIG["MODEL_NAME"],
                messages=messages,
            )
            .choices[0]
            .message.content
        )

        post_process(row, response)
    except Exception as e:
        print(f"处理数据时出错: {e}")
        print(f"跳过数据: {row.get('id', 'unknown')}")


def post_process(row, response):
    # 在此处理模型的响应,例如输出是json,可使用json.loads(response)
    # 示例:将响应直接添加到row中
    row["model_response"] = response

    # 生成唯一ID并保存处理后的数据
    unique_id = str(uuid.uuid4())
    filename = f"{CONFIG['PROCESSED_DIR']}/{unique_id}.json"
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(row, f, ensure_ascii=False, indent=4)

整体的生成时间可以通过tqdm较好的把控观察,一般来说几十万条数据两三天就够了,不过具体还是要看你的任务数据长度。

小结

完整的代码可以在这里这里找到。以上就是本人在工作中最常用的蒸馏数据的代码,不一定是最佳实践,但是目前对于我来说够用了。本文只着眼于如何高效的产出数据,略去了具体的一些细节;例如prompt如何迭代优化、不同的场景如何生产出优质的数据,后面可以单开一篇举例聊聊。

标签:高效,--,代码,openai,str,model,数据,模型,row
From: https://blog.csdn.net/budahui/article/details/142603893

相关文章

  • 三篇文章速通JavaSE到SpringBoot框架 (中) IO 进程线程 网络编程 XML MySQL JDBC相关
    文章目录IOfile类的作用I/O的作用将上篇文章综合项目使用IO流升级所需知识点进程线程创建线程的三种方式网络编程网络编程介绍IP地址端口号网络通信协议网络通信协议的分层演示代码XMLXML的作用是什么?xml特点注解什么是注解?注解的使用注解的重要性注解的使用实例M......
  • 二级指针内存模型
    二级指针主要分成三种内存模型:1》指针数组:指针指向栈区的一段内存的首地址,并且栈区分配内存空间,每个元素又装有一个指针指向常量区的某一个地址类似于char*myArray[]={"aaaaa","cccccc","bbbbbb","11111"};应用场景名称:指针数组涉及到2个内存区:栈区和栈区 ......
  • [Java手撕]生产者消费者模型
    importjava.util.LinkedList;importjava.util.Queue;importjava.util.concurrent.locks.Condition;importjava.util.concurrent.locks.ReentrantLock;publicclassMain{publicstaticfinalQueue<Integer>message=newLinkedList<>();......
  • Spring Ioc底层原理代码详细解释
    文章目录概要根据需求编写XML文件,配置需要创建的bean编写程序读取XML文件,获取bean相关信息,类,属性,id前提知识点Dom4j根据第二步获取到的信息,结合反射机制动态创建对象,同时完成属性赋值将创建好的bean存入到Map集合,设置key-value映射提供方法从Map中通过id获取到对象的valu......
  • GPT也会玩《黑神话》?胜率还远超人类?全靠大模型实力!
    导语《黑神话:悟空》这款游戏,以其独特的东方魅力和引人入胜的剧情,在玩家和业界中引发了巨大的热潮。它不仅在界内十分火爆,更是火出了圈,可以在各处看见他的身影,包括奶茶店、咖啡店、商场超市等。这款游戏凭借其精致的画面和深入人心的角色塑造,无疑将为中国游戏产业注入新的活力,......
  • 如何让大模型更好地进行场景落地?【文末送书】
    自ChatGPT模型问世后,在全球范围内掀起了AI新浪潮。有很多企业和高校也随之开源了一些效果优异的大模型,例如:Qwen系列模型、MiniCPM序列模型、Yi系列模型、ChatGLM系列模型、Llama系列模型、Baichuan系列模型、Deepseek系列模型、Moss模型等。图片来自:ASurveyofLargeLa......
  • 这五本大模型书籍,让你从大模型零基础到精通,非常详细收藏我这一篇就够了
    大模型(LargeLanguageModels,LLMs)是近年来人工智能领域的一大热点,它们在自然语言处理、对话系统、内容生成等多个方面展现出了强大的能力。随着技术的发展,市面上出现了许多介绍大模型理论与实践的书籍,为研究人员和开发人员提供了宝贵的资源。以下是一些精选的大模型书籍推......
  • 大模型时代,新手和程序员如何转型入局AI行业?
    在近期的全国两会上,“人工智能”再次被提及,并成为国家战略的焦点。这一举措预示着在接下来的十年到十五年里,人工智能将获得巨大的发展红利。技术革命正在从“互联网+”向“人工智能+”逐步迈进,我将迎来新一轮技术革新和人才需求的增长。毫无疑问,AI工程师将是未来最紧俏的岗......
  • 代码随想录训练营第44天|最长公共子序列
    1143.最长公共子序列classSolution{public:intlongestCommonSubsequence(stringtext1,stringtext2){text1.insert(text1.begin(),'');text2.insert(text2.begin(),'');intn1=text1.length(),n2=text2.length(),m......
  • 类中静态代码块、静态属性加载顺序
     1、如果静态属性在静态代码块前面classFoo{publicFoo(){System.out.println("我是Example的静态属性foo");System.out.println("未修改的静态属性值为====>"+Example.staticVariable);Example.staticVariable=2;......