首页 > 其他分享 >生产者消费者模式下实现多batch延时推理

生产者消费者模式下实现多batch延时推理

时间:2023-10-30 18:14:03浏览次数:23  
标签:std timeout max batch queue 延时 推理 size

生产者消费者模式下实现多batch延时推理

需求分析

在实际推理过程中为了实现较高的吞吐量和较高的资源利用率,往往会使用多线程来收集多次请求,并组合形成多batch下的模型推理,一种常见的实现便是生产者和消费者模式,其需求如下:

  1. 生产者收集提交的请求,消费者对请求进行消费,并将结果返回。
  2. 资源是有限的,可运行的最大max_batch_size是有限的,队列中可等待的max_queue_num也是有限的。
  3. 邻近的资源可进行等待,最长要等待timeout的时间
  4. 资源封装完备,哪里使用就在哪里释放。

设计实现

  1. 在设计实现上,首先是需要有线程安全的生产者消费者队列,通过promise和future完成返回消息。
  2. 需要有设置各个参数的公开接口,以便于设置各个参数。
  3. 需要有一把锁和延时机制,使用条件变量condition_variable的wait_for函数来等待后续参数的传入。
  4. 资源的申请和释放都放置在消费者线程中,并使用promise和future来返回中间结果。
  5. 使用抽象类实现接口和封装,使用多态进行实现。

具体实现

  1. 头文件

这里声明了一个抽象类,关键要实现forward函数,create_infer外部函数实现多态调用。

#ifndef INFER_HPP
#define INFER_HPP
#include <memory>
#include <string>
#include <map>
#include <future>
namespace Infer {
using Tensor = std::string;
using RET_TYPE = std::map<std::string, Tensor>;
class InferInterface {
 public:
  // pic_tensor 表示输入给生产者的tensor, timeout 表示超时时间
  // wait_timeout 表示当队列满了的情况下,愿意最多等待 wait_timeout ms 来等待
  virtual std::shared_future<RET_TYPE> forward(const Tensor& pic_tensor, int wait_timeout = 10) = 0;
  virtual void set_timeout(const int timeout) = 0;
  virtual void set_max_queue_num(const int max_queue_num) = 0;
  virtual ~InferInterface() {}
  explicit InferInterface() {}
 protected:
  InferInterface(const InferInterface&) = delete;
  InferInterface(InferInterface&&) = delete;
  InferInterface& operator=(const InferInterface&) = delete;
};
std::shared_ptr<InferInterface> create_infer(const std::string& filepath, int max_bactch_size);
}
#endif // INFER_HPP
  1. cpp实现类

InferImpl是InferInterface的具体实现,是一种推理引擎的多态封装。

这里有4个原子操作,分别对应着running、超时时间timeout、最大可推理的batch_size以及队列的最大长度max_queue_num。
有一个锁,用以等待后续的请求,并合为一个batch。
有两个条件变量,cond_var_对应的是消费者,用以wait_for不为空且系统在runnning状态,并且queue的请求数量大于batch_size时取出batch_size请求。cond_queue_overflow_对应的是生产者,需要等待当前队列中的数据不满时才能进行填充。

forward函数里的实现生产者,当队列不满的时候向队列中插入请求。
worker函数是消费者的具体实现,会等待当前队列中请求数大于要推理的请求时才会组成batch一同推理,并将结果通过promise返回得到最终结果。

#include "infer.hpp"
#include <cstdio>
#include <queue>

namespace Infer
{
    using std::atomic;
    using std::future;
    using std::map;
    using std::mutex;
    using std::promise;
    using std::queue;
    using std::string;
    using std::thread;
    class InferImpl : public InferInterface
    {
    protected:
        /* data */
        using RET_PROMISE_PTR = std::shared_ptr<promise<RET_TYPE>>;
        struct Job
        {
            /* data */
            Tensor data;
            RET_PROMISE_PTR ret_promise;
        };

        atomic<bool> running;
        atomic<int> timeout;
        atomic<int> max_batch_size;
        atomic<int> max_queue_num;

        mutex lock;
        thread thread_;
        queue<Job> queue_;
        string context;

        std::condition_variable cond_var_;
        std::condition_variable cond_queue_overflow_;

    public:
        void set_timeout(const int timeout)
        {
            this->timeout = timeout;
        }
        void set_max_queue_num(const int max_queue_num)
        {
            this->max_queue_num = max_queue_num;
        }
        bool set_batch_size(const int batch_size)
        {
            if (batch_size < 1)
                return false;
            this->max_batch_size = batch_size;
            return true;
        }

