定义computation
整体流程类似于tvm的计算描述
定义输入、输出tensor,指定名称、数据类型和shape
a = tensor_input('a', dtype='float32', shape=[10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])
使用compute定义计算,指定名称、shape、计算表达式
b = compute('copy', shape=[10], fcompute=lambda i: a[i])
语义等价于
for i1 in range(10):
b[i1] = a[i1]
此外还有一些参数,使用这些参数可以指定reduce等计算。
封装Task
将计算封装为Task,然后使用hidet提供的rule-based调度器自动将计算生成代码
一个task由名称、tensor输入和输出组成,对应于计算的输入输出
from typing import List
import hidet
from hidet.ir.task import Task
def run_task(task: Task, inputs: List[hidet.Tensor]):
"""Run given task and print inputs and outputs"""
from hidet.runtime import CompiledTask
# build the task
func: CompiledTask = hidet.drivers.build_task(task, target='cpu')
# run the compiled task
outputs = func.run_async(inputs)
print('Task:', task.name)
print('Inputs:')
for tensor in inputs:
print(tensor)
print('Output:')
for tensor in outputs:
print(tensor)
print()
使用build_task将task lowering到可执行函数,主要包含以下流程
- 根据device等信息调度task到scheduler
- scheduler将task下降到IRModule
- 优化并继续下降IRModule
- 根据device进行代码生成
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else
class AbsTask(Task):
def __init__(self, input):
out = compute(
name='out',
shape=input.shape,
fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
)
super().__init__(
name='abs',
inputs=[input],
outputs=[out]
)
封装算子类,输入是hidet Tensor
from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like
class AbsOp(Operator):
def __init__(self, input: Tensor):
super().__init__(
inputs=[input],
attributes={},
task=AbsTask(
input_like(input, 'input')
)
)
def abs(input: Tensor) -> Tensor:
return AbsOp(input).outputs[0]
输出结果在outputs中,通过在torch compile时指定后端为hidet,就能实现runtime的算子替换。
import hidet
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else
hidet.option.cache_dir('./outs/cache')
op_device = "cuda"
class AbsTask(Task):
def __init__(self, input):
out = compute(
name='out',
shape=input.shape,
fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
)
super().__init__(
name='abs',
inputs=[input],
outputs=[out]
)
from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like
class AbsOp(Operator):
def __init__(self, input: Tensor):
super().__init__(
inputs=[input],
attributes={},
task=AbsTask(
input_like(input, 'input')
)
)
import torch
from torch import nn
from hidet.graph.frontend.torch.interpreter import (
register_function,
)
@register_function(torch.abs)
def abs_demo(input: Tensor) -> Tensor:
return AbsOp(input).outputs[0]
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.abs(x)
return x
def run_demo():
input = hidet.randn([2, 3, 3])
print(input)
torch_ref = torch.randn([2, 3, 3], device=op_device)
print("input: ", torch_ref)
m = Model()
m = torch.compile(m, backend='hidet')
print("output: ", m(torch_ref))
run_demo()
input: tensor([[[ 0.3670, 0.3044, -0.8114],
[-0.3331, 0.3331, -0.1969],
[-1.4178, 0.3059, 0.4833]],
[[-0.6634, 2.7176, -0.4525],
[ 0.9879, -0.0581, 0.9540],
[ 0.1877, -0.2522, 0.0652]]], device='cuda:0')
output: tensor([[[0.3670, 0.3044, 0.8114],
[0.3331, 0.3331, 0.1969],
[1.4178, 0.3059, 0.4833]],
[[0.6634, 2.7176, 0.4525],
[0.9879, 0.0581, 0.9540],
[0.1877, 0.2522, 0.0652]]], device='cuda:0')
标签:__,task,based,torch,rule,hidet,input,import
From: https://www.cnblogs.com/ddl789/p/18206794