首页 > 其他分享 >将pb模型参数提取转成torch模型

将pb模型参数提取转成torch模型

时间:2022-10-11 10:12:23浏览次数:57  
标签:sess constant name graph 模型 torch pb import op

  1 import tensorflow as tf
  2 import onnx
  3 import onnxsim
  4 import numpy as np
  5 import torch
  6 from model.facedetector_model import mobilenetv2_yolov3
  7 
  8 #提取pb模型中的参数
  9 def extract_params_from_pb():
 10     constant_values = {}
 11     with tf.compat.v1.Session() as sess:
 12         with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
 13             graph_def = tf.compat.v1.GraphDef()
 14             graph_def.ParseFromString(f.read())
 15         sess.graph.as_default()
 16         tf.import_graph_def(graph_def, name='')
 17         # # input
 18         # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
 19         # # output
 20         # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
 21         # sess.run(output, feed_dict={'input/input_data:0': inputimage})
 22 
 23         constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
 24         for constant_op in constant_ops:
 25             if constant_op.op_def.name == "Const":
 26                 if "Shape" in constant_op.name or "pred" in constant_op.name:
 27                     continue
 28                 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
 29     return constant_values
 30 
 31 #过滤提取出来的params
 32 def filter_params(constant_values):
 33     total = 0
 34     prompt = []
 35     res = {}
 36     forbidden = ['shape','stack']
 37     
 38     for k,v in constant_values.items():
 39         # filtering some by checking ndim and name
 40         if v.ndim<1: continue
 41         if v.ndim==1:
 42             token = k.split(r'/')[-1]
 43             flag = False
 44             for word in forbidden:
 45                 if token.find(word)!=-1:
 46                     flag = True
 47                     break
 48             if flag:
 49                 continue
 50 
 51         shape = v.shape
 52         cnt = 1
 53         for dim in shape:
 54             cnt *= dim
 55         prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
 56         res[k] = v
 57         print(prompt[-1])
 58         total += cnt
 59     prompt.append('totaling {}'.format(total))
 60     # print(prompt[-1])
 61     return res
 62 
 63 #将Tensorflow的张量转换成PyTorch的张量
 64 def trans_tensor_pb2pth(k,a):
 65  
 66     v = tf.convert_to_tensor(a).numpy()
 67     # tensorflow weights to pytorch weights
 68     if len(v.shape) == 4:
 69         if "depthwise_weights" in k:#防止深度可分离卷积
 70             return np.ascontiguousarray(v.transpose(2,3,0,1))
 71         return np.ascontiguousarray(v.transpose(3,2,0,1))
 72     elif len(v.shape) == 2:
 73         return np.ascontiguousarray(v.transpose())
 74     return v
 75 
 76 #将pb的对应params名字转换为pth对应参数名
 77 def trans_name_pb2pth(trans_weights):
 78     model_dict = {}
 79     for name,para in trans_weights.items():
 80         name = name.replace('/',".")
 81         
 82         if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
 83             name = name.replace('weights',"0.weight")
 84             name = name.replace('BatchNorm',"1")
 85             name = name.replace('gamma',"weight")
 86             name = name.replace('beta',"bias")
 87             name = name.replace('moving_mean',"running_mean")
 88             name = name.replace('moving_variance',"running_var")
 89         elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
 90             name = name.replace('depthwise.',"0.")
 91             name = name.replace('project',"1")
 92             name = name.replace('depthwise_weights',"0.weight")
 93             name = name.replace('weights',"0.weight")
 94             name = name.replace('BatchNorm',"1")
 95             name = name.replace('gamma',"weight")
 96             name = name.replace('beta',"bias")
 97             name = name.replace('moving_mean',"running_mean")
 98             name = name.replace('moving_variance',"running_var")
 99         elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100             name = name.replace('expand.',"0.")
