首页 > 其他分享 >CINN 中子图编译缓存机制

CINN 中子图编译缓存机制

时间:2023-05-06 09:57:07浏览次数:49  
标签:CINN std 缓存 const target auto 编译 key graph

采用 「问-答」形式记录研读 CINN 开源框架的笔记

Q:CINN中子图编译的入口是在哪里?

  for (const auto& node_vec : clusters) {  // <------- 逐个遍历每个子图
    // Classify var node to inputs, outputs, and internals.
    GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());

    GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
    AnalyseClusterVariables(cluster_set,
                            deny_var_set,
                            &cluster_inputs,
                            &cluster_outputs,
                            &cluster_internals,
                            is_inference_stage,
                            all_skip_gc_vars);

    auto subgraph = CreateNewSubGraph(
        cluster_set, cluster_internals, cluster_inputs, cluster_outputs);

    if (graph->Has(kSkipGcVarNames)) {
      auto& sub_skip_gc_vars =
          subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
      sub_skip_gc_vars = all_skip_gc_vars;
    }
    auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));  // <------ 添加子图(可能包含-1动态shape)
    VLOG(4) << "Compilation Key:\n"
            << cinn_compiler->ReadableKey(compilation_key);

    // Replace the found cluster to a new cinn op node
    ReplaceSubGraphWithCinnOpNode(cluster_set,     // <------- 编译并缓存每个子图的结果
                                  cluster_inputs,
                                  cluster_outputs,
                                  cluster_internals,
                                  compilation_key,
                                  graph);

Q:AddGraph做的事情是什么?

int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
  int64_t graph_key = std::hash<Graph *>()((&(*graph)));
  graphs_[graph_key] = std::move(graph);  // <------ 编译期原生静态图包含-1
  return graph_key;
}
// Add一个graph后,会同步替换原生的Graph为一个 [cinn_launch] Op

Q:CINN中不同Program下的子图编译结果可以复用么?hashkey是否耦合了var_name?

size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
  // sort grad node by name and id.
  auto compare = [](ir::Node* n1, ir::Node* n2) {
    return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
                                      : (n1->Name() < n2->Name());
  };

  // graph.Nodes() return unordered_set, here using set to avoid the same graph
  // may return different result
  std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare),
      output_set(compare);
  node_set.insert(graph.Nodes().begin(), graph.Nodes().end());

  std::string hash_str;
  for (ir::Node* n : node_set) {
    hash_str.append(n->Name());

    output_set.clear();
    output_set.insert(n->outputs.begin(), n->outputs.end());
    for (auto* out : output_set) {
      hash_str.append(out->Name()); // <------ 耦合了graph中的var_name
    }
  }

  VLOG(1) << "The hash graph:\n" << hash_str;

  size_t hash_val = std::hash<std::string>()(hash_str);
  VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
  return hash_val;
}  //

Bert中具体的一个hash_key样例:cumsumcumsum_0.tmp_0cumsum_0.tmp_0elementwise_subelementwise_subtmp_0feedinput_idsfetchfill_any_likefull_like_0.tmp_0full_like_0.tmp_0cumsumelementwise_subinput_idsfill_any_liketmp_0fetch

size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
  std::ostringstream has_str;

  for (const auto& name_shape : key.input_shapes_) {  // <------- 输入shape信息
    has_str << name_shape.first;
    has_str << std::hash<phi::DDim>()(name_shape.second);
  }

  has_str << key.graph_hash_val_;   // graph 结构信息
  has_str << key.arch_str_;        // target 信息
  return std::hash<std::string>()(has_str.str());
}

Q:主框架是何时触发「编译」的?

template <typename DeviceContext, typename T>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey); 
     // 根据输入的Tensor shape信息来触发,此时会消解掉一些动态shape为-1的值
    const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
        compilation_key, inputs_name2tensor, target, stream);
  }

Q:CINN是如何消除动态shape的?

void CinnGraphSymbolization::RunOp(const CinnOpDesc& op_desc,
                                   const OpMapperContext& ctx) const {
  const auto& op_type = op_desc.Type();
  auto* kernel = ::cinn::frontend::OpMapperRegistry::Global()->Find(op_type);
  VLOG(4) << "Running Op " << op_type;
  kernel->Run(op_desc, ctx);  // 此处会由NetBuilder->build()分发到具体API上,调用infer_shape
}

Q:CINN内部是哪里触发缓存机制的?