        std::shared_future<RET_TYPE> forward(const Tensor &pic_tensor, int wait_timeout = 10)
        {
            Job job;
            job.data = pic_tensor;
            job.ret_promise = RET_PROMISE_PTR(new promise<RET_TYPE>());
            {
                std::unique_lock<mutex> l(lock);
                if (queue_.size() >= max_queue_num)
                {
                    if (0 == wait_timeout)
                    {
                        throw std::runtime_error("exhausted resource");
                    }
                    cond_queue_overflow_.wait_for(l, std::chrono::milliseconds(wait_timeout), [&]()
                                                 { return queue.size() < max_queue_num; });
                }

                if (queue_.size() >= max_queue_num)
                {
                    throw std::runtime_error("exhausted resource");
                }

                queue_.push(job);
            }
            cond_var_.notify_one();
            return job.ret_promise->get_future();
        }
        explicit InferImpl()
        {
            running = false;
            max_batch_size = 1;
            max_queue_num = 5;
            timeout = 0;
        }
        ~InferImpl() override {
            running = false;
            cond_var_.notify_one();
            if (thread_.joinable())
            {
                thread_.join();
            }
        }

        bool load_model(const string &filepath)
        {
            promise<bool> init_promise;
            thread_ = thread(&InferImpl::worker, this, filepath, std::ref(init_promise));
            running = true;
            return init_promise.get_future().get();
        }

    protected:
        void worker(const string &filepath, promise<bool> &init_promise)
        {
            // 加载模型
            context = filepath;
            if (context.empty())
            {
                init_promise.set_value(false);
                return;
            }

            init_promise.set_value(true);
            std::vector<Job> jobs;
            int batch_id = 0;

            while (running)
            {
                {
                    std::unique_lock<mutex> l(lock);
                    cond_var_.wait(l, [&](){
                        // true 则退出等待
                        return !queue_.empty() || !running; });
                    if (!running)
                        break;
                    if (0 != timeout)
                    {
                        cond_var_.wait(l, [&]()
                                       { return queue_.size() >= max_batch_size; });
                    }
                    for (int i = 0; !queue_.empty() && i < max_batch_size; i++)
                    {
                        jobs.emplace_back(queue_.front());
                        queue_.pop();
                    }
                }

                // 此处假装inference,得到运行结果

                int sz = jobs.size();
                for (auto& job:jobs){
                    auto bbox = job.data + "_result";
                    RET_TYPE handle_result;
                    handle_result["bbox"] = bbox;
                    job.ret_promise->set_value(handle_result);
                }

                jobs.clear();
                std::this_thread::sleep_for(std::chrono::milliseconds(1000 * sz));
                printf("batch id: %d job size: %d \n", batch_id, sz);
                cond_queue_overflow_.notify_one();
                ++batch_id;
            }

            context.clear();
            puts("context_ has cleared");
            puts("Workder done!");      
        }
    };

    std::shared_ptr<InferInterface> create_infer(const string& file_path, int max_batch_size){
        auto infer_ptr = new InferImpl();
        infer_ptr->set_batch_size(max_batch_size);

        if (! infer_ptr->load_model(file_path)){
            delete infer_ptr;
            return nullptr;
        }
        return std::shared_ptr<InferInterface>(infer_ptr);
    }
}
  1. main文件

在main文件中,创建了24个请求,每5个请求将组成一个batch进行推理。
这里有个小问题是,当最后一个请求小于5个时会一直等待,因此应当设置超时机制对剩余的请求进行一并处理。

#include <iostream>
#include <vector>
#include "infer.hpp"

using std::string;
using std::errc;

using RET_FUTURE = std::shared_future<Infer::RET_TYPE>;

int main(int argc, char const *argv[])
{
    /* code */
    string file_path = ""; 
    std::shared_ptr<Infer::InferInterface>infer_ptr = Infer::create_infer(file_path, 5);
    if (infer_ptr == nullptr){
        printf("create infer engine error\n");
        return -1;
    }
    // 每个实例推理愿意等待的时间
    infer_ptr->set_timeout(5);
    infer_ptr->set_max_queue_num(10);
    printf("create infer engine success!\n");

    std::vector<RET_FUTURE> shared_ptrs;
    char buffer[100];
    for (int i=0; i<24; ++i){
        sprintf(buffer, "%d.tensor", i);

        // 在队列满的时候,生产者愿意等1000ms
        shared_ptrs.push_back(infer_ptr->forward(buffer, 1000));
    }

    for (auto &shard_ptr:shared_ptrs){
        shard_ptr.get();
    }

    return 0;
}

