首页 > 其他分享 >hidet使用rule based调度

hidet使用rule based调度

时间:2024-05-22 17:55:48浏览次数:26  
标签:__ task based torch rule hidet input import

定义computation

整体流程类似于tvm的计算描述

定义输入、输出tensor,指定名称、数据类型和shape

a = tensor_input('a', dtype='float32', shape=[10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])

使用compute定义计算,指定名称、shape、计算表达式

b = compute('copy', shape=[10], fcompute=lambda i: a[i])

语义等价于

for i1 in range(10):
	b[i1] = a[i1]

此外还有一些参数,使用这些参数可以指定reduce等计算。

封装Task

将计算封装为Task,然后使用hidet提供的rule-based调度器自动将计算生成代码

一个task由名称、tensor输入和输出组成,对应于计算的输入输出

from typing import List
import hidet
from hidet.ir.task import Task


def run_task(task: Task, inputs: List[hidet.Tensor]):
    """Run given task and print inputs and outputs"""
    from hidet.runtime import CompiledTask

    # build the task
    func: CompiledTask = hidet.drivers.build_task(task, target='cpu')

    # run the compiled task
    outputs = func.run_async(inputs)

    print('Task:', task.name)
    print('Inputs:')
    for tensor in inputs:
        print(tensor)
    print('Output:')
    for tensor in outputs:
        print(tensor)
    print()

使用build_task将task lowering到可执行函数,主要包含以下流程

  1. 根据device等信息调度task到scheduler
  2. scheduler将task下降到IRModule
  3. 优化并继续下降IRModule
  4. 根据device进行代码生成
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else

class AbsTask(Task):
    def __init__(self, input):
        out = compute(
            name='out',
            shape=input.shape,
            fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
        )

        super().__init__(
            name='abs',
            inputs=[input],
            outputs=[out]
        )

封装算子类,输入是hidet Tensor


from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like

class AbsOp(Operator):
    def __init__(self, input: Tensor):
        super().__init__(
            inputs=[input],
            attributes={},
            task=AbsTask(
                input_like(input, 'input')
            )
        )

def abs(input: Tensor) -> Tensor:
    return AbsOp(input).outputs[0]

输出结果在outputs中,通过在torch compile时指定后端为hidet,就能实现runtime的算子替换。

import hidet

from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else

hidet.option.cache_dir('./outs/cache')

op_device = "cuda"

class AbsTask(Task):
    def __init__(self, input):
        out = compute(
            name='out',
            shape=input.shape,
            fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
        )

        super().__init__(
            name='abs',
            inputs=[input],
            outputs=[out]
        )

from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like

class AbsOp(Operator):
    def __init__(self, input: Tensor):
        super().__init__(
            inputs=[input],
            attributes={},
            task=AbsTask(
                input_like(input, 'input')
            )
        )

import torch
from torch import nn
from hidet.graph.frontend.torch.interpreter import (
    register_function,
)

@register_function(torch.abs)
def abs_demo(input: Tensor) -> Tensor:
    return AbsOp(input).outputs[0]

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.abs(x)
        return x

def run_demo():
    input = hidet.randn([2, 3, 3])
    print(input)

    torch_ref = torch.randn([2, 3, 3], device=op_device)
    print("input: ", torch_ref)
    m = Model()
    m = torch.compile(m, backend='hidet')
    print("output: ", m(torch_ref))
    
run_demo()
input:  tensor([[[ 0.3670,  0.3044, -0.8114],
         [-0.3331,  0.3331, -0.1969],
         [-1.4178,  0.3059,  0.4833]],

        [[-0.6634,  2.7176, -0.4525],
         [ 0.9879, -0.0581,  0.9540],
         [ 0.1877, -0.2522,  0.0652]]], device='cuda:0')
output:  tensor([[[0.3670, 0.3044, 0.8114],
         [0.3331, 0.3331, 0.1969],
         [1.4178, 0.3059, 0.4833]],

        [[0.6634, 2.7176, 0.4525],
         [0.9879, 0.0581, 0.9540],
         [0.1877, 0.2522, 0.0652]]], device='cuda:0')

