1.使用tf2onnx工具,把saved model转换为tf的graph def(不带function,也就是tf1的计算图)
https://github.com/onnx/tensorflow-onnx/blob/v1.9.3/tf2onnx/tf_loader.py
# -*- coding: utf-8 -*- import os import multiprocessing from typing import List, Dict try: from tf2onnx import tf_loader except ImportError: # install tf2onnx import subprocess subprocess.call(["sudo", "/usr/bin/python3", "-m", "pip", "install", "tf2onnx==1.9.3"]) from tf2onnx import tf_loader from tensorflow.core.protobuf import meta_graph_pb2, config_pb2 from tensorflow.python.grappler import tf_optimizer from google.protobuf import text_format from tensorflow.core.protobuf import rewriter_config_pb2 import tensorflow as tf DEFAULT_OPTIMIZERS = ('dependency',) def run_graph_grappler(graph, inputs, outputs, optimizers=DEFAULT_OPTIMIZERS): tf.compat.v1.disable_eager_execution() config = config_pb2.ConfigProto() config.graph_options.rewrite_options.optimizers.extend(optimizers) config.graph_options.rewrite_options.meta_optimizer_iterations = rewriter_config_pb2.RewriterConfig.ONE meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph) fetch_collection = meta_graph_pb2.CollectionDef() fetch_collection.node_list.value.extend(inputs) fetch_collection.node_list.value.extend(outputs) meta_graph.collection_def['train_op'].CopyFrom(fetch_collection) graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) return graph_def def is_control_dependency(node_name: str) -> bool: return node_name.startswith("^") def is_saved_model_control_node(node: tf.compat.v1.NodeDef) -> bool: ''' control node looks like: node { name: "Func/StatefulPartitionedCall/input_control_node/_0" op: "NoOp" input: "^deep_fm4_1024" input: "^deep_fm4_1552" } such nodes should be removed if we need to inference the subgraph ''' if node.op != "NoOp": return False if "input_control_node" not in node.name and "output_control_node" not in node.name: return False return all([is_control_dependency(input_name) for input_name in node.input]) def fix_saved_model_control_dependency(graph_def: tf.compat.v1.GraphDef): saved_model_control_nodes = set() # collect input_control_node for node in graph_def.node: if is_saved_model_control_node(node): saved_model_control_nodes.update(["^" + node.name]) # remove input_control_node dependencies from normal node inputs for node in graph_def.node: for i in reversed(range(len(node.input))): input_name = node.input[i] if input_name in saved_model_control_nodes: # safe deletion in iteration node.input[i], node.input[-1] = node.input[-1], node.input[i] del node.input[-1] return graph_def def fix_output_name(graph_def: tf.compat.v1.GraphDef, outputs: List[str], alias_map: Dict[str, str]): ''' outputs looks like: ['Identity:0', 'Identity_1:0', 'Identity_2:0', 'Identity_3:0', 'Identity_4:0', 'Identity_5:0', 'Identity_6:0', 'Identity_7:0'] alias_map looks like: {'Identity:0': 'logit_dislike', 'Identity_1:0': 'logit_like', 'Identity_2:0': 'logit_play', 'Identity_3:0': 'logit_staytime', 'Identity_4:0': 'pred_dislike', 'Identity_5:0': 'pred_like', 'Identity_6:0': 'pred_play', 'Identity_7:0': 'pred_staytime'} apply alias name inplace so that serving won't need alias mapping ''' for node in graph_def.node: tensor_name = node.name + ":0" if tensor_name in outputs: node.name = alias_map[tensor_name] return graph_def def convert_saved_model_to_graph_def(export_dir): print("Start to convert saved model to graph def pbtxt", flush=True) assert(os.path.exists("{}/saved_model.pb".format(export_dir))) frozen_graph_def, inputs, outputs, alias_map = tf_loader.from_saved_model( export_dir, input_names = None, output_names = None, return_tensors_to_rename=True) # remove trival Identity and control dependency for readability frozen_graph_def = run_graph_grappler(frozen_graph_def, inputs=inputs, outputs=outputs) frozen_graph_def = fix_saved_model_control_dependency(frozen_graph_def) frozen_graph_def = fix_output_name(frozen_graph_def, outputs, alias_map) graph_def_file = "{}/graph.pbtxt".format(export_dir) with open(graph_def_file, 'w') as f: f.write(text_format.MessageToString(frozen_graph_def)) print("Convert saved model to graph def success", flush=True)
----2022.09.28补充--------------
通过阅读tf_loader的远吗,发现在转换成graph的时候,已经做了grappler的优化,取的是constfold, dependency,如果取constfold的话,会导致中间节点被折叠起来,不想被折叠的话,禁止使用constflod优化方法就可以了。但是需要改tf_loader.py的源码(目前没找到能仅仅替换import的模块,里面某个函数的方法)
标签:node,name,graph,savedmodel,Identity,input,Tensorflow,def From: https://www.cnblogs.com/deepllz/p/16257194.html