1 import tensorflow as tf 2 import onnx 3 import onnxsim 4 import numpy as np 5 import torch 6 from model.facedetector_model import mobilenetv2_yolov3 7 8 #提取pb模型中的参数 9 def extract_params_from_pb(): 10 constant_values = {} 11 with tf.compat.v1.Session() as sess: 12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f: 13 graph_def = tf.compat.v1.GraphDef() 14 graph_def.ParseFromString(f.read()) 15 sess.graph.as_default() 16 tf.import_graph_def(graph_def, name='') 17 # # input 18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0') 19 # # output 20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0') 21 # sess.run(output, feed_dict={'input/input_data:0': inputimage}) 22 23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"] 24 for constant_op in constant_ops: 25 if constant_op.op_def.name == "Const": 26 if "Shape" in constant_op.name or "pred" in constant_op.name: 27 continue 28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0]) 29 return constant_values 30 31 #过滤提取出来的params 32 def filter_params(constant_values): 33 total = 0 34 prompt = [] 35 res = {} 36 forbidden = ['shape','stack'] 37 38 for k,v in constant_values.items(): 39 # filtering some by checking ndim and name 40 if v.ndim<1: continue 41 if v.ndim==1: 42 token = k.split(r'/')[-1] 43 flag = False 44 for word in forbidden: 45 if token.find(word)!=-1: 46 flag = True 47 break 48 if flag: 49 continue 50 51 shape = v.shape 52 cnt = 1 53 for dim in shape: 54 cnt *= dim 55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt)) 56 res[k] = v 57 print(prompt[-1]) 58 total += cnt 59 prompt.append('totaling {}'.format(total)) 60 # print(prompt[-1]) 61 return res 62 63 #将Tensorflow的张量转换成PyTorch的张量 64 def trans_tensor_pb2pth(k,a): 65 66 v = tf.convert_to_tensor(a).numpy() 67 # tensorflow weights to pytorch weights 68 if len(v.shape) == 4: 69 if "depthwise_weights" in k:#防止深度可分离卷积 70 return np.ascontiguousarray(v.transpose(2,3,0,1)) 71 return np.ascontiguousarray(v.transpose(3,2,0,1)) 72 elif len(v.shape) == 2: 73 return np.ascontiguousarray(v.transpose()) 74 return v 75 76 #将pb的对应params名字转换为pth对应参数名 77 def trans_name_pb2pth(trans_weights): 78 model_dict = {} 79 for name,para in trans_weights.items(): 80 name = name.replace('/',".") 81 82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv 83 name = name.replace('weights',"0.weight") 84 name = name.replace('BatchNorm',"1") 85 name = name.replace('gamma',"weight") 86 name = name.replace('beta',"bias") 87 name = name.replace('moving_mean',"running_mean") 88 name = name.replace('moving_variance',"running_var") 89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv. 90 name = name.replace('depthwise.',"0.") 91 name = name.replace('project',"1") 92 name = name.replace('depthwise_weights',"0.weight") 93 name = name.replace('weights',"0.weight") 94 name = name.replace('BatchNorm',"1") 95 name = name.replace('gamma',"weight") 96 name = name.replace('beta',"bias") 97 name = name.replace('moving_mean',"running_mean") 98 name = name.replace('moving_variance',"running_var") 99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_* 100 name = name.replace('expand.',"0.") 101 name = name.replace('depthwise.',"1.") 102 name = name.replace('project',"2") 103 name = name.replace('depthwise_weights',"0.weight") 104 name = name.replace('weights',"0.weight") 105 name = name.replace('BatchNorm',"1") 106 name = name.replace('gamma',"weight") 107 name = name.replace('beta',"bias") 108 name = name.replace('moving_mean',"running_mean") 109 name = name.replace('moving_variance',"running_var") 110 elif "yolo-v3" in name: 111 if "bbox" in name: 112 continue 113 name = name.replace('yolo-v3',"yolo_v3") 114 name = name.replace('weight',"0.weight") 115 name = name.replace('kernel',"weight") 116 name = name.replace('batch_normalization',"1") 117 name = name.replace('gamma',"weight") 118 name = name.replace('beta',"bias") 119 name = name.replace('moving_mean',"running_mean") 120 name = name.replace('moving_variance',"running_var") 121 print(name) 122 model_dict[name] = torch.Tensor(para) 123 return model_dict 124 125 #将pb参数copy给pth模型 126 def copy_pbParams2pthParams(): 127 constant_values = extract_params_from_pb() 128 TF_weights = filter_params(constant_values) 129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() } 130 131 #创建pytorch模型 132 PyTorchModel = mobilenetv2_yolov3() 133 model_dict = trans_name_pb2pth(trans_weights) 134 # model_dict = PyTorchModel.state_dict() 135 # for name in model_dict.keys(): 136 # print(name) 137 PyTorchModel.load_state_dict(model_dict) 138 PyTorchModel.cuda().eval() 139 dummy_input = torch.rand(1,1,224,224,device="cuda").float() 140 # out = PyTorchModel(dummy_input) 141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11) 142 print("====> Simplifying...") 143 model_opt,_ = onnxsim.simplify("P3mNet.onnx") 144 onnx.save(model_opt, 'P3mNet_sim.onnx') 145 print("onnx model simplify Ok!") 146 copy_pbParams2pthParams()
标签:sess,constant,name,graph,模型,torch,pb,import,op From: https://www.cnblogs.com/zukang/p/16778294.html