首页 > 其他分享 >UNet pytorch模型转ONNX模型完整code

UNet pytorch模型转ONNX模型完整code

时间:2023-11-14 16:55:39浏览次数:42  
标签:code img onnx 模型 torch pytorch path model self

 1 import os
 2 import torch
 3 import numpy as np
 4 from Unet import UNET
 5 os.environ["CUDA_VISIBLE_DEVICE"] = ""
 6 
 7 def main():
 8     demo = Demo(model_path="/xxx.pth.tar", output="pathto/xxx.onnx")
 9     demo.inference()
10     check_onnx(onnx_pth="path to xxx.onnx")
11 
12 
13 
14 #检查onnx模型
15 def check_onnx(onnx_pth):
16     import onnx
17     #load the ONNX model
18     model = onnx.load(onnx_pth)
19     #check the IR is well formed
20     onnx.checker.check_model(model)
21     #print a human readable representation of graph
22     print(onnx.helper.printable_graph(model.graph))
23 
24 class WrappedModel(torch.nn.Module):
25     def __init__(self,model):
26         super().__init__()
27         self.model =model
28 
29     def forward(self,x):31         outs=self.model(x)
32         new_outs=torch.sigmoid(outs)
33         return new_outs
34 
35 
36 class Demo():
37     def __init__(self,model_path,output):
38         self.model_path =model_path
39         self.output_path = output
40 
41     def init_torch_tensor(self):
42         self.device = 'cpu'#torch.device('cpu')
43         torch.set_default_tensor_type('torch.FloatTensor')
44         #use gpu or not
45         # if torch.cuda.is_available():
46         #     self.device = torch.device('cuda')
47         #     torch.set_default_tensor_type('torch.FloatTensor')
48         # else:
49         #     self.device = torch.device('cpu')
50         #     torch.set_default_tensor_type('torch.FloatTensor')
51     
52     def init_model(self,in_channels,out_channels):
53         model = UNET(in_channels=in_channels, out_channels=out_channels).to(self.device)#to('cuda')
54         return model
55     
56     def resume(self, model, path):
57         if not os.path.exists(path):
58             print("Checkpoint not found:" + path)
59             return
60         states = torch.load(path, map_location=self.device)#
61         model.load_state_dict(states["state_dict"],strict=False)#states有两个key_value"state_dict","optimizer"
62         
63         model_sig = WrappedModel(model)
64         print("Resume from " + path)
65         return model_sig
66 
67     def inference(self):
68         #use gpu or cpu
69         self.init_torch_tensor()
70         #加载网络模型
71         model = self.init_model(in_channels=3,out_channels=2)
72         model_sig=self.resume(model, self.model_path)
73         #设置model的模式
74         model_sig.eval()
75         #设置输入
76         img = np.random.randint(0,255, size=(512,512,3),dtype=np.uint8)
77         img = img.astype(np.float32)
78         img = img / 255#(img / 255. - 0.5)/0.5
79         img = img.transpose((2,0,1)) #C H W 
80         img = torch.from_numpy(img).unsqueeze(0).float()
81         #img = torch.randn(1,3,512,512)
82         '''
83         设置动态可变维度
84         KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。
85         VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。
86         '''
87         dynamic_axes = {'input':{0: 'batch_size', 2: 'height', 3: 'width'},
88                         'output': {0:'batch_size', 2: 'height', 3: 'width'}}
89         with torch.no_grad():
90             img = img.to(self.device)
91             torch.onnx.export(model_sig, img, self.output_path, input_names=['input'],
92                                 output_names=['output'], dynamic_axes=dynamic_axes, keep_initializers_as_inputs=False,export_params=True,
93                                 verbose=True, opset_version=11)
94 
95 if __name__ == '__main__':
96     main()

 

标签:code,img,onnx,模型,torch,pytorch,path,model,self
From: https://www.cnblogs.com/zzc-Andy/p/17832001.html

相关文章

  • Python字符的编码encode和解码decode
    相关阅读:字符集(CharacterSet)和编码(Encoding)的历史演化 Python字符的编码encode和解码decode进行编码str.encode("编码") 进行解码bytes.decode("编码")  s="周杰伦"bs1=s.encode("gbk")#b'xxxx'bytes类型bs2=s.encode("utf-8"......
  • Codeforces Round 809 (Div. 2) D1. Chopping Carrots (Easy Version) 题解
    题意CodeforcesRound809(Div.2)D1.ChoppingCarrots(EasyVersion)给两个整数\(n,k\),一个数组\(a\),要求构造一个同样长度的数组\(p\),使得\(\max\limits_{1\lei\len}\left(\left\lfloor\frac{a_i}{p_i}\right\rfloor\right)-\min\limits_{1\lei\l......
  • Codeforces Round 906 (Div. 2)
    A.简单题B.简单题C.比赛时没做出来,赶着回宿舍,过了几天来补发现很简单秒掉D.Doremy'sConnectingPlan给定n个结点的图,每个点有一个权值a[i],开始时图上没有边,如果与点i相邻的点(包括点i)的权值的和记为Sum_i.给定一个常数c,如果Sum_i+Sum_j>=ijc,则可以在i和j上......
  • vscode下载慢
    官网下载链接https://az764295.vo.msecnd.net/stable/6c3e3dba23e8fadc360aed75ce363ba185c49794/VSCodeUserSetup-x64-1.81.1.exe1.81.1版本镜像源的下载链接https://vscode.cdn.azure.cn/stable/6c3e3dba23e8fadc360aed75ce363ba185c49794/VSCodeUserSetup-x64-1.81.1.ex......
  • 解决 keras 首次装载预训练模型VGG16 时下载失败问题
    解决:Exception:URLfetchfailureonhttps://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5:None--[Errno104]Connectionresetbypeer解决方案:1、先将数据集单独下载下来:models/vgg16_weights_tf_d......
  • 浅谈JVM Instruction Set (Opcode)
    浅谈JVMInstructionSet(Opcode)1.背景日常开发中,遇到一个潜藏bug的java代码,借此简单回顾一下JVMInstructionSet(Opcode)知识。问题demo代码如下:publicclassBugDemo{publicstaticvoidmain(String[]args){//模拟用户输入(具有不可预测性),假设......
  • Xcode 展示failed to prepare the device for development
    首先打开链接找到https://gitee.com/Han0/iOSDeviceSupport 找到对应版本,解压其次打开终端输入 open/Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/DeviceSupport然后将解压后的文件夹放进去,即可重启xcode......
  • 结合大语言模型与亚马逊云科技基础服务,构建知识库智能搜索问答方案
     背景 本篇主要介绍LangChain和开源大语言模型集成,结合亚马逊云科技的云基础服务,构建基于企业知识库的智能搜索问答方案。  LangChain介绍 LangChain是一个利用大语言模型的能力开发各种下游应用的开源框架,它的核心理念是为各种大语言模型应用实现通用的接口,简化大语言模型应......
  • 为啥Decoder-Only这条路线效果最好?
    https://arxiv.org/pdf/2304.13712.pdf这篇论文中有个现代大型语言模型(LLM)的演变树,可以看出:同一分支上的模型关系更为紧密。图说明:基于Transformer模型以非灰色显示:decoder-only模型在蓝色分支,encoder-only模型在粉色分支,encoder-decoder模型在绿色分支。模型在......
  • VSCode 中 Json 文件介绍
    VisualStudioCode官方文档1.Json配置文件EditingJSONwithVisualStudioCodesettings.json分类defaultsettings.json:只读格式,相当于官方的参考文档;settings.json:自定义形式,优先级大于默认的settings.json文件,ctrl+shift+o查看默认提供的格式,而后自定......