const CinnCompiledObject &CinnCompiler::Compile(
    const Graph &graph,
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
    const Target &target,
    void *stream) {
  VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
  CinnCacheKeyByAddress cur_key_by_address(
      graph, input_tensors, target.arch_str());   // 优先通过graph.ptr + shape + target 来获取?
  CinnCacheKeyByStructure cur_key_by_struct;      // 若未命中,则再以 graph info + shape + target 来获取

  if (!cache_by_address_.count(cur_key_by_address)) {
    // generate the structure cache key
    cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
    if (!cache_by_struct_.count(cur_key_by_struct)) {
      std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
      auto compiled_res =
          CompileGraph(graph, input_tensors, target, compiled_num, stream); // 核心职责交给 CompileGraph
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_struct_
      if (!cache_by_struct_.count(cur_key_by_struct)) {
        cache_by_struct_[cur_key_by_struct] = compiled_num;
        index2cache_.emplace(compiled_num, std::move(compiled_res));
      }
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
        cache_by_address_[cur_key_by_address] =
            cache_by_struct_.at(cur_key_by_struct);
      }
    } else {
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
        cache_by_address_[cur_key_by_address] =
            cache_by_struct_.at(cur_key_by_struct);
      }
    }
  }
  return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
}

Q: CompileGraph里的核心职责是什么,是否还有缓存?

std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
    const ir::Graph &graph,
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
    const Target &target,
    std::int64_t compiled_num,
    void *stream) const {
  CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
  auto frontend_program = symbol();
  auto fetch_ids = symbol.GetFetchIds();
  VLOG(4) << "All fetch var ids in CINN: "
          << string::join_strings(fetch_ids, ',');

  auto cinn_graph = Optimize(&frontend_program, fetch_ids, target); // 同一个ir::Graph仅会做一次
  VLOG(4) << "-- The " << compiled_num << "-th compilation ("
          << target.arch_str() << "), and its related graph:\n"
          << cinn_graph->Visualize();

  auto scope = BuildScope(target, cinn_graph);
  auto graph_compiler =
      std::make_unique<GraphCompiler>(target, scope, cinn_graph); // GraphCompiler一次性工作,但会被compiled_obj持有
  GraphCompiler::CompileOptions options;
  options.with_instantiate_variables = false;
  if (!FLAGS_enable_pe_launch_cinn) {
    options.with_buffer_handle_instruction_inserted = true;
  }
  std::unique_ptr<AutoTuner> auto_tuner;
  if (FLAGS_enable_cinn_auto_tune) {
    VLOG(4) << "Compile with auto-tune";
    auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
    auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
    ::cinn::auto_schedule::TuningOptions tuning_options;
    tuning_options.num_measure_trials = 0;
    auto tuning_result = auto_tuner->Tune(tuning_options);
    options.Apply(tuning_result);
  }
  auto compiled_res =
      graph_compiler->Build(options, std::move(fetch_ids), stream);
  auto compiled_obj = std::make_unique<CinnCompiledObject>();
  *compiled_obj = {std::move(graph_compiler),
                   std::move(auto_tuner),
                   std::move(compiled_res.runtime_program),
                   scope,
                   symbol.var_model_to_program_map()};  // <------对应于 paddle2cinn_varmap
  compiled_obj->cached_index = compiled_num;
  compiled_obj->launch_context =
      std::make_unique<operators::details::CinnLaunchContext>(graph,
                                                              *compiled_obj);
  CheckCompiledValid(graph, input_tensors, *compiled_obj);
  return compiled_obj;
}

Q:GraphCompiler负责编译链接的任务均交给了backends::Compiler,那么此后端Compiler是否有编译缓存呢?

A:host module 端看起来主要是函数声明和调用逻辑,device module 主要是函数定义

如下是一个 CodeGen 生成的源码,即将写到一个 file 文件中传递给编译引擎做编译。如果是多个函数,则会放到同一个文件中编译、链接。

从代码来看,我理解对于一个 CINN 的 sub graph ,会对应一个GraphCompiler来编译生成一个名称范式为:fn_xxx_yyy_zzz 的函数:

  • 描述 sub graph 里所有 op 整体的计算逻辑
  • 可能经过算子 Decompose、优化等逻辑,生成多个子函数
  • 多个子函数放到一个 host 文件、一个 cuda 文件,统一编译、链接成一个函数指针
  • 待确认项:所以lower_func层面是没有缓存的?

上图是在构建 engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));

附录:TVM中编译实现

Q:TVM里类似 GraphCompiler 的角色是什么?