标签:std,timeout,max,batch,queue,延时,推理,size
From: https://www.cnblogs.com/wildkid1024/p/17783838.html

相关文章

  • 2023-10-21:用go语言,一共有三个服务A、B、C,网络延时分别为a、b、c 并且一定有:1 <= a <= b
    2023-10-21:用go语言,一共有三个服务A、B、C,网络延时分别为a、b、c并且一定有:1<=a<=b<=c<=10^9但是具体的延时数字丢失了,只有单次调用的时间一次调用不可能重复使用相同的服务,一次调用可能使用了三个服务中的某1个、某2个或者全部3个服务比如一个调用的时间,T=100100的延时......
  • 自编码器AE全方位探析:构建、训练、推理与多平台部署
    本文深入探讨了自编码器(AE)的核心概念、类型、应用场景及实战演示。通过理论分析和实践结合,我们详细解释了自动编码器的工作原理和数学基础,并通过具体代码示例展示了从模型构建、训练到多平台推理部署的全过程。关注TechLead,分享AI与云服务技术的全维度知识。作者拥有10+年互联......
  • 技术分享| anyRTC低延时直播优化
    直播系统就是把活动现场的音频或视频信号经数字压缩后,传送到直播多媒体服务器(CDN)上,在互联网上供广大网友或授权特定人群收听或收看。而随着技术的日益更新,人民对于直播的互动性,实时性要求更高了,传统的直播少则几十秒,多则几分钟的时延很难满足现在的很多直播场景。今天我们就从播......
  • 【Azure Batch】在中国区批处理服务(Mooncake Batch Account)上实验自动池(Auto Pool)
    问题描述在AzureBatch的介绍文档中,提出了自动池的概念,它可以在任务完成后,自动删除Pool资源,详细介绍:https://docs.azure.cn/zh-cn/batch/nodes-and-pools#autopools& https://learn.microsoft.com/zh-cn/rest/api/batchservice/job/add?tabs=HTTP#autopoolspecification自动池是......
  • 让大模型真正学会1+1=2!谷歌教会模型自动学习推理规则,大模型的幻觉有救了
    作者|谢年年在初学算术加法或乘法时,我们通过数小棍的方式逐步从1+1=2,1+2=3等例子中得出1+3=4,这是一种依赖记忆中的数学表格进行演绎推理的过程。后来老师告诉我们前辈们总结了一套完备的求和或乘法表,只要背住,做简单算术题根本不成问题,也不需要数小棍啦!这样一套完备的求和或乘法表......
  • 推理成本增加10倍?对文心大模型4.0的一些猜想
    作者|卖萌酱大家好,我是卖萌酱。相信不少小伙伴这几天都听到了消息,在期待下周即将发布的文心大模型4.0。我们的几个读者群里也发生了相关的讨论:讨论的核心主要围绕以下两个话题展开:文心4.0能不能打过GPT-4文心4.0会不会收费作为AI从业者,卖萌酱将基于目前得到的一些有限的消息,来展......
  • RTMP流媒体服务器LiteCVR支持在iOS播放WebRTC低延时视频流
    视频监控设备是安防行业的细分专业领域,近年来,视频监控业务正在向其他领域加速渗透。众所周知,iOS系统支持HLS流,但是HLS流延时高,无法满足实时流的要求;而WebRTC播放延时低,因此,很多用户希望能在iOS系统上播放Webrtc视频流。针对用户的这一需求,LiteCVR平台灵活的视频能力,可以完全满足。......
  • Kotlin 协程Job 代替 Handler执行延时任务 带取消
    privatevalhandler=Handler(Looper.getMainLooper())varrunnable=Runnable{dismissProgressDialog()}......handler.postDelayed(runnable,(10*1000).toLong())......//取消任务handler.removeCallbacks(runnable)privatevarjob:Job?=null......job......
  • 模型推理batch inference速度无明显提升、耗时线性增长问题排查
    模型推理batchinference速度无明显提升、耗时线性增长问题排查现象描述当模型在推理阶段使用batchinference时,推理速度并无明显提升,相比单帧多次推理收益不大。如笔者在Xavier上测试某模型结果batchsize推理时间ms折算耗时ms/img111.2311.23220.3910.20......
  • [926] Batch Script - Commands
    Inthischapter,wewilllookatsomeofthefrequentlyusedbatchcommands.S.NoCommands&Description1VERThisbatchcommandshowstheversionofMS-DOSyouareusing.2ASSOCThisisabatchcommandthatassociatesanextensionwithaf......