首页 > 其他分享 >TVM: VisitExpr流程分析

TVM: VisitExpr流程分析

时间:2022-10-04 22:56:33浏览次数:67  
标签:const Expr 流程 TVM override VisitExpr void op

TVM源码中涉及到表达式遍历的地方,一般是适用VisitExpr接口进行,这个接口设计TVM的visitor模式,具体分析可参考:TVM:visitor设计模式

基类tvm::relay::ExprFunctor

适用visitor遍历的起点是调用VisitExpr接口,看下基类tvm::relay::ExprFunctor中这个方法的代码:

template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
 private:
  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;

 public:
  ......
  /*!
   * \brief The functor call.
   * \param n The expression node.
   * \param args Additional arguments.
   * \return The result of the call
   */
  virtual R VisitExpr(const Expr& n, Args... args) {
    ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
                           "have generated invalid data.";
    static FType vtable = InitVTable();
    return vtable(n, this, std::forward<Args>(args)...);
  }
  // Functions that can be overriden by subclass
  virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
  virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
.......
    throw;
  }

 private:
  // initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
   ......
    return vtable;
  }
};

VisitExpr中调用InitVTable:

// initialize the vtable.
  static FType InitVTable() {
    FType vtable;
    // Set dispatch
    RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
    RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
.....
    return vtable;
  }
  
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                                    \
  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {     \
    return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
  });
template <typename R, typename... Args>
class NodeFunctor<R(const ObjectRef& n, Args...)> {
 private:
  /*! \brief internal function pointer type */
  typedef R (*FPointer)(const ObjectRef& n, Args...);
  /*! \brief refer to itself. */
  using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
  /*! \brief internal function table */
  std::vector<FPointer> func_;