A:大致复习了TVM的源码,感觉是 TECompilerImpl ,继承自TECompilerNode,提供了如下核心接口:

  // Lower the function.
  CachedFunc Lower(const CCacheKey& key) {
    return LowerInternal(key, global_var_supply_)->cached_func;
  }
  
// For now, build one module per function.
  PackedFunc JIT(const CCacheKey& key) final {
    CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
    if (value->packed_func != nullptr) {
      return value->packed_func;
    }
    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));   // <------ 此处 m 是一个 runtime::Module 对象
    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
    return value->packed_func;
  }

  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
    return LowerShapeFuncInternal(key)->cached_func;
  }

值得注意的是,TECompilerImpl 中包含了两个缓存相关的数据结构:

  /*! \brief internal compiler cache */
  std::unordered_map<CCacheKey, CCacheValue> cache_;
  
  /*! \brief internal compiler cache for shape funcs */
  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;

Q: 上述 build () 方法是做什么用的?与飞桨的 backend::Compiler 角色是一样的么?

A:我认为是一样的,而且其返回的 runtime::Module 对象似乎可以对标飞桨 CINN 中的 RuntimeProgram来理解?

// Build for heterogeneous execution when targets are specified as
// objects.  This wrapper around the internal API is maintained for
// backwards compatibility.
runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) {
  return TIRToRuntime(input, target_host);
}

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
                             const Target& target_host_arg) {   // <------- 实现
  std::vector<runtime::Module> device_modules;
  Map<Target, IRModule> inputs = inputs_arg;
  Target target_host = target_host_arg;

  // Fetch previous defined target host in targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  if (!target_host.defined()) {
    for (const auto& it : inputs) {
      if (it.first->GetTargetDeviceType() == kDLCPU ||
          it.first->GetTargetDeviceType() == kDLMicroDev) {
        target_host = it.first;
        break;
      }
    }
  }

  if (!target_host.defined()) {
    target_host = DefaultTargetHost(target_host);
  }

  // Update target host for all targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  // Take the attrs from the first module so the eventual modules have them.
  // Ideally this would just be one unified module all the way through;
  IRModule first_module = (*inputs.begin()).second;
  IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

  ICHECK(mhost_all.defined()) << "The host module must be defined";

  for (const auto& it : inputs) {
    if (it.second.defined()) {
      const Target& target = it.first;
      const IRModule& ir_module = it.second;
      auto pair = SplitMixedModule(ir_module, target, target_host);
      auto& host_mod = pair.first;
      auto& device_mod = pair.second;

      ICHECK(host_mod.defined()) << "The split host module must be defined";

      ICHECK(mhost_all.defined()) << "The host module must be defined";

      // We don't want library modules going back into host codegen
      // unless they're supposed to. Here if we overrode the target host
      // to allow lowering previously we check that it's meant to be placed
      // back into the host Module.
      bool overrides_host_target =
          target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
      bool non_host_target_kind = target->kind != target_host->kind;
      if (overrides_host_target && non_host_target_kind) {
        device_modules.push_back(codegen::Build(host_mod, it.first));
      } else {
        mhost_all->Update(host_mod);
      }

      if (device_mod->functions.size() != 0) {
        device_modules.push_back(codegen::Build(device_mod, it.first));
      }
    }
  }

  runtime::Module mhost = codegen::Build(mhost_all, target_host);   // <----- 编译?
  for (const auto& it : device_modules) {
    if (it.operator->()) {
      mhost.Import(it);
    }
  }

  return mhost;
}

runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }

  auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
  if (target_attr_map.count(target->kind)) {
    return target_attr_map[target->kind](mod, target);
  }

  // the build function.
  std::string build_f_name = "target.build." + target->kind->name;
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

runtime::Module BuildCUDA(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
        << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
    cg.AddFunction(f);
  }

  std::string code = cg.Finish();

  if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
    code = (*f)(code).operator std::string();
  }
  std::string fmt = "ptx";
  std::string ptx;
  const auto* f_enter = Registry::Get("target.TargetEnterScope");
  (*f_enter)(target);
  if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
    ptx = (*f)(code).operator std::string();
    // Dirty matching to check PTX vs cubin.
    // TODO(tqchen) more reliable checks
    if (ptx[0] != '/') fmt = "cubin";
  } else {
    ptx = NVRTCCompile(code, cg.need_include_path());
  }
  const auto* f_exit = Registry::Get("target.TargetExitScope");
  (*f_exit)(target);
  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}

