首页 > 其他分享 >keras模型转换onnx模型

keras模型转换onnx模型

时间:2024-07-03 14:29:37浏览次数:20  
标签:Path keras onnx 模型 graph input output model path

1.keras一般先转换为tensorflow的pb格式,然后再使用tf2onnx转换。
2.tensorflow转换为onnx

keras转换tensorflow:
参考:https://github.com/amir-abdi/keras_to_tensorflow

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from pathlib import Path
from absl import app
from absl import flags
from absl import logging
import keras
from keras import backend as K
from keras.models import model_from_json, model_from_yaml

K.set_learning_phase(0)
FLAGS = flags.FLAGS

flags.DEFINE_string('input_model', None, 'Path to the input model.')
flags.DEFINE_string('input_model_json', None, 'Path to the input model '
                                              'architecture in json format.')
flags.DEFINE_string('input_model_yaml', None, 'Path to the input model '
                                              'architecture in yaml format.')
flags.DEFINE_string('output_model', None, 'Path where the converted model will '
                                          'be stored.')
flags.DEFINE_boolean('save_graph_def', False,
                     'Whether to save the graphdef.pbtxt file which contains '
                     'the graph definition in ASCII format.')
flags.DEFINE_string('output_nodes_prefix', None,
                    'If set, the output nodes will be renamed to '
                    '`output_nodes_prefix`+i, where `i` will numerate the '
                    'number of of output nodes of the network.')
flags.DEFINE_boolean('quantize', False,
                     'If set, the resultant TensorFlow graph weights will be '
                     'converted from float into eight-bit equivalents. See '
                     'documentation here: '
                     'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms')
flags.DEFINE_boolean('channels_first', False,
                     'Whether channels are the first dimension of a tensor. '
                     'The default is TensorFlow behaviour where channels are '
                     'the last dimension.')
flags.DEFINE_boolean('output_meta_ckpt', False,
                     'If set to True, exports the model as .meta, .index, and '
                     '.data files, with a checkpoint file. These can be later '
                     'loaded in TensorFlow to continue training.')

flags.mark_flag_as_required('input_model')
flags.mark_flag_as_required('output_model')

def load_model(input_model_path, input_json_path=None, input_yaml_path=None):
    if not Path(input_model_path).exists():
        raise FileNotFoundError(
            'Model file `{}` does not exist.'.format(input_model_path))
    try:
        model = keras.models.load_model(input_model_path)
        return model
    except FileNotFoundError as err:
        logging.error('Input mode file (%s) does not exist.', FLAGS.input_model)
        raise err
    except ValueError as wrong_file_err:
        if input_json_path:
            if not Path(input_json_path).exists():
                raise FileNotFoundError(
                    'Model description json file `{}` does not exist.'.format(
                        input_json_path))
            try:
                model = model_from_json(open(str(input_json_path)).read())
                model.load_weights(input_model_path)
                return model
            except Exception as err:
                logging.error("Couldn't load model from json.")
                raise err
        elif input_yaml_path:
            if not Path(input_yaml_path).exists():
                raise FileNotFoundError(
                    'Model description yaml file `{}` does not exist.'.format(
                        input_yaml_path))
            try:
                model = model_from_yaml(open(str(input_yaml_path)).read())
                model.load_weights(input_model_path)
                return model
            except Exception as err:
                logging.error("Couldn't load model from yaml.")
                raise err
        else:
            logging.error(
                'Input file specified only holds the weights, and not '
                'the model definition. Save the model using '
                'model.save(filename.h5) which will contain the network '
                'architecture as well as its weights. '
                'If the model is saved using the '
                'model.save_weights(filename) function, either '
                'input_model_json or input_model_yaml flags should be set to '
                'to import the network architecture prior to loading the '
                'weights. \n'
                'Check the keras documentation for more details '
                '(https://keras.io/getting-started/faq/)')
            raise wrong_file_err

def main(args):
    # If output_model path is relative and in cwd, make it absolute from root
    output_model = FLAGS.output_model
    if str(Path(output_model).parent) == '.':
        output_model = str((Path.cwd() / output_model))

    output_fld = Path(output_model).parent
    output_model_name = Path(output_model).name
    output_model_stem = Path(output_model).stem
    output_model_pbtxt_name = output_model_stem + '.pbtxt'

    # Create output directory if it does not exist
    #Path(output_model).parent.mkdir(parents=True, exist_ok=True)

    if FLAGS.channels_first:
        K.set_image_data_format('channels_first')
    else:
        K.set_image_data_format('channels_last')

    model = load_model(FLAGS.input_model, FLAGS.input_model_json, FLAGS.input_model_yaml)

    # TODO(amirabdi): Support networks with multiple inputs
    orig_output_node_names = [node.op.name for node in model.outputs]
    if FLAGS.output_nodes_prefix:
        num_output = len(orig_output_node_names)
        pred = [None] * num_output
        converted_output_node_names = [None] * num_output

        # Create dummy tf nodes to rename output
        for i in range(num_output):
            converted_output_node_names[i] = '{}{}'.format(
                FLAGS.output_nodes_prefix, i)
            pred[i] = tf.identity(model.outputs[i],
                                  name=converted_output_node_names[i])
    else:
        converted_output_node_names = orig_output_node_names
    logging.info('Converted output node names are: %s',
                 str(converted_output_node_names))

    sess = K.get_session()
    if FLAGS.output_meta_ckpt:
        saver = tf.train.Saver()
        saver.save(sess, str(output_fld / output_model_stem))

    if FLAGS.save_graph_def:
        tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld),
                             output_model_pbtxt_name, as_text=True)
        logging.info('Saved the graph definition in ascii format at %s',
                     str(Path(output_fld) / output_model_pbtxt_name))

    if FLAGS.quantize:
        from tensorflow.tools.graph_transforms import TransformGraph
        transforms = ["quantize_weights", "quantize_nodes"]
        transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
                                               converted_output_node_names,
                                               transforms)
        constant_graph = graph_util.convert_variables_to_constants(
            sess,
            transformed_graph_def,
            converted_output_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(
            sess,
            sess.graph.as_graph_def(),
            converted_output_node_names)

    graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
                         as_text=False)
    logging.info('Saved the freezed graph at %s',
                 str(Path(output_fld) / output_model_name))


