首页 > 其他分享 >tvm-多线程代码生成和运行

tvm-多线程代码生成和运行

时间:2022-10-23 14:22:11浏览次数:75  
标签:代码生成 task cdata tvm num env 多线程 parallel op

本文链接
https://www.cnblogs.com/wanger-sjtu/p/16818492.html

调用链

tvm搜索算子在需要多线程运行的算子,是在codegen阶段时插入TVMBackendParallelLaunch的调用。
TVMBackendParallelLaunch 是tvm的线程池并行化入口,具体如下

/*!
 * \brief The callback function to execute a parallel lambda
 * \param task_id the task id of the function. //这里实际就是线程池线程编码,对应第几个线程
 * \param penv The parallel environment backs the execution. // num_task, sync
 * \param cdata The supporting closure data.
 */
typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);

/*!
 * \brief Backend function for running parallel jobs.
 *
 * \param flambda The parallel function to be launched.
 * \param cdata The closure data. // 可以认为时循环的变量 codegen时生成
 * \param num_task Number of tasks to launch, can be 0, means launch
 *           with all available threads. // codegen 时写入的是0,运行时根据配置写入
 *
 * \return 0 when no error is thrown, -1 when failure happens
 */
int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);

flambda的调用在单线程和多线程下略有区别。

单线程运行时

if (num_workers == 1) {
    std::atomic<int32_t> sync_counter{0};
    TVMParallelGroupEnv env;
    env.num_task = 1;
    env.sync_handle = &sync_counter;
    (*flambda)(0, &env, cdata);
    return 0;
  }

多线程运行时

// launcher->Init(flambda, cdata, num_task, need_sync != 0);
this->cdata = cdata;
this->flambda = flambda;
this->env.num_task = num_task;

while (queue->Pop(&task, spin_count)) {
    ICHECK(task.launcher != nullptr);
    TVMParallelGroupEnv* penv = &(task.launcher->env);
    void* cdata = task.launcher->cdata;
    if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
      task.launcher->SignalJobFinish();
    } else {
      task.launcher->SignalJobError(task.task_id);
    }
  }

可以看到 待并行函数中 TVMParallelGroupEnv* penv 包含了实际的运行时线程,运行时可以根据这个确定每个线程的工作区间和步长。
cdata则是线程运行时需要变量信息,闭包变量。

总结

对要并行的函数,实际上是按照lambda表达式的方式生成的。FTVMParallelLambda 的输入参数前两个是运行时确定的,第三个是捕获的外部变量。

codegen 过程

下面验证一下上述的猜测。

codegen过程中,实际上是在遍历tir Stmt的AST,因为生成的循环都是基于For的,调用过程也比较简单了。

void CodeGenCPU::VisitStmt_(const ForNode* op)  // -> 
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
                        op->thread_binding, op->annotations),
                    0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str());   // ->
CodeGenCPU::VisitStmt_(const ForNode* op);

当遍历到For节点时, 根据属性判断是否并行加速。这里只分析加速场景。此时parallel_env_.penv == nullptr 创建多线程调用函数,进入CreateParallelLaunch函数。
然后 再生成 For的遍历逻辑。this->VisitStmt(body); 这里的body其实还是For ,这时候就进入

} else {
      // already in parallel env.

前文的猜测也在这里得到验证。


void CodeGenCPU::VisitStmt_(const ForNode* op) {
  ICHECK(is_zero(op->min));
  if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
    CodeGenLLVM::VisitStmt_(op);
  } else if (op->kind == ForKind::kParallel) {
    if (parallel_env_.penv == nullptr) {
      CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
                               op->thread_binding, op->annotations),
                           0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str());
    } else {
      // already in parallel env.
      ICHECK(parallel_env_.task_id.defined());
      ICHECK(parallel_env_.num_task.defined());
      ICHECK(parallel_env_.penv != nullptr);
      DataType t = op->extent.dtype();
      PrimExpr num_task = cast(t, parallel_env_.num_task);
      PrimExpr task_id = cast(t, parallel_env_.task_id);
      ICHECK(!parallel_env_.in_parallel_loop)
          << "Nested parallel loop is not supported by threadpool, try fuse them instead";
      parallel_env_.in_parallel_loop = true;
      if (parallel_env_.stride_pattern) {
        CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
                        op->loop_var, op->body);
      } else {
        PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
        PrimExpr begin = min(task_id * step, op->extent);
        PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
        CreateSerialFor(MakeValue(begin), MakeValue(end),
                        llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
      }
      parallel_env_.in_parallel_loop = false;
      ++parallel_env_.parallel_loop_count;
    }
  } else {
    LOG(FATAL) << "cannot handle for type " << op->kind;
  }
}

