首页 > 其他分享 >TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)

TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)

时间:2024-04-06 14:55:55浏览次数:32  
标签:Subexpr relay -- expr Pass pass new EliminateCommonSubexpr 表达式

定义(What)

公共子表达式消除 就是如果表达式E的值已经计算的到了,并且自计算的到值后E的值就不再改变了,就说,表达式E在后续计算中是一个公共表达式。
简单说,该表达式上面已经执行过了,下面没必要再执行了
举个例子:

import tvm
from tvm import relay
from tvm.relay import transform

def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def before():
    x = relay.var("x", shape=(1, 16))
    y1 = relay.nn.relu(x)
    y2 = relay.nn.relu(x)
    y1 = relay.add(y1, relay.const(1.0, "float32"))
    y2 = relay.add(y2, relay.const(1.0, "float32"))
    y = relay.add(y1, y2)
    f = relay.Function([x], y)
    return f

z = before()
print("before")
print(z)
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
print("after")
print(z)

通过print(z)打印公共子表达式消除前IRModule对象内容,如下:
image
消除之后的IRModule对象内容如下:
image
可以发现Relay图中的y2 = relay.nn.relu(x)节点被清除
因为表达式y2 = relay.nn.relu(x)在前一个表达式y1 = relay.nn.relu(x)中已经计算过了,只需要用前面计算过的表达式结果代替即可

作用 (Why)

意义就很简单了,为了避免重新计算表达式E,浪费计算资源,影响运行效率

怎么做(How)

上面的例子可看到,公共子表达式消除主要调用的是relay.transform.EliminateCommonSubexpr()接口,这个接口是对已注册的公共子表达式消除pass的封装。可见路径:python/tvm/relay/transform/transform.py

def EliminateCommonSubexpr(fskip=None):
    """Eliminate common subexpressions.

    Parameters
    ----------
    fskip: Callable
        The callback function that decides whether an expression should be
        skipped.

    Returns
    -------
    ret : tvm.transform.Pass
        The registered pass that eliminates common subexpressions.
    """
    return _ffi_api.EliminateCommonSubexpr(fskip)

通过PackFunc机制,_ffi_api.EliminateCommonSubexpr接口最后会通过_LIB.TVMFuncGetGlobal函数获取到C++端注册的EliminateCommonSubexpr函数。
C++端EliminateCommonSubexpr注册代码如下:

Pass EliminateCommonSubexpr(PackedFunc fskip) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
      };
  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
    .set_body_typed(EliminateCommonSubexpr);

上述代码,CreateFunctionPass()函数作用是生成FunctionPass对象,FunctionPass工作在Relay模块中的每一个Relay函数对象上。

  • FunctionPass() 函数的第一个参数pass_func是TypedPackedFunc对象,真正的pass优化功能由该对象调用pass函数EliminateCommonSubexpr()完成。
  • 第二个参数是优化级别(当通过pass基础架构调用该pass时,会检查pass的优化级别,只有当该pass的优化级别不低于pass上下文配置中的优化级别时,才能启用执行该pass);
  • 第三个参数是函数pass名称;
  • 第四个参数是{} 中列出了公共子表达式消除pass依赖的其他pass,如InferType,因为需要类型信息,所以参数中列出了InferType pass名称

EliminateCommonSubexpr()函数的函数体是CommonSubexprEliminator()函数,它主要通过实现遍历Relay IR,完成Relay IR中的公共子表达式消除功能。

Relay IR遍历的C++实现类是ExprFunctor类的派生类,继承关系如下:
image

CommonSubexprEliminator()类通过重载Rewrite_()方法实现公共子表达式消除功能。该方法将处理过的表达式都存储在unordered_map变量expr _map_中。在每次通过ReWrite_方法处理当前表达式时,会先从expr_map_中查找是否有相同操作类型的已处理表达式,如果有,在判断当前表达式与已处理表达式的属性和参数是否相同,如果这些条件都满足,则返回满足条件的一处理表达式。

expr_map_定义如下:

  std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;

ReWrite_()方法(src/relay/transforms/eliminate_common_subexpr.cc)实现代码如下:

Expr Rewrite_(const CallNode* call, const Expr& post) final {
    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
    Expr new_expr = post;
    const CallNode* new_call = new_expr.as<CallNode>();
    ICHECK(new_call);
    const OpNode* op = new_call->op.as<OpNode>();
    StructuralEqual attrs_equal;

    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }

    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const Expr& candidate_expr : it->second) {
        if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
          bool is_equivalent = true;
          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
            continue;
          }
          for (size_t i = 0; i < new_call->args.size(); i++) {
            if (!IsEquivalent(new_call->args[i], candidate->args[i])) {
              is_equivalent = false;
              break;
            }
          }
          if (!is_equivalent) continue;
          return GetRef<Call>(candidate);
        }
      }
    }
    expr_map_[new_call->op].push_back(new_expr);
    return new_expr;
  }

在python端调用时,通过CreateFunctionPass()函数返回FunctionPass对象,然后通过该对象调用算子,如上述例子中opt_pass(mod)
它会调用Pass类的__call__方法来调用算子