 public:
......
  /*!
   * \brief set the dispacher for type TNode
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template <typename TNode>
  TSelf& set_dispatch(FPointer f) {  // NOLINT(*)
    uint32_t tindex = TNode::RuntimeTypeIndex();
    if (func_.size() <= tindex) {
      func_.resize(tindex + 1, nullptr);
    }
    ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
    func_[tindex] = f;
    return *this;
  }
  /*!
   * \brief unset the dispacher for type TNode
   *
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template <typename TNode>
  TSelf& clear_dispatch() {  // NOLINT(*)
    uint32_t tindex = TNode::RuntimeTypeIndex();
    ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
    func_[tindex] = nullptr;
    return *this;
  }
};

InitVTable中调用NodeFunctor::set_dispatch接口,类型参数为tvm relay ir的各种表达式类型,传入set_dispatch的函数参数是lamad函数,lamad函数体中执行self->VisitExpr_()。self时传入的参数this,当从派生类中发起VisitExpr的时候,这个this将是派生类实例,而不是基类。

NodeFunctor::set_dispatch是在函数指针表func_中添加传入的lamad函数,表项索引为类型参数的id。

InitVTable在为所有类型都调用set_dispatch注册对应的visit调用后,返回了注册的NodeFunctor实例。而VisitExpr在调用InitVTablereturn vtable(n, this, std::forward<Args>(args)...)。NodeFunctor中对()进行了运算符重载

R operator()(const ObjectRef& n, Args... args) const {
    ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
                            << n->GetTypeKey();
    return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
  }

这里以传入的参数的类型id为索引,从func_表中获取对应的lamad函数体,并调用执行。也就是执行了类实例的VisitExpr_。因为一般来说发起VisitExpr调用的是以tvm::relay::ExprFunctor为基类,并在VisitExpr_中完成业务操作的类,所以这里VisitExpr_是调用的业务类中重载后的VisitExpr_方法。业务类对自己关注的类型的VisitExpr_进行重载,在其中完成自己的操作。

如果派生类不对各种类型重载VisitExpr_,就会调用到tvm::relay::ExprFunctor定义的VisitExpr_,抛出异常:

virtual R VisitExpr_(const ConstantNode* op, Args... args) { 
  return VisitExprDefault_(op, std::forward<Args>(args)...); 
};
 
virtual R VisitExpr_(const TupleNode* op, Args... args) { 
  return VisitExprDefault_(op, std::forward<Args>(args)...); 
};
 
virtual R VisitExpr_(const VarNode* op, Args... args) { 
  return VisitExprDefault_(op, std::forward<Args>(args)...); 
};
 
...
 
virtual R VisitExprDefault_(const Object* op, Args...) {  
 ::tvm::runtime::detail::LogFatal("/home/tvm/tvmsource/tvm/include/tvm/relay/expr_functor.h", 114).stream() << "Do not have a default for " << op->GetTypeKey();
    throw;
  }

派生类tvm::relay::ExprVisitor

ExprVisitor继承了ExprFunctor,并对VisitExpr和VisitExpr_进行了重载:

class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
 public:
  void VisitExpr(const Expr& expr) override;
  void VisitExpr_(const VarNode* op) override;
  void VisitExpr_(const GlobalVarNode* op) override;
  void VisitExpr_(const ConstantNode* op) override;
  void VisitExpr_(const TupleNode* op) override;
  void VisitExpr_(const FunctionNode* op) override;
  void VisitExpr_(const CallNode* op) override;
  void VisitExpr_(const LetNode* op) override;
  void VisitExpr_(const IfNode* op) override;
  void VisitExpr_(const OpNode* op) override;
  void VisitExpr_(const TupleGetItemNode* op) override;
  void VisitExpr_(const RefCreateNode* op) override;
  void VisitExpr_(const RefReadNode* op) override;
  void VisitExpr_(const RefWriteNode* op) override;
  void VisitExpr_(const ConstructorNode* op) override;
  void VisitExpr_(const MatchNode* op) override;
  virtual void VisitType(const Type& t);
  virtual void VisitClause(const Clause& c);
  virtual void VisitPattern(const Pattern& c);
  virtual void VisitSpan(const Span& span);
 
 protected:
  // Internal visiting counter
  std::unordered_map<const Object*, size_t> visit_counter_;
};
 
void ExprVisitor::VisitExpr(const Expr& expr) {
  auto it = visit_counter_.find(expr.get());
  if (it != visit_counter_.end()) {
    ++it->second;
  } else {
    using TParent = ExprFunctor<void(const Expr&)>;
    TParent::VisitExpr(expr);
    visit_counter_.insert({expr.get(), 1});
  }
}

visit_counter_表记录了每个表达式(注意不是每种)的访问历史。在VisitExpr中,如果发现该表达式已经访问过,则只是递增该表达式的访问计数,而不做实质的访问操作。如果发现表达式没有遍历过,则调用基类ExprFunctor的VisitExpr,进而调用到发起VisitExpr的某个派生类的VisitExpr_。

派生类tvm::relay::ExprMutator

派生类ExprMutator的定义跟ExprFunctor差不多:

class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
 public:
  /*!
   * \brief Mutate is alias for VisitExpr
   * \return expr.
   */
  Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
  Expr VisitExpr(const Expr& expr) override;
  Expr VisitExpr_(const VarNode* op) override;
  Expr VisitExpr_(const ConstantNode* op) override;
  Expr VisitExpr_(const GlobalVarNode* op) override;
  Expr VisitExpr_(const OpNode* op) override;
  Expr VisitExpr_(const TupleNode* op) override;
  Expr VisitExpr_(const FunctionNode* op) override;
  Expr VisitExpr_(const CallNode* call_node) override;
  Expr VisitExpr_(const LetNode* op) override;
  Expr VisitExpr_(const IfNode* op) override;
  Expr VisitExpr_(const TupleGetItemNode* op) override;
  Expr VisitExpr_(const RefCreateNode* op) override;
  Expr VisitExpr_(const RefReadNode* op) override;
  Expr VisitExpr_(const RefWriteNode* op) override;
  Expr VisitExpr_(const ConstructorNode* op) override;
  Expr VisitExpr_(const MatchNode* op) override;
 
  /*!
   * \brief Used to visit the types inside of expressions.
   *
   * Can be overloaded to transform the types in arbitrary
   * ways, one way would be to define a sub-class of type
   * visitor for types which transform them appropriately.
   */
  virtual Type VisitType(const Type& t);
  virtual Clause VisitClause(const Clause& c);
  virtual Pattern VisitPattern(const Pattern& c);
 
