定义
移除无用函数,Remove Unused Function,顾名思义,就是删除Module中定义但未用到的函数
当然,它也是一个模块级的优化,
举例子:
def get_mod():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn2 = relay.Function([], relay.const(2))
fn3 = relay.Function([], relay.const(3))
g1 = relay.GlobalVar("g1")
g2 = relay.GlobalVar("g2")
g3 = relay.GlobalVar("g3")
mod[g1] = fn1
mod[g2] = fn2
mod[g3] = fn3
p = relay.var("p", "bool")
mod["main"] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
return mod
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
print(mod)
例子中,get_mod
定义了三个函数fn1
、fn2
、fn3
,下面并未使用fn3,因此,fn3是个无用函数,应该被移除
该例子输出结果如下:
通过RemoveUnusedFunctions函数后,fn3
已被移除
意义
删除无用代码,可以缩减IR代码,可使程序更小、编译更快、(通常)执行也更快。同时它还可以增强编译器改进代码的能力。
实现
同样,该pass会通过TVM_REGISTER_GLOBAL
在C++端进行注册,python端通过FFI接口进行访问,不过需要注意的是,该Pass的IRModule对象中一定要有main函数,否则在运行该Pass时会因找不到main方法而报错。
Python端访问:
def RemoveUnusedFunctions(entry_functions=None):
if entry_functions is None:
entry_functions = ["main"]
return _ffi_api.RemoveUnusedFunctions(entry_functions)
C++端注册:
Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
};
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
}
TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
CreateModulePass
说明了该Pass是一个Module模块级别的优化
该Pass优化的真正实现是在RemoveUnusedFunctions
方法中,代码实现如下:
IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
VLOG(2) << "RemoveUnusedFunctions:" << entry;
auto funcs = CallTracer(module).Trace(entry);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
for(auto func : called_funcs)
{
VLOG(2) << "called_funcs:" << func;
}
auto existing_functions = module->functions;
for (auto f : existing_functions) {
VLOG(2) << "existing_functions:" << f.first->name_hint;
auto it = called_funcs.find(f.first->name_hint);
if (it == called_funcs.end()) {
module->Remove(f.first);
}
}
return module;
}
该函数调用CallTracer类
(继承ExprVisitor类
)的成员函数Trace()
,获取IRModule对象中main函数
调用的所有函数,并保存在unordered_set<std::string>
类型变量called_funcs_
中,然后RemoveUnusedFunctions()
函数遍历IRModule对象中的所有函数,将不在called_funcs_
中的函数视为无用函数,并将它从IRModule对象中移除。
通过上述的调试信息,可以进行佐证:
该Pass优化还是比较简单的
Respect~
标签:Function,函数,relay,RemoveUnusedFunctions,Pass,Unused,移除,entry,mod From: https://www.cnblogs.com/whiteBear/p/18134269