首页 > 其他分享 >TVM:visitor设计模式

TVM:visitor设计模式

时间:2022-10-04 22:14:31浏览次数:69  
标签:tvm include const relay visitor TVM VisitExpr 设计模式 op

visitor模式,因为它在编译器的框架中应用的广泛,在TVM中也是无处不在。

visitor模式介绍

Visitor(访问者)模式的定义:将作用于某种数据结构中的各元素的操作分离出来封装成独立的类,使其在不改变数据结构的前提下可以添加作用于这些元素的新的操作,为数据结构中的每个元素提供多种访问方式

GoF 设计模式这本书中提出Visitor模式的时候就是以编译器作为例子的。在编译器中,一般会用抽象语法树来构建中间表示。然后会有一些优化pass是基于抽象语法树来做的,例如类型检查,常量折叠,代码优化等等

有一种实现方法是如下图,对每个数据类型都实现Pass相关的功能。这就导致了一个问题,每增加一个优化Pass,都要重新修改现有的数据结构从而增加新的方法。
image

如果对象结构中数据元素的类型相对稳定的情况下(不经常增加新的数据类型),可以考虑使用下图中的方法。这样在新增加一个优化Pass的时候,就只要新增一个Visitor类继承自Visitor类就可以,然后添加相应的Visit方法,而不需要修改原有的对象结构中数据元素
image

visitor模式在TVM中的应用

image

上图是TVM中Vistor的几个基础的类。这儿只列举了CallNode和FunctionNode这两个数据类型,其它的还有IfNode,LetNode等,大家可以查看源码。

ExprFunctor::VisitExpr(const Expr& n, Args... args) 会根据Expr& n的类型进行分发,从而调用相应的VisitExpr_接口,例如n的类型是Function,那就会调用VisitExpr_(const FunctionNode* op, Args... args)接口

ExprVisitor 类不会改变数据,ExprMutator会改变数据。

下面解读下ExprVisitor的源码。

void ExprVisitor::VisitExpr(const Expr& expr) {
  auto it = visit_counter_.find(expr.get());
  if (it != visit_counter_.end()) {
    ++it->second;
  } else {
    using TParent = ExprFunctor<void(const Expr&)>;
    TParent::VisitExpr(expr);
    visit_counter_.insert({expr.get(), 1});
  }
}

ExprVisitor 重写了ExprFunctor的VisitExpr函数,成员变量visit_counter用来记录已经访问过的expr,防止再次访问。最终还是会调用ExprFunctor的VisitExpr函数,用来做分发功能。如果这时候的expr类型是Call类型,那么紧接着就会调用ExprVisitor::VisitExpr_(const CallNode* op) 函数。

void ExprVisitor::VisitExpr_(const CallNode* op) {
 this->VisitSpan(op->span);
 this->VisitExpr(op->op);

 for (auto ty_arg : op->type_args) {
 this->VisitType(ty_arg);
  }

 for (auto arg : op->args) {
 this->VisitExpr(arg);
  }
}

可以看到它会对args调用VisitExpr函数,又会trigger ExprFunctor的分发功能,判断args的具体类型,然后去调用相应的VisitExpr_方法。

添加Pass打印包含的Relay op

添加新的Pass,如果想要遍历expr,只要添加一个新类继承自ExprVisitor 或者ExprMutator就可以。

下面是我添加的一个打印Call op的pass,在调用CallNode的时候打印下op的类型,所以只需要重写VisitExpr_(const CallNode* call)函数就可以。文件添加在tvm/src/relay/transform 目录下

#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace relay {

class TestCallExpr: public ExprVisitor {
 public:

 private:
  std::unordered_map<Expr, bool, ObjectPtrHash, ObjectPtrEqual> memo_;

  void VisitExpr_(const CallNode* call) final {
    ExprVisitor::VisitExpr_(call);
    std::cout << "call :" << std::endl;
    std::cout << AsText(call->op, false, nullptr) << std::endl;
  }
};

Expr TestCallParse(const Expr& expr, const IRModule& mod) {
  TestCallExpr().VisitExpr(expr);
  return expr;
}