101             name = name.replace('depthwise.',"1.")
102             name = name.replace('project',"2")
103             name = name.replace('depthwise_weights',"0.weight")
104             name = name.replace('weights',"0.weight")
105             name = name.replace('BatchNorm',"1")
106             name = name.replace('gamma',"weight")
107             name = name.replace('beta',"bias")
108             name = name.replace('moving_mean',"running_mean")
109             name = name.replace('moving_variance',"running_var")
110         elif "yolo-v3" in name:
111             if "bbox" in name:
112                 continue
113             name = name.replace('yolo-v3',"yolo_v3")
114             name = name.replace('weight',"0.weight")
115             name = name.replace('kernel',"weight")
116             name = name.replace('batch_normalization',"1")
117             name = name.replace('gamma',"weight")
118             name = name.replace('beta',"bias")
119             name = name.replace('moving_mean',"running_mean")
120             name = name.replace('moving_variance',"running_var")
121         print(name)
122         model_dict[name] = torch.Tensor(para)
123     return model_dict
124 
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127     constant_values = extract_params_from_pb()
128     TF_weights = filter_params(constant_values)
129     trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130 
131     #创建pytorch模型
132     PyTorchModel = mobilenetv2_yolov3()
133     model_dict = trans_name_pb2pth(trans_weights)
134     # model_dict = PyTorchModel.state_dict()
135     # for name in model_dict.keys():
136     #     print(name)
137     PyTorchModel.load_state_dict(model_dict)
138     PyTorchModel.cuda().eval()
139     dummy_input = torch.rand(1,1,224,224,device="cuda").float() 
140     # out = PyTorchModel(dummy_input)
141     torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142     print("====> Simplifying...")
143     model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144     onnx.save(model_opt, 'P3mNet_sim.onnx')
145     print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()

 

标签:sess,constant,name,graph,模型,torch,pb,import,op
From: https://www.cnblogs.com/zukang/p/16778294.html

相关文章

  • 利用依存句法分析和GCN网络进行情感分析的模型
    1.Aspect-LevelSentimentAnalysisViaConvolutionoverDependencyTree(EMNLP2019)模型将句子的依存树进行输入,然后经过Bi-LSTM进行编码,之后再经过GCN网络进一步增强,目......
  • 【 云原生 | kubernetes 】资源对象 - 控制器模型之Deployment
    Deployment概述Kubernetes中的一个控制器模式,最常用于部署无状态服务的方式。Deployment控制器实际操纵的是ReplicaSet对象,而不是Pod对象。保证系统中的Pod数量永......
  • 学习常用模型及算法1.模拟退火算法
    title:学习常用模型及算法1.模拟退火算法excerpt:学习数学建模常用模型及算法tags:[数学建模,matlab]categories:[学习,数学建模]index_img:https://picture-......
  • 学习常用模型及算法4.元胞自动机
    title:学习常用模型及算法4.元胞自动机excerpt:学习数学建模常用模型及算法tags:[数学建模,matlab]categories:[学习,数学建模]index_img:https://picture-st......
  • PyTorch学习笔记
    #########################################################有关PyTorch一些学习笔记,目前笔记并不全面,只是针对性记录一些对应地铁预测中运用的的原理,函数,方法(有些没有使......
  • Class4 隐马尔科夫模型HMM
    title:Class4隐马尔科夫模型HMMexcerpt:HMM的两个基本假设!tags:[语音识别,ASR,HMM,ForwaedAlgorithm,BackwaedAlgorrithm,ViterbiAlgorrithm]categories:......
  • 28. JS DOM(文档对象模型)
    1.前言文档对象模型(DocumentObjectModel,简称DOM),是一种与平台和语言无关的模型,用来表示HTML或XML文档。文档对象模型中定义了文档的逻辑结构,以及程序访问和操作文......
  • 【动态规划】数字三角形模型
    纯数字三角形模型状态表示:(i,j)表示到达当前位置的最后一步如何,共有两种状态f[i-1][j],f[i][j]状态转移:f[i][j]=max(f[i-1][j-1]+q[i][j],f[i][j......
  • 深圳华锐视点:车展场景3d模型获得汽车全面信息
    作为中国经济与金融中心,上海市在各方面发展都稳居国家前列,上海进出口更是位居世界城市之首,受疫情影响,采购、供应商之间多采用线上联系,必然要借助全面立体、逼真的3d模......
  • 混合高斯模型与帧差法结合的算法
    针对的对象:背景图像中有物体的运动状态发生改变(一般是速度突变,比如从静止到运动,使得背景变化较大),改进算法前后对比背景图像的更新速度问题。文献举例:比如在文献1中P37图3.5(a......