/*
    const Stmt& body  For 循环的statement
    int num_task, 这里设置的是0,根据运行时参数确定使用线程
    std::string name
*/
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) {
  // closure data
  llvm::Function* f =
      llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage,
                             "__tvm_parallel_lambda", module_.get());
  SetTargetAttributes(f);

  // allocate and setup the closure, call the closure. //For 循环内部变量。这里需要声明一下
  Array<Var> vfields = tir::UndefinedVars(body, {});
  uint64_t nbytes;
  TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); // 可以认为时循环的变量
#if TVM_LLVM_VERSION >= 90
  auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
  auto launch_callee = RuntimeTVMParallelLaunch();
#endif
  llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
      launch_callee,
      {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
  // Setup the closure function.
  auto* lambda_entry =
      llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f);
  builder_->SetInsertPoint(lambda_entry);
  auto it = f->arg_begin();
  llvm::Value* task_id = &(*it++);
  task_id->setName("task_id");
  llvm::Value* penv = &(*it++);
  cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
  // setup new variable map, swap it with current var context.
  std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
  UnpackClosureData(cdata, vfields, &new_vmap);
  // setup parallel env
  ParallelEnv par_env;
  par_env.task_id = Var("task_id", DataType::Int(32));
  par_env.num_task = Var("num_task", DataType::Int(32));
  new_vmap[par_env.task_id.get()] = task_id;
  new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
      t_int32_,
      builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}),
      "num_task");
  par_env.penv = penv;
  auto new_analyzer = std::make_unique<arith::Analyzer>();
  std::swap(function_, f);
  std::swap(parallel_env_, par_env);
  std::swap(analyzer_, new_analyzer);
  std::swap(var_map_, new_vmap);
  this->VisitStmt(body);
  builder_->CreateRet(ConstInt32(0));
  // swap the var map back, now we are back on track.
  std::swap(var_map_, new_vmap);
  std::swap(analyzer_, new_analyzer);
  std::swap(parallel_env_, par_env);
  std::swap(function_, f);
  ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch";
  builder_->SetInsertPoint(par_launch_end);
}

标签:代码生成,task,cdata,tvm,num,env,多线程,parallel,op
From: https://www.cnblogs.com/wanger-sjtu/p/16818492.html

相关文章

  • 使用多线程优化for循环请求http接口
    packagecom.test.list;importcom.alibaba.fastjson.JSON;importcom.google.common.util.concurrent.ThreadFactoryBuilder;importjava.util.*;importjava.util.concur......
  • Java多线程(2):线程关键字
    您好,我是湘王,这是我的博客园,欢迎您来,欢迎您再来~ Java中和线程相关的关键字就两:volatile和synchronized。volatile以前用得较少,以后会用得更少(后面解释)。它是一种非常轻......
  • java线程例题-类/对象/实例化/声明/多线程/同步
    packageA_ShangGuiGu.Thread.ThreadTest;importjava.util.concurrent.locks.ReentrantLock;////////////////////////////classzhanghu{//账户类,定义一个余额属性。......
  • 郁金香 -多线程创建
    #include<stdio.h>#include<Windows.h>//创建线程函数//开辟线程//BUG解决让这两个线程可以长期存在免得无法观察DWORDWINAPI线程函数1(LPVOIDarg){whil......
  • devexpress中grid控件教程 多线程异步加载数据,进度条展示
    devexpress中最强大的控件,要数它的Grid了。几乎任务数据都可以展示,但今天要用它做另一个功能。假设我们开发这样一款软件:视频编辑软件。里面有个功能,提取视频中的音频。一......
  • 多线程
    线程简介多任务现实中太多同时做多件事情的例子,看起来是多个任务都在做,其实本质上我们的大脑同时只做了一件事多线程原来是一条路,慢慢因为车太多了,道路阻塞,效率极......
  • UEC++ 多线程(二) AsyncTask
    AsyncTaskAsyncTask系统实现的多线程与自己实现继承的FRunnable实现的原理相似,还可以利用UE4提供的线程池。当使用多线程不满意时也可以调用StartSynchronousTask改成主线......
  • 单核多线程可见性问题
    背景学习群上有个同学提出问题,如下截图这里可以看到分歧点,我认为JMM协议规定了工作内存,那么即使是单核,JAVA虚拟机也会保证线程本地内存变量的私有性,所以会存在不可见。......
  • 3_linux多线程
    3_linux多线程编程基本概念程序执行的最小单位进程是线程的容器,不是基本执行单位,是线程容器线程是进程中的不同执行路径,有独立的堆栈、局部变量(因为线程需要线程函数)......
  • Java多线程(1):线程生命周期
    您好,我是湘王,这是我的博客园,欢迎您来,欢迎您再来~ 从事Java开发这些年来,如果要问我Java当中最难的部分是什么?最有意思的部分是什么?最多人讨论的部分是什么?那我会毫不犹豫......