采用 「问-答」形式记录研读 CINN 开源框架的笔记
Q:CINN中子图编译的入口是在哪里?
for (const auto& node_vec : clusters) { // <------- 逐个遍历每个子图
// Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set,
deny_var_set,
&cluster_inputs,
&cluster_outputs,
&cluster_internals,
is_inference_stage,
all_skip_gc_vars);
auto subgraph = CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
if (graph->Has(kSkipGcVarNames)) {
auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph)); // <------ 添加子图(可能包含-1动态shape)
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);
// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, // <------- 编译并缓存每个子图的结果
cluster_inputs,
cluster_outputs,
cluster_internals,
compilation_key,
graph);
Q:AddGraph做的事情是什么?
int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
int64_t graph_key = std::hash<Graph *>()((&(*graph)));
graphs_[graph_key] = std::move(graph); // <------ 编译期原生静态图包含-1
return graph_key;
}
// Add一个graph后,会同步替换原生的Graph为一个 [cinn_launch] Op
Q:CINN中不同Program下的子图编译结果可以复用么?hashkey是否耦合了var_name?
size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
// sort grad node by name and id.
auto compare = [](ir::Node* n1, ir::Node* n2) {
return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
: (n1->Name() < n2->Name());
};
// graph.Nodes() return unordered_set, here using set to avoid the same graph
// may return different result
std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare),
output_set(compare);
node_set.insert(graph.Nodes().begin(), graph.Nodes().end());
std::string hash_str;
for (ir::Node* n : node_set) {
hash_str.append(n->Name());
output_set.clear();
output_set.insert(n->outputs.begin(), n->outputs.end());
for (auto* out : output_set) {
hash_str.append(out->Name()); // <------ 耦合了graph中的var_name
}
}
VLOG(1) << "The hash graph:\n" << hash_str;
size_t hash_val = std::hash<std::string>()(hash_str);
VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
return hash_val;
} //
Bert中具体的一个hash_key样例:cumsumcumsum_0.tmp_0cumsum_0.tmp_0elementwise_subelementwise_subtmp_0feedinput_idsfetchfill_any_likefull_like_0.tmp_0full_like_0.tmp_0cumsumelementwise_subinput_idsfill_any_liketmp_0fetch
size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::ostringstream has_str;
for (const auto& name_shape : key.input_shapes_) { // <------- 输入shape信息
has_str << name_shape.first;
has_str << std::hash<phi::DDim>()(name_shape.second);
}
has_str << key.graph_hash_val_; // graph 结构信息
has_str << key.arch_str_; // target 信息
return std::hash<std::string>()(has_str.str());
}
Q:主框架是何时触发「编译」的?
template <typename DeviceContext, typename T>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey);
// 根据输入的Tensor shape信息来触发,此时会消解掉一些动态shape为-1的值
const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
compilation_key, inputs_name2tensor, target, stream);
}
Q:CINN是如何消除动态shape的?
void CinnGraphSymbolization::RunOp(const CinnOpDesc& op_desc,
const OpMapperContext& ctx) const {
const auto& op_type = op_desc.Type();
auto* kernel = ::cinn::frontend::OpMapperRegistry::Global()->Find(op_type);
VLOG(4) << "Running Op " << op_type;
kernel->Run(op_desc, ctx); // 此处会由NetBuilder->build()分发到具体API上,调用infer_shape
}
Q:CINN内部是哪里触发缓存机制的?
const CinnCompiledObject &CinnCompiler::Compile(
const Graph &graph,
const std::map<std::string, const phi::DenseTensor *> &input_tensors,
const Target &target,
void *stream) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKeyByAddress cur_key_by_address(
graph, input_tensors, target.arch_str()); // 优先通过graph.ptr + shape + target 来获取?
CinnCacheKeyByStructure cur_key_by_struct; // 若未命中,则再以 graph info + shape + target 来获取
if (!cache_by_address_.count(cur_key_by_address)) {
// generate the structure cache key
cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
if (!cache_by_struct_.count(cur_key_by_struct)) {
std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num, stream); // 核心职责交给 CompileGraph
std::unique_lock<std::mutex> guard(lock_);
// double check cache_by_struct_
if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
}
// double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
} else {
std::unique_lock<std::mutex> guard(lock_);
// double check cache_by_address_
if (!cache_by_address_.count(cur_key_by_address)) {
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct);
}
}
}
return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
}
Q: CompileGraph里的核心职责是什么,是否还有缓存?
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph &graph,
const std::map<std::string, const phi::DenseTensor *> &input_tensors,
const Target &target,
std::int64_t compiled_num,
void *stream) const {
CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
auto frontend_program = symbol();
auto fetch_ids = symbol.GetFetchIds();
VLOG(4) << "All fetch var ids in CINN: "
<< string::join_strings(fetch_ids, ',');
auto cinn_graph = Optimize(&frontend_program, fetch_ids, target); // 同一个ir::Graph仅会做一次
VLOG(4) << "-- The " << compiled_num << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize();
auto scope = BuildScope(target, cinn_graph);
auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph); // GraphCompiler一次性工作,但会被compiled_obj持有
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
if (!FLAGS_enable_pe_launch_cinn) {
options.with_buffer_handle_instruction_inserted = true;
}
std::unique_ptr<AutoTuner> auto_tuner;
if (FLAGS_enable_cinn_auto_tune) {
VLOG(4) << "Compile with auto-tune";
auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
::cinn::auto_schedule::TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto tuning_result = auto_tuner->Tune(tuning_options);
options.Apply(tuning_result);
}
auto compiled_res =
graph_compiler->Build(options, std::move(fetch_ids), stream);
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler),
std::move(auto_tuner),
std::move(compiled_res.runtime_program),
scope,
symbol.var_model_to_program_map()}; // <------对应于 paddle2cinn_varmap
compiled_obj->cached_index = compiled_num;
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(graph,
*compiled_obj);
CheckCompiledValid(graph, input_tensors, *compiled_obj);
return compiled_obj;
}
Q:GraphCompiler负责编译链接的任务均交给了backends::Compiler,那么此后端Compiler是否有编译缓存呢?
A:host module 端看起来主要是函数声明和调用逻辑,device module 主要是函数定义
如下是一个 CodeGen 生成的源码,即将写到一个 file 文件中传递给编译引擎做编译。如果是多个函数,则会放到同一个文件中编译、链接。
从代码来看,我理解对于一个 CINN 的 sub graph
,会对应一个GraphCompiler
来编译生成一个名称范式为:fn_xxx_yyy_zzz
的函数:
- 描述 sub graph 里所有 op 整体的计算逻辑
- 可能经过算子 Decompose、优化等逻辑,生成多个子函数
- 多个子函数放到一个 host 文件、一个 cuda 文件,统一编译、链接成一个函数指针
- 待确认项:所以lower_func层面是没有缓存的?
上图是在构建 engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
附录:TVM中编译实现
Q:TVM里类似 GraphCompiler
的角色是什么?
A:大致复习了TVM的源码,感觉是 TECompilerImpl
,继承自TECompilerNode
,提供了如下核心接口:
// Lower the function.
CachedFunc Lower(const CCacheKey& key) {
return LowerInternal(key, global_var_supply_)->cached_func;
}
// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
if (value->packed_func != nullptr) {
return value->packed_func;
}
auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); // <------ 此处 m 是一个 runtime::Module 对象
value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
return value->packed_func;
}
CachedFunc LowerShapeFunc(const CCacheKey& key) final {
return LowerShapeFuncInternal(key)->cached_func;
}
值得注意的是,TECompilerImpl 中包含了两个缓存相关的数据结构:
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
/*! \brief internal compiler cache for shape funcs */
std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
Q: 上述 build ()
方法是做什么用的?与飞桨的 backend::Compiler
角色是一样的么?
A:我认为是一样的,而且其返回的 runtime::Module
对象似乎可以对标飞桨 CINN 中的 RuntimeProgram
来理解?
// Build for heterogeneous execution when targets are specified as
// objects. This wrapper around the internal API is maintained for
// backwards compatibility.
runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) {
return TIRToRuntime(input, target_host);
}
runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) { // <------- 实现
std::vector<runtime::Module> device_modules;
Map<Target, IRModule> inputs = inputs_arg;
Target target_host = target_host_arg;
// Fetch previous defined target host in targets
CheckAndUpdateHostConsistency(&inputs, &target_host);
if (!target_host.defined()) {
for (const auto& it : inputs) {
if (it.first->GetTargetDeviceType() == kDLCPU ||
it.first->GetTargetDeviceType() == kDLMicroDev) {
target_host = it.first;
break;
}
}
}
if (!target_host.defined()) {
target_host = DefaultTargetHost(target_host);
}
// Update target host for all targets
CheckAndUpdateHostConsistency(&inputs, &target_host);
// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);
ICHECK(mhost_all.defined()) << "The host module must be defined";
for (const auto& it : inputs) {
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;
ICHECK(host_mod.defined()) << "The split host module must be defined";
ICHECK(mhost_all.defined()) << "The host module must be defined";
// We don't want library modules going back into host codegen
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
bool overrides_host_target =
target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
mhost_all->Update(host_mod);
}
if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
}
}
}
runtime::Module mhost = codegen::Build(mhost_all, target_host); // <----- 编译?
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
}
}
return mhost;
}
runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
.value()) {
mod = tir::transform::SkipAssert()(mod);
}
auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
if (target_attr_map.count(target->kind)) {
return target_attr_map[target->kind](mod, target);
}
// the build function.
std::string build_f_name = "target.build." + target->kind->name;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
ICHECK(bf != nullptr) << build_f_name << " is not enabled";
return (*bf)(mod, target);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
Q:TVM中是从哪里调用执行的?
A:看到了一个 GraphExecutor
的数据结构。