首页 > 其他分享 >TVM Pass优化 -- InferType 类型推导

TVM Pass优化 -- InferType 类型推导

时间:2024-04-07 20:47:13浏览次数:23  
标签:resolved return -- InferType TVM expr type mod

定义(What)

InferType,类型推断,顾名思义,给表达式进行类型的推断
直接上代码

import tvm
from tvm import relay
import numpy as np

def get_demo_mod():
    a = relay.var("a", shape=(2, 3, 10), dtype="float32")
    b = relay.var("b", shape=(1, 10), dtype="float32")
    c = relay.add(a, b)
    func = relay.Function([a, b], c)
    mod = tvm.IRModule.from_expr(func)
    return mod
	
mod = get_demo_mod()

print("------before InferType------")
try:
    print(mod["main"].body.checked_type)
except Exception:
    print("can't get checked_type")

print("------after InferType------")

mod = relay.transform.InferType()(mod)
print(mod["main"].body.checked_type)

执行结果如下:
image

作用 (Why)

推断表达式的类型及输入输出尺寸
另:在 Relay 优化过程中, 每个 pass 都可以修改/添加/删除 op, 所以每个 pass 之后都需要重新 InferType
如,TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)对公共子表达式消除一节中FunctionPass()第四个参数就是InferType进行类型推断

怎么做(How)

这块代码主要在src/relay/transforms/type_infer.cc文件中,具体实现如下:

Pass InferType() {
  auto pass_info = PassInfo(0, "InferType", {});
  return tvm::transform::CreateModulePass(
      [=](IRModule mod, const PassContext& pass_ctx) {
	...
        AddGlobalTypes(mod);
        VLOG(1) << "AddGlobalTypes'" << PrettyPrint(mod);
        std::vector<std::pair<GlobalVar, Function>> updates;
        for (const auto& it : updated_mod->functions) {
          if (auto func = it.second.as<Function>()) {
            auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value());
            VLOG(1) << "it.first'" << PrettyPrint(it.first) << "it.second"<< PrettyPrint(it.second);

            auto updated_func = inferencer.Infer(it.first, func.value());
            VLOG(1) << "updated_func'" << PrettyPrint(updated_func);
      		...
            it.first->checked_type_ = updated_func->checked_type();

            if (!WellFormed(updated_func, pass_ctx->diag_ctx)) {
              LOG(FATAL) << "The type checked intermediate representation is malformed";
            }

            auto free_tvars = FreeTypeVars(updated_func, mod);
            ICHECK(free_tvars.size() == 0)
                << "Found unbound type variables in " << updated_func << ": " << free_tvars;
            EnsureCheckedType(updated_func);
            updates.push_back({it.first, Downcast<Function>(updated_func)});
          }
        }

        for (const auto& pair : updates) {
          updated_mod->Add(pair.first, pair.second, true);
        }

        return updated_mod;
      },
      0, "InferType", {});
}

TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); });

和公共子表达式消除的实现可发现,该算子调用的是CreateModulePass,因此它是一个模块级的优化,

模块级优化用于实现过程间优化和分析,模块级优化pass工作在tvm.IRModule对象上,将整个程序作为处理单元,几乎可以对程序执行任何操作。

其中,AddGlobalTypes 给mod添加全局参数,为后续的参数推断做准备,
真正进行推断的是TypeInferencer类的Infer()方法,实现如下:

Expr TypeInferencer::Infer(GlobalVar var, Function function) {
	...
  // Step 1: Populate the constraints.
  GetType(function);

  // Step 2: Solve the constraints.
  Solve();

  // Step 3: Attach resolved types to checked_type field.
  auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(function);
	...
  }
  return resolved_expr;
}

第一步,填充约束

  Type GetType(const Expr& expr) {
    auto it = type_map_.find(expr);
    if (it != type_map_.end() && it->second.checked_type.defined()) {
      return it->second.checked_type;
    }
    Type ret = this->VisitExpr(expr);
    ICHECK(ret.defined()) << "expression:" << std::endl << PrettyPrint(expr);
    KindCheck(ret, mod_, this->diag_ctx);
    ResolvedTypeInfo& rti = type_map_[expr];
    rti.checked_type = ret;
    return ret;
  }

会先从type_map_map表中查找该Expr,第一次执行,该map表中一般都是没有的,通常都会走到VisitExpr,并将expr添加到该map表中(里面具体怎么执行的,有待进一步研究)

第二步,解决约束

bool TypeSolver::Solve() {
  while (!update_queue_.empty()) {
    RelationNode* rnode = update_queue_.front();
    const auto& rel = rnode->rel;
    update_queue_.pop();
    ICHECK(!rnode->resolved);
    // update the relation with given evidence.
    Array<Type> args;
    for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) {
      args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
      ICHECK_LE(args.size(), rel->args.size());
    }

    // We need to set this in order to understand where unification
    // errors generated by the error reporting are coming from.
    reporter_->SetSpan(rnode->span);

    try {
      // Call the Type Relation's function.
      bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);

      if (resolved) {
        ++num_resolved_rels_;
      }

      rnode->resolved = resolved;
    } catch (const CompileError& err) {
      this->Emit(Diagnostic::Error(rnode->span) << err.what());
      rnode->resolved = false;
    }

    // Mark inqueue as false after the function call
    // so that rnode itself won't get enqueued again.
    rnode->inqueue = false;
  }

  // This criterion is not necessarily right for all the possible cases
  // TODO(tqchen): We should also count the number of in-complete types.
  return num_resolved_rels_ == rel_nodes_.size();
}
通过调用 Solve() 方法,我们求解填充好的类型约束。解决约束的过程使用了类型约束求解器(constraint solver)来尝试找到满足约束条件的类型赋值方案。