if __name__ == "__main__":
    app.run(main)

标签:Path,keras,onnx,模型,graph,input,output,model,path
From: https://blog.csdn.net/zhilaizhiwang/article/details/140151122

相关文章

  • 使用中转API访问OpenAI大模型的指南
    在现代AI技术的迅速发展中,大语言模型(LLM)如OpenAI的GPT-4变得越来越重要。然而,由于中国地区的网络限制,直接访问海外的API可能会遇到困难。为了方便开发者使用这些强大的工具,我们可以通过中转API地址http://api.wlai.vip来进行访问。本文将介绍如何使用中转API访问OpenAI大模......
  • 多模态大模型入门指南
    作者:林夕,阿里巴巴集团高级算法工程师,专注多模态大模型研究。声明:本文只做分享,版权归原作者,侵权私信删除!原文:https://zhuanlan.zhihu.com/p/682893729内容总结,本篇综述主要介绍和分析了以下几个方面:•概述了MM-LLMs的设计形式,将模型架构分为5个部分:模态编码器、输入......
  • AI绘画Stable Diffusion 3.0大模型解锁AIGC写实摄影:摄影构图与视角关键提示,SD3模型最
    大家好,我是设计师阿威在现实摄影领域中,创作出优秀的摄影图像会涉及很多关键技术要素,如:光影效果、摄影构图(摄影机位置:相机与主体的距离)和摄影角度(相机相对于主体的位置)等的选择。这些核心要素对于AIGC绘图(StableDiffusion1.5/XL、Playground、Midjourney)创作也极为重要......
  • AI预测福彩3D采取888=3策略+和值012路或胆码测试7月3日新模型预测第23弹
            今天咱们继续验证新模型的8码定位=3,重点是预测8码定位=3+和值012+胆码。有些朋友看到我最近几篇文章没有给大家提供缩水后的预测详情,在这里解释下:其实我每篇文章中既有8码定位,也有和值012路,也有胆码排序,这些条件如果命中的话,其实大家完全可以自行使用一些免费的......
  • AI预测体彩排3采取888=3策略+和值012路或胆码测试7月3日升级新模型预测第18弹
            根据前面的预测效果,我对模型进行了重新优化,因为前面的模型效果不是很好。熟悉我的彩友比较清楚,我之前的主要精力是对福彩3D进行各种模型的开发和预测,排三的预测也就是最近1个月才开始搞的。3D的预测,经过对模型的多次修改和完善,最新的模型命中率有了大幅提高,大......
  • 昇思25天学习打卡营第8天|模型权重与 MindIR 的保存加载
    目录导入Python库和模块创建神经网络模型保存和加载模型权重保存和加载MindIR导入Python库和模块        上一章节着重阐述了怎样对超参数予以调整,以及如何开展网络模型的训练工作。在网络模型训练的整个进程当中,事实上我们满怀期望能够留存中间阶段以及最......
  • 利用中转API部署本地大语言模型
    前言在本文中,我们将展示如何使用中转API地址(http://api.wlai.vip)来部署本地的大语言模型(LLM)。我们将使用Llama2聊天模型作为示例,但该方法适用于所有受支持的LLM模型。本文将包括以下几个步骤:安装所需库、启动本地模型、创建索引和查询引擎,并附上示例代码和可能遇到......
  • 外网爆火!真正内行人必看的大模型神书!
    ......
  • The Forest Enemy Pack(2D动画角色游戏模型)
    这个包包含14个适用于platformer和2drpg游戏的动画角色。动画总帧数:1785用于动画的所有精灵都具有透明背景,并准备有1500x1200和750x600两种尺寸。对于每个角色,你也可以找到具有单独身体部位的精灵表,这样你就可以轻松地制作自己的动画。它们有PNG和PSD格式。示例场景包含......
  • 手把手教你:如何在51建模网免费下载3D模型?
    作为国内领先的3D互动展示平台,51建模网不仅汇聚了庞大的3D模型资源库,供用户免费下载,更集成了在线编辑、格式转换、内嵌展示及互动体验等一站式功能,为3D创作者及爱好者搭建起梦想与现实的桥梁。如何在51建模网免费下载3D模型?按照下面几个步骤即可轻松下载:01、筛选可下载模型......