这是一个完整的例子。
使用预训练的resnet50模型,经过tvm优化调整,target=llvm,在cpu上进行推理。最后打印结果是 1 这个索引代表 gold fish
import onnx from tvm.contrib.download import download_testdata from PIL import Image import numpy as np import tvm.relay as relay import tvm from tvm.contrib import graph_executor
model_url = ( # "https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx" )
# model_path = download_testdata(model_url, "resnet50-v2-7.onnx", module="onnx") model_path = "/home/po/.tvm_test_data/onnx/resnet50-v2-7.onnx" onnx_model = onnx.load(model_path)
# Seed numpy's RNG to get consistent results np.random.seed(0)
# img_url = "https://s3.amazonaws.com/model-server/inputs/kitten.jpg" # img_path = download_testdata(img_url, "imagenet_cat.png", module="data") img_path = "/home/po/.tvm_test_data/data/gold-fish.jpg"
# Resize it to 224x224 resized_image = Image.open(img_path).resize((224, 224)) img_data = np.asarray(resized_image).astype("float32")
# Our input image is in HWC layout while ONNX expects CHW input, so convert the array img_data = np.transpose(img_data, (2, 0, 1))
# Normalize according to the ImageNet input specification imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
# Add the batch dimension, as we are expecting 4-dimensional input: NCHW. img_data = np.expand_dims(norm_img_data, axis=0)
# The input name may vary across model types. You can use a tool # like Netron to check input names input_name = "data" shape_dict = {input_name: img_data.shape} target = "llvm"
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params)
dev = tvm.device(str(target), 0) module = graph_executor.GraphModule(lib["default"](dev))
dtype = "float32" module.set_input(input_name, img_data) module.run() output_shape = (1, 1000) tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy() print("predict imgnet index=", np.argmax(tvm_output)) 标签:llvm,resnet50,img,demo,tvm,onnx,input,model,data From: https://www.cnblogs.com/qmjc/p/18412599