标签:__,task,based,torch,rule,hidet,input,import
From: https://www.cnblogs.com/ddl789/p/18206794

相关文章

  • vue给元素添加校验rules
    1.使用validator添加校验规则:在Elemengplus(Vue3版本的ElementPlus)框架中,给el-dialog中的input框添加日期格式验证,可以使用el-form和el-form-item组件来配合实现,并通过el-input组件的v-model绑定数据,结合el-form的验证规则rules实现。以下是一个简单的例子,演示如何给日期输入框......
  • Paper Reading: Tri-objective optimization-based cascade ensemble pruning for dee
    目录研究动机文章贡献本文方法染色体编码适应度函数评估进化过程最终解选择级联剪枝框架实验结果数据集和实验设置三目标优化的效果不同集成规模的算法比较算法在不同数据集上的比较优点和创新点PaperReading是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能......
  • WPF ValidatesOnDataErrors IDataErrorInfo ValidationRule
    //xaml<Windowx:Class="WpfApp91.MainWindow"xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d="http://schemas.mic......
  • FBWF(File-Based Write Filter)是Windows操作系统中的一种功能,主要用于保护系统的存储设
    FBWF(File-BasedWriteFilter)是Windows操作系统中的一种功能,主要用于保护系统的存储设备(如硬盘)免受意外写入或恶意软件的影响。它通过将所有对存储设备的写操作重定向到一个临时缓存中,从而保护存储设备的内容不被修改。FBWF的主要优点包括:简化系统管理:可以在不影响系统运行......
  • Vue中form表单常用rules校验规则
    是否合法IP地址constcheckIPCode=(rule,value,callback)=>{ if(/^(\d|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])\.(\d|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])\.(\d|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])\.(\d|[1-9]\d|1\d{2}|2[0-4]\d|25[0-5])$/ .test(value......
  • 论文笔记-Machine learning based flow regime recognition in helically coiled tube
    对象:进行了螺旋线圈中的自动两相流模式识别方法:X射线照相的空隙率测量数据+聚类+KNN、RF、SVM目标:模式识别关注特征:结果:聚类分类:模型是随机森林(RF)分类器、KNN分类器和SVM(参见第1节)。为了优化超参数并估计分类器精度,所有模型均采用嵌套5×5交叉验证方案,如图1所示。......
  • Enhancing ID and Text Fusion via Alternative Training in Session-based Recommend
    目录概MotivationAlterRec代码LiJ.,HanH.,ChenZ.,ShomerH.,JinW.,JavariA.andTangJ.EnhancingIDandtextfusionviaalternativetraininginsession-basedrecommendation.2024.概作者“发现”多模态推荐中ID和文本模态的结合做的并不好,于是乎提出......
  • 论文笔记-Two-phase flow regime identification based on the liquid-phase velocity
    对象:液相速度信息方法:CNN、LSTM、SVM目标:实现了水平管道内两相流态识别关注特征:从速度时间序列数据中提取的统计特征:均值、均方根和功率谱密度、最大速度比和最大速度差比结果:SVM-93.1%,CNN-94%,LSTM-不佳73.3%LSTM:总共使用了300秒的速度数据,然后将其分为180秒用于训练和......
  • 实时动态规则(55)规则发布平台后端开发(5) 规则模型开发(4)rulemodel_03_涉及事件时间
    0涉及架构 注意:以下代码,都是根据一个特定规则模型: rulemodel_03_caculator 来进行开发的不同的规则模型,如下功能代码需要进行不同的开发RuleModel_03 这个规则模型的特点是:拥有事件间隔时间1规则参数结构规范{"ruleModelId":"3","ruleId":"m3-r01",......
  • 为什么自动驾驶领域发论文都是用强化学习算法,但是实际公司里却没有一家使用强化学习算
    为什么自动驾驶领域发论文都是用强化学习算法,但是实际公司里却没有一家使用强化学习算法?——(特斯拉今年年初宣布推出实际上第一款纯端到端的自动驾驶系统,全部使用强化算法,替换掉原有的30万行C++的rule-based代码)给出一个自己比较认可的答案:https://www.zhihu.com/question/54......