第三步,

Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap,
           TypeSolver* solver)
      : tmap_(tmap), solver_(solver) {}
	  
Expr MixedModeMutator::VisitExpr(const Expr& expr) {
  auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
  auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
  if (memo_.count(expr)) {
    return memo_[expr];
  } else {
    ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
    return memo_[expr];
  }
}

使用 Resolver 类的实例来将解析后的类型信息附加到已解析的表达式的checked_type 字段上。Resolver 类是负责类型解析和处理的工具类。它通过访问表达式的结构,并使用之前求解出的类型信息来确定每个表达式的准确类型。

这里面的东西还是有点深的,后面再研究吧
如有其他友友,可沟通交流

respect~

标签:resolved,return,--,InferType,TVM,expr,type,mod
From: https://www.cnblogs.com/whiteBear/p/18119763

相关文章

  • Java并发(二十四)----wait、notify、notifyAll、join区别与联系
    1、join是调用者轮询检查线程alive状态,执行后线程进入阻塞状态。如在线程B中调用线程A的join(),那线程B会进入到阻塞队列,直到join结束或中断线程B才开始进入阻塞队列。可以实现一个线程的顺序执行。t1.join();等价于下面的代码synchronized(t1){  //调用者线程进入t1......
  • L3-008 喊山
    DFS。#include<bits/stdc++.h>usingnamespacestd;constintinf=0x3f3f3f3f;vector<vector<int>>vec;intvis[10003];intmain(){ intn,m,k; cin>>n>>m>>k; vec.resize(n+10); while(m--){ inta,b; cin>>a&......
  • Vscode+gcc-arm+openocd搭建STM32开发环境
    1简介尝试使用Vscode搭建STM32开发环境,自己记录一下详细的配置过程2工具下载设计到的相关软件以及资源包括Vscode软件、STM32CubeMX、mingw64以及openocd,相应的软件介绍以及下载链接如下:Vscode软件:宇宙第一编辑器,开源,插件丰富CubeMx:初始化代码生产器,HAL库mingw64:因......
  • 全屋光纤(Fiber-to-the-Home,FTTH)
     全屋光纤(Fiber-to-the-Home,FTTH)的起源可以追溯到光纤通信技术的发展历程。光纤通信是一种利用光纤作为传输介质,将信息转换成光信号进行传输的通信技术,具有高速、大带宽、低损耗等优势。20世纪末,随着互联网的普及和数字化技术的发展,人们对网络带宽的需求越来越高。传统的电......
  • make编译报错:fatal error: filesystem: 没有那个文件或目录 #include <filesystem>
    报错:fatalerror:filesystem:没有那个文件或目录#include(filesystem)解决方法一:修改头文件#include<experimental/filesystem>添加依赖在编译时,后面添加:-lstdc++fs编译通过。解决方法二:升级gcc升级到gcc-8或8以上问题即可解决:添加PPA存储库首先,您需要添加Ub......
  • 探究MySQL8.0驱动的加载
    探究MySQL8.0驱动的加载大家在连接mysql的时候,启动项目,会警告你推荐使用com.mysql.cj.jdbc.Driver而不是com.mysql.jdbc.Driver那么这两者到底有什么区别呢本质区别:com.mysql.jdbc.Driver是mysql-connector-java5中的,需要手动加载驱动com.mysql.cj.jdbc.Driver是mysql......
  • C++:类的继承
    基类的构造函数和析构函数不会被继承,但是当子类对象初始化时则会自动调用基类的构造函数和析构函数(默认)如果想要调用基类的重载构造函数,需要在构造函数后加上“:<重载的构造函数>{};”,如下classFATHER{public:FATHER();~FATHER();FATHER(inta)//重载......
  • 探究MySQL8.0驱动的加载
    探究MySQL8.0驱动的加载大家在连接mysql的时候,启动项目,会警告你推荐使用com.mysql.cj.jdbc.Driver而不是com.mysql.jdbc.Driver那么这两者到底有什么区别呢本质区别:com.mysql.jdbc.Driver是mysql-connector-java5中的,需要手动加载驱动com.mysql.cj.jdbc.Driver是mysql-......
  • Matlab 安装及添加 SPM 12
    Matlab安装及添加SPM12因为课题需要,需要学习Matlab的使用,又开始学习一个新的知识!快乐的(bcd)研究生~~~下载安装Matlab首先我从网上一些资源那里下载了Matlab安装压缩包(从百度网盘下的,压缩包都12.02G!下了我好久啊!!!)下载完成后就是这样子的啦~然后进行解压安装(这里解压也......
  • 事件循环
    事件循环单线程是异步产生的原因事件循环是异步的实现方式浏览器的进程模型何为进程?程序运行需要有它自己专属的内存空间,可以把这块内存空间简单的理解为进程每个应用至少有一个进程,进程之间相互独立,即使要通信,也需要双方同意何为线程?有了进程后,就可以运行程序的代码了......