@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
    """The base class of all passes. All methods here are just simple wrappers
    that are implemented in the backend. They are defined for users to
    conveniently interact with the base class.
    """

    @property
    def info(self):
        """Get the pass meta."""
        return _ffi_transform_api.Info(self)

    def __call__(self, mod):
        """Execute the pass. Note that for sequential pass, the dependency among
        different passes will be resolved in the backend.

        Parameters
        ----------
        mod : tvm.IRModule
            The module that a certain optimization is performed on.

        Returns
        -------
        mod : tvm.IRModule
            The updated module after applying this pass.
        """
        return _ffi_transform_api.RunPass(self, mod)

src/ir/transform.cc中ransform.RunPass注册代码如下:

TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
  return pass(std::move(mod));
});

此处的pass就是通过CreateFunctionPass()创建的对象,此处会调用pass中operator()重载,最终会调到FunctionPassNode类中的operator()方法,该实现会调到CreateFunctionPass()时保存的真正公共子表达式消除的代码的实现pass_func

总体,该算子优化还算是比较简单

respect~
致敬

标签:Subexpr,relay,--,expr,Pass,pass,new,EliminateCommonSubexpr,表达式
From: https://www.cnblogs.com/whiteBear/p/18117317

相关文章

  • G. Rudolf and Subway
     原题的无向图等价于上图所示的联通图,此时我们要求的就是起始位置到终止位置最少要经过几个有颜色的结点。code #include<bits/stdc++.h>usingnamespacestd;constintN=4e5+5;intvis[N];intmain(){//freopen("input.txt","r",stdin);intt;cin>>t;......
  • CNCKAD数冲激光编程排版软件介绍
    CNCKAD是一种集成了CAD和CNC的软件,它可以导入各种图形和CAD文件格式,比如DXF、DWG和IGES,帮助用户创建复杂的3D模型、独特的几何形状和特殊的设计要求。CNCKAD的核心功能是自动化的刀具路径生成,它可以通过多种方式生成刀具路径,包括手动、自动、半自动和机器学习等方法,这些技术可以确......
  • Photoshop混合模式的底层原理
        Photoshop虽然不是什么高手,但平时工作中难免会用到,处理部分需求还是可以胜任的。接触PS这么多年,对PS中图层的混合模式(BlendMode)一直就处于懵懂状态,即使是看了教材及视频后,有了一点感性认识,但在实际操作中仍旧无法运用起来。直至某一天,我在B站看到韩世麟的《把PS......
  • C#词法分析自动生成器
    C#词法分析自动生成器前言在做编译原理实验时,要求使用自动生成器生成词法分析器,老师推荐的是用flex,但用flex只会生成C代码,自己项目用的又是C#,本来想使用C代码直接生成dll库并用C#调用,但非常麻烦。干脆找了个能生成C#代码的生成器。配置相关的生成器很多,但我能找到的且能成功......
  • 劫持TLS绕过canary pwn89
    劫持TLS绕过canarypwn88首先了解一下这个东西的前提条件和原理前提:溢出字节够大,通常至少一个page(4K)创建一个线程,在线程内栈溢出原理:在开启canary的情况下,当程序在创建线程的时候,会创建一个TLS(ThreadLocalStorage),这个TLS会存储canary的值,而TLS会保存在stack高地址......
  • OccNet 栅格占据网络:重建智能驾驶场景表征
    随着高阶智能驾驶的发展,长尾障碍物感知成为智驾发力的关键点。驾驶场景中常见的行人、车、障碍物,能够通过3D物体检测等方式实现其位置、大小的估计。而现实世界城区的交通路况中,还存在海量长尾场景问题:如异形车辆、路上的石子、掉落的树叶等障碍物,以3D检测框、点云等传统表......
  • linux - GPG 非对称加密工具
    GNUPrivacyGuard(GPG)是一种主要设计用于使用公钥加密技术对数据进行加密和签名的工具。然而,它还包含仅使用用户提供的密码来加密数据的能力,并且支持多种加密算法。1.查看gpg支持的算法gpg--version2.生成密钥#使用默认选择gpg--generate-key#更灵活的算法选择g......
  • 150行Python代码模拟太阳系行星运转
    今天我们用Python来模拟一下太阳系行星运动轨迹~先上成品图(运行效果含音乐的呦)想要实现这样的效果并不难准备材料首先我们需要准备这样一些材料宇宙背景图背景透明的行星图 编写代码代码分块详解导入需要的模块import pygame  import sys ......
  • Arm架构下麒麟操作系统安装配置Mariadb数据库
    1、安装配置JDK(1)检查机器是否已安装JDK执行java-version命令查看机器是否安装JDK,一般麒麟操作系统默认安装openjdk1.8。  (2)安装指定版本JDK如果麒麟操作系统默认安装的openjdk1.8不符合需求的话,可以卸载机器安装的openjdk1.8并按需安装所需的openjdk版本,此步骤本文不......
  • Python哪种方式循环最快,或许颠覆你的认知!
    众所周知,Python不是一种执行效率较高的语言。此外在任何语言中,循环都是一种非常消耗时间的操作。假如任意一种简单的单步操作耗费的时间为1个单位,将此操作重复执行上万次,最终耗费的时间也将增长上万倍。while 和 for 是Python中常用的两种实现循环的关键字,它们的运行......