我们主要使用torch.onnx.export()函数来实现PyTorch模型到ONNX模型的转换。
import torch
import torchvision.models as models
device = "cpu"
# 加载预训练的ResNet18模型
model =models.resnet18(pretrained=True)
model.eval().to(device)
# 定义输入
input = torch.zeros(1,3,224,224).to(device)
torch.onnx.export(
model,
# 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
(input,),
# 输出的onnx文件路径
"resnet18.onnx",
# 是否打印详细信息
verbose=True,
# 为输入和输出节点指定名称,方便后面查看或者操作
input_names = ["image"],
output_names = ["infer_output"],
# 这里的opset,指各类算子以何种方式导出,对应于symbolic_opset11
opset_version=11,
# 表示它有batch,height,width3个维度是动态的,在onnx中给其赋值为-1
# 通常,我们只设置batch为动态,其他避免动态
dynamic_axes ={
"image":{0:"batch"}, # 只有batch维是动态的
"infer_output":{0:"batch"} # 只有batch维是动态的
# "train_output":{0:"batch"} # 只有batch维是动态的
}
)
print("Done!")
标签:ONNX,模型,torch,batch,PyTorch,onnx,output,model
From: https://blog.csdn.net/m0_51579041/article/details/136683540