 protected:
  /*! \brief Internal map used for memoization. */
  std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
};
 
Expr ExprMutator::VisitExpr(const Expr& expr) {
  auto it = this->memo_.find(expr);
  if (it != this->memo_.end()) {
    return it->second;
  } else {
    Expr new_expr = ExprFunctor::VisitExpr(expr);
    memo_[expr] = new_expr;
    return new_expr;
  }
}

这里需要注意的是,ExprMutator的VisitExpr和VisitExpr_都是有返回值的,调用将返回遍历到的表达式,这样可以在VisitExpr_外对表达式做操作,比如说修改。

Codegen内存申请时的visitor模式使用

GraphPlanMemory分配流程中涉及的类关系图如下所示:
image

在该流程中分别从StorageAllocInit和StorageAllocator里面调用Run接口,Run接口调用VisitExpr,这个时候调用的是ExprVisitor::VisitExpr。而VisitExpr_则是调用的StorageAllocaBaseVisitor和DeviceAwareExprVisitor中重载的。

从这里也可以看到,ExprFunctor和ExprVisitor是纯粹作为visitor模式的实现而设计,具体的业务在各业务实现类中。

参考:
VisitExpr流程分析

标签:const,Expr,流程,TVM,override,VisitExpr,void,op
From: https://www.cnblogs.com/whiteBear/p/16754737.html

相关文章

  • TVM:visitor设计模式
    visitor模式,因为它在编译器的框架中应用的广泛,在TVM中也是无处不在。visitor模式介绍Visitor(访问者)模式的定义:将作用于某种数据结构中的各元素的操作分离出来封装成独立......
  • webpack打包思路与流程解析
    一:创建一个新的工程,项目初始化npminit-y二:搭建项目框架 三:编写main.js文件内容,在index.js中引入,在把index.js引入到index.html中例:exportdefault()=>{fun......
  • 118-22-ZooKeeper 基础设施详解 和 服务启动流程源码分析_ev
         ......
  • nginx&http 第三章 ngx http 框架处理流程
    1.nginx 连接结构 ngx_connection_t 这个连接表示是客户端主动发起的、Nginx服务器被动接受的TCP连接,我们可以简单称其为被动连接。同时,在有些请求的处理过程中,Nginx会试......
  • TVM:Object家族
    Object.h概述命名空间:TVM::runtime文件中包含的结构:1.结构体TypeIndex2.类Object3.类ObjectPtr4.类ObjectRef5.结构体ObjectPtrHash6.结构体ObjectPtrEqual7.......
  • etcd写流程
    0)整体结构1)server->etcdRaft,处理协程生成msgWithResult  2)etcdRaft模块,从proc取出msgWithResult调用step驱动状态机 封装成Ready实例给server发送:readyc<......
  • 笔记一:机器学习工作流程
    目录1理解问题和背景1.1目的1.2工作环境1.3获取数据2探索性数据分析(EDA)3数据预处理3.1数据清理3.2特征选择3.3特征工程3.4特征缩放4模型探索根据Geron(2019)......
  • Unity同项目导出不同的UnityPackage流程
    Unity3D中的root项目导出为一个package;新建项目并将package导入进来;在新建的项目里更改里面的各种信息(脚本、模型、贴图等等);重命名所有的更改了的资源名称;将所有的prefab都......
  • 2022-10-03-SpringMVC执行流程梳理及结合源码断点调试过程源码分析
    SpringMVC执行流程梳理接口方式控制器实现流程分析控制器层代码实现控制器配置SpringMVC.xml配置文件客户端浏览器发起请求,按回车前端控制器拦截所有请求/......
  • [架构之路-8]:架构师 - 必须熟悉的组织内的软硬件研发流程和几大研发系统
    目录​​前言:​​​​一、系统架构部门在组织内软硬件生产中的位置上​​​​二、软、硬件研发的几大系统​​​​三、软件开发流程与DevOps工具​​​​附录:组织公司的主要......