Q:TVM中是从哪里调用执行的?

A:看到了一个 GraphExecutor 的数据结构。

标签:CINN,std,缓存,const,target,auto,编译,key,graph
From: https://www.cnblogs.com/CocoML/p/17376062.html

相关文章

  • 每天打卡一小时 第十九天 编译四部曲
    第一部曲自然语言 先将大数类的框架写好,再定义其中的函数分别写出每一个函数,通过分步骤的方法解决问题 有参构造函数首先声明函数时,默认参数定义根据数值的正负进行选择 然后进行循环将数字进行输入拷贝构造函数循环进行赋值操作公有函数成员选择正负符号循环赋......
  • 5个强大的Java分布式缓存框架推荐
    本文主要是分享了5个常用的Java分布式缓存框架,这些缓存框架支持多台服务器的缓存读写功能,可以让你的缓存系统更容易扩展。1、EhcacheEhcache是一个Java实现的开源分布式缓存框架,EhCache可以有效地减轻数据库的负载,可以让数据保存在不同服务器的内存中,在需要数据的时候可以快速存取......
  • 交叉编译boost库
    ./bootstrap.sh--with-toolset=gccproject-config.jam:if!gccin[feature.values<toolset>]{usinggcc:arm:/home/arci/buildroot-2023.02/output/host/bin/arm-buildroot-uclinux-uclibcgnueabi-gcc;}./b2link=static./b2link=staticinstall-......
  • 缓存----Ibatis /Hibernate
    iBatis缓存的使用方法及解释:以iBatis2.3为例,做以下配置即可实现对某些查询进行缓存1、<settingslazyLoadingEnabled="false" cacheModelsEnabled="true" enhancementEnabled="true"/>   注释:       lazyLoadingEnabled延迟加载数据;cacheModelsEna......
  • 【CacheLine】关于缓存行的笔记(存疑)
    什么是缓存行Cache是由很多个cacheline组成的。每个cacheline通常是64字节,并且它有效地引用主内存中的一块儿地址。一个Java的long类型变量是8字节,因此在一个缓存行中可以存8个long类型的变量。CPU每次从主存中拉取数据时,会把相邻的数据也存入同一个cacheline。在访问一......
  • Eclipse4.5Mars安装JAD反编译插件
    第一步:打开eclipse帮助中的软件安装第二步:添加站点:http://feeling.sourceforge.net/update,如下图:第三步:选择EclipseClassDecompiler,下一步,下一步第四步:接受许可,安装!图略!第五步:重启eclipse!第六步:配置Jad,下图为EclipseClassDecompiler的首选项页面,可以选择缺省的反编译器工具,并......
  • 编译?汇编?链接?
    前言我还记得在我大一的时候上C语言课,老师的期末实验是用C语言写一个命令行的管理系统,本着模块化的思想,我很自然的想到系统中具有不同职责的模块应该分到不同的文件里去,但我真的不知道C语言该怎么做这种拆分,所以最后我用一个巨大无比的文件完成了实验。现代编程语言经过层层抽象......
  • leveldb armlinx交叉编译
    首先安装所有依赖,在linux下可以直接编译成功,在armlinux低版本编译器(由于系统限制,只能使用这个版本)下有点问题。1、在CMakeLists.txt中增加set(CMAKE_C_COMPILER"/xxxxxx/arm-linux-gnueabihf-gcc")set(CMAKE_CXX_COMPILER"/xxxxxx/arm-linux-gnueabihf-g++")2、编译报错......
  • VoIP应用在Ubuntu 14.04下编译FFmpeg libX264及PJSIP
    PJSIP是一个开源的SIP协议栈。它支持多种SIP的扩展功能,可说算是最目前流行的SIP协议栈之一了。 它实现了SIP、SDP、RTP、STUN、TURN和ICE。PJSIP作为基于SIP的一个多媒体通信框架提供了非常清晰的API,以及NAT穿越的功能。PJSIP具有非常好的移植性,几乎支持现今所有系统:从桌面系统......
  • Android dtbo(3) 编译和验证
    您可以使用设备树编译器(DTC)编译设备树源文件。不过,在将叠加层DT应用于目标主DT之前,您还应该通过模拟DTO的行为来验证结果。1.通过DTC进行编译构建主DT.dts的示例命令:dtc-@-Odtb-omy_main_dt.dtbmy_main_dt.dts构建叠加DT.dts的示例命令:dtc-@-Odtb......