namespace transform {
Pass  TestCallParse() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(TestCallParse(f, m));
      };
  return CreateFunctionPass(pass_func, 2, "TestCallParse", {});
}

TVM_REGISTER_GLOBAL("relay._transform.TestCallParse").set_body_typed(TestCallParse);
}
}
}

在python/tvm/relay/transform/transform.py中添加如下代码

def TestCallParse():
    return _ffi_api.TestCallParse()

添加测试代码

import tvm
from tvm import te
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import tvm.testing
import numpy as np

dshape = (1, 16)
x = relay.var("x", shape=dshape)
y = relay.add(x, relay.const(1, "float32"))
z = relay.multiply(y, relay.const(1, "float32"))
print(z)
zz = run_opt_pass(z, transform.TestCallParse()

它对应的语法树为:

image

打印的结果为:

free_var %x: Tensor[(1, 16), float32];
%0 = add(%x, 1f);
multiply(%0, 1f)

call :
#[version = "0.0.5"]
add
call :
#[version = "0.0.5"]
multiply

可以看出,先打印的是add,因为在遍历multiply(%0, 1f)的时候,需要遍历%0,因为%0是multiply(%0, 1f)的一个arg。遍历%0的时候 就会打印出op类型add。遍历完 multiply(%0, 1f)之后才会打印 multiply。

参考:
TVM之设计模式解读(一)--visitor模式

标签:tvm,include,const,relay,visitor,TVM,VisitExpr,设计模式,op
From: https://www.cnblogs.com/whiteBear/p/16754600.html

相关文章

  • 设计模式:访问者模式
    访问者模式诞生的思维过程访问者模式难理解、难实现,应用它会导致代码的可读性、可维护性变差,所以,访问者模式在实际的软件开发中很少被用到,在没有特别必要的情况下,建议你不......
  • 设计模式-单例模式
    单例模式的英文叫做singleton模式,我先说一下,单例模式是怎么回事,就是,在你的系统里,你要判断一下,如果有一些类,只需要一个实例就可以了,那就给那个类,做成单例的模式。实际上我......
  • TVM:Object家族
    Object.h概述命名空间:TVM::runtime文件中包含的结构:1.结构体TypeIndex2.类Object3.类ObjectPtr4.类ObjectRef5.结构体ObjectPtrHash6.结构体ObjectPtrEqual7.......
  • 桥接模式【Java设计模式】
    桥接模式【Java设计模式】​​前言​​​​推荐​​​​桥接模式​​​​介绍​​​​实现​​​​最后​​前言2022/9/2313:34路漫漫其修远兮,吾将上下而求索本文是根据袁......
  • 设计模式---适配器模式
    简述类型:结构型目的:解决接口不兼容问题。话不多说,看个案例吧。优化案例最初版v0在真实的开发场景中,系统的每个模块都是分配给不同的团队或个人来开发的。这使得事......
  • 设计模式概述
    GOF-23模式分类从目的来看:创建型(Creational)模式:将对象的部分创建工作延迟到子类或者其他对象,从而应对需求变化为对象创建时具体类型实现引来的冲击。结构型(Structural)模式:......
  • Java设计模式 —— 建造者模式
    8建造者模式8.1建造者模式概述BuilderPattern:将一个复杂对象的构建与它的表示分离,使得同样的构建过程可以创建不同的表示。建造者模式可以将部件本身和它们的组......
  • 抽象类及模板设计模式
    1基本介绍当父类的某些方法,需要声明,但是又不确定如何实现时,可以将其声明为抽象方法,那么这个类就是抽象类当父类的一些方法不能确定时,可以用abstract关键字来修饰该方......
  • Go设计模式学习准备——下载bilibili合集视频
    需求前段时间面试,被问到设计模式。说实话虽然了解面向对象、多态,但突然被问到设计模式,还要说清解决什么问题,自己是有些懵的,毕竟实习主要工作是在原项目基础进行CRUD,自己还......
  • Java设计模式 —— 原型模式
    7原型模式7.1原型模式概述PrototypePattern:使用原型实例指定待创建对象的类型,并且通过复制这个原型来创建新的对象。原型模式的工作原理:将一个原型对象传给创建......