首页 > 其他分享 >如何将onnx稳定的转换为tensorflow, 甚至转换为tflite(float32/int8)

如何将onnx稳定的转换为tensorflow, 甚至转换为tflite(float32/int8)

时间:2024-06-05 14:24:17浏览次数:15  
标签:转换 tflite onnx args pypi defaults model dir

做模型部署边缘设备的时候,我们经常会遇到特定格式的要求。但常见的onnx2tf很多时候都不能满足我们的要求。因此,记录一下我的操作过程。

  • 1. 环境:(linux18.04)
  • # Name                    Version                   Build  Channel
    _libgcc_mutex             0.1                        main    defaults
    _openmp_mutex             5.1                       1_gnu    defaults
    absl-py                   0.15.0                   pypi_0    pypi
    addict                    2.4.0                    pypi_0    pypi
    altgraph                  0.17.4                   pypi_0    pypi
    array-record              0.4.0                    pypi_0    pypi
    astunparse                1.6.3                    pypi_0    pypi
    beautifulsoup4            4.11.2                   pypi_0    pypi
    ca-certificates           2024.3.11            h06a4308_0    defaults
    cachetools                5.3.3                    pypi_0    pypi
    certifi                   2024.2.2                 pypi_0    pypi
    charset-normalizer        3.3.2                    pypi_0    pypi
    click                     8.1.7                    pypi_0    pypi
    colorama                  0.4.6                    pypi_0    pypi
    defusedxml                0.7.1                    pypi_0    pypi
    dnspython                 2.6.1                    pypi_0    pypi
    editdistance              0.8.1                    pypi_0    pypi
    etils                     1.3.0                    pypi_0    pypi
    fast-ctc-decode           0.3.6                    pypi_0    pypi
    filelock                  3.14.0                   pypi_0    pypi
    flatbuffers               1.12                     pypi_0    pypi
    fsspec                    2024.5.0                 pypi_0    pypi
    future                    1.0.0                    pypi_0    pypi
    gast                      0.4.0                    pypi_0    pypi
    google-auth               2.29.0                   pypi_0    pypi
    google-auth-oauthlib      0.4.6                    pypi_0    pypi
    google-pasta              0.2.0                    pypi_0    pypi
    googleapis-common-protos  1.63.0                   pypi_0    pypi
    grpcio                    1.34.1                   pypi_0    pypi
    h5py                      3.1.0                    pypi_0    pypi
    huggingface-hub           0.23.2                   pypi_0    pypi
    hyperopt                  0.1.2                    pypi_0    pypi
    idna                      3.7                      pypi_0    pypi
    imageio                   2.34.1                   pypi_0    pypi
    importlib-metadata        7.1.0                    pypi_0    pypi
    importlib-resources       6.4.0                    pypi_0    pypi
    joblib                    1.4.2                    pypi_0    pypi
    jstyleson                 0.0.2                    pypi_0    pypi
    keras-nightly             2.5.0.dev2021032900          pypi_0    pypi
    keras-preprocessing       1.1.2                    pypi_0    pypi
    libedit                   3.1.20230828         h5eee18b_0    defaults
    libffi                    3.2.1             hf484d3e_1007    defaults
    libgcc-ng                 11.2.0               h1234567_1    defaults
    libgomp                   11.2.0               h1234567_1    defaults
    libstdcxx-ng              11.2.0               h1234567_1    defaults
    markdown                  3.6                      pypi_0    pypi
    markupsafe                2.1.5                    pypi_0    pypi
    ncurses                   6.4                  h6a678d5_0    defaults
    networkx                  2.8.8                    pypi_0    pypi
    nibabel                   5.1.0                    pypi_0    pypi
    nltk                      3.8.1                    pypi_0    pypi
    numpy                     1.19.5                   pypi_0    pypi
    oauthlib                  3.2.2                    pypi_0    pypi
    onnx                      1.13.0                   pypi_0    pypi
    opencv-python             4.5.5.64                 pypi_0    pypi
    openssl                   1.1.1w               h7f8727e_0    defaults
    openvino                  2021.4.2                 pypi_0    pypi
    openvino-dev              2021.4.2                 pypi_0    pypi
    openvino-telemetry        2024.1.0                 pypi_0    pypi
    openvino2tensorflow       1.34.0                   pypi_0    pypi
    opt-einsum                3.3.0                    pypi_0    pypi
    packaging                 24.0                     pypi_0    pypi
    pandas                    1.1.5                    pypi_0    pypi
    parasail                  1.3.4                    pypi_0    pypi
    pillow                    9.4.0                    pypi_0    pypi
    pip                       24.0             py38h06a4308_0    defaults
    progress                  1.6                      pypi_0    pypi
    promise                   2.3                      pypi_0    pypi
    protobuf                  3.20.3                   pypi_0    pypi
    psutil                    5.9.8                    pypi_0    pypi
    py-cpuinfo                9.0.0                    pypi_0    pypi
    pyasn1                    0.6.0                    pypi_0    pypi
    pyasn1-modules            0.4.0                    pypi_0    pypi
    pydicom                   2.4.4                    pypi_0    pypi
    pyinstaller               6.7.0                    pypi_0    pypi
    pyinstaller-hooks-contrib 2024.6                   pypi_0    pypi
    pymongo                   4.7.2                    pypi_0    pypi
    python                    3.8.0                h0371630_2    defaults
    python-dateutil           2.9.0.post0              pypi_0    pypi
    pytz                      2024.1                   pypi_0    pypi
    pywavelets                1.4.1                    pypi_0    pypi
    pyyaml                    6.0.1                    pypi_0    pypi
    rawpy                     0.21.0                   pypi_0    pypi
    readline                  7.0                  h7b6447c_5    defaults
    regex                     2024.5.15                pypi_0    pypi
    requests                  2.32.3                   pypi_0    pypi
    requests-oauthlib         2.0.0                    pypi_0    pypi
    rsa                       4.9                      pypi_0    pypi
    scikit-image              0.19.3                   pypi_0    pypi
    scikit-learn              1.3.2                    pypi_0    pypi
    scipy                     1.5.4                    pypi_0    pypi
    sentencepiece             0.2.0                    pypi_0    pypi
    setuptools                69.5.1           py38h06a4308_0    defaults
    shapely                   2.0.4                    pypi_0    pypi
    six                       1.15.0                   pypi_0    pypi
    soupsieve                 2.5                      pypi_0    pypi
    sqlite                    3.33.0               h62c20be_0    defaults
    tensorboard               2.11.2                   pypi_0    pypi
    tensorboard-data-server   0.6.1                    pypi_0    pypi
    tensorboard-plugin-wit    1.8.1                    pypi_0    pypi
    tensorflow                2.5.3                    pypi_0    pypi
    tensorflow-datasets       4.9.2                    pypi_0    pypi
    tensorflow-estimator      2.5.0                    pypi_0    pypi
    tensorflow-metadata       1.14.0                   pypi_0    pypi
    termcolor                 1.1.0                    pypi_0    pypi
    texttable                 1.6.7                    pypi_0    pypi
    threadpoolctl             3.5.0                    pypi_0    pypi
    tifffile                  2023.7.10                pypi_0    pypi
    tk                        8.6.14               h39e8969_0    defaults
    tokenizers                0.19.1                   pypi_0    pypi
    toml                      0.10.2                   pypi_0    pypi
    torch                     1.12.1                   pypi_0    pypi
    torchvision               0.13.1                   pypi_0    pypi
    tqdm                      4.66.4                   pypi_0    pypi
    typing-extensions         3.7.4.3                  pypi_0    pypi
    urllib3                   2.2.1                    pypi_0    pypi
    werkzeug                  3.0.3                    pypi_0    pypi
    wheel                     0.43.0           py38h06a4308_0    defaults
    wrapt                     1.12.1                   pypi_0    pypi
    xz                        5.4.6                h5eee18b_1    defaults
    yamlloader                1.4.1                    pypi_0    pypi
    zipp                      3.19.0                   pypi_0    pypi
    zlib                      1.2.13               h5eee18b_1    defaults
  • 2. 具体代码:(下面是int8量化)
    #!/usr/bin/env python
    """
    a command line tool to format onnx model from pytorch-onnx to tflite model
    
    """
    import random
    import os
    import tensorflow as tf
    import glob
    import cv2
    import numpy as np
    from tqdm import tqdm
    import argparse
    from pathlib import Path
    import shutil
    from typing import List
    
    
    def parse_args():
        parser = argparse.ArgumentParser(
            description="Formatting PyTorch models to TensorFlow models")
        parser.add_argument(
            "-i", "--input_onnx", type=str, help="an onnx file form pytorch model")
        parser.add_argument(
            "-s", "--shape", type=str, help="input image size (height, width)")
        parser.add_argument(
            "-o", "--output_dir", type=str, default="./", help="model output dir")
        parser.add_argument(
            "-t", "--tflite_file", type=str, help="output tflite file name")
        parser.add_argument(
            "-d", "--dataset", type=str, help="represent dataset")
        parser.add_argument(
            "-n", "--num_present_images", type=int, default=100,
            help="number of represent images for tflite quantization",
        )
        args = parser.parse_args()
        return args
    
    
    def convert_onnx2tensorflow(args):
        modify_xml_func = None
        if args.modify and args.modify_model == "yolox":
            modify_xml_func = mo_yolox_ov_xml
        return onnx2tensorflow(args, modify_xml_func)
    
    
    def onnx2tensorflow(args, modify_xml_func=None):
        print(f"CWD:{os.getcwd()}")
        output_dir = args.output_dir
        onnx_model = args.input_onnx
        h, w = eval(args.shape)
        input_shape = (1, 3, h, w)
        ov_dir = os.path.join(output_dir, "ov")
    
        shutil.rmtree(ov_dir, ignore_errors=True)
        ov_cmd = (
            f"mo  \
            --input_model {onnx_model} \
            --input_shape '{input_shape}' \
            --output_dir {ov_dir} \
            --progress \
            --data_type FP32"
        )
        print(ov_cmd)
        assert os.system(ov_cmd) == 0, "failed in converting onnx to openvino"
    
        ov_xml = os.path.join(ov_dir, f"{Path(onnx_model).stem}.xml")
        # add changes for certain models if needed
        if modify_xml_func:
            modify_xml_func(ov_xml)
    
        tf_model_dir = os.path.join(output_dir, "hwc")
        shutil.rmtree(tf_model_dir, ignore_errors=True)
        ov2tf_cmd = (
            f"openvino2tensorflow  \
            --model_path {ov_xml} \
            --model_output_path {tf_model_dir} \
            --output_no_quant_float32_tflite \
            --output_saved_model"
        )
        print(ov2tf_cmd)
        assert os.system(ov2tf_cmd) == 0, \
            "failed in converting openvino to tensorflow"
    
        return tf_model_dir
    
    
    def get_represent_images(path: str, num_present_images: int) -> List[str]:
        direc = Path(path)
        files = list(direc.rglob("*.jpg"))
        if not len(files) > 0:
            files = list(direc.rglob("*.png"))
        if not len(files) > 0:
            files = list(direc.rglob("*.JPEG"))
        if not len(files) > 0:
            raise TypeError("unrecognised img file type")
    
        files = random.sample(files, num_present_images)
        return files
    
    
    def tflite_quantize(args, tf_model_dir):
        """ convert a tensorflow model to tflite model with represent data
    
        :param args: related command line args
        :param tf_model_dir: where the tensorflow model saved
        :return: tflite model in current working dir
        """
        # assert args.tflite_file.split(".")[-1] == "tflite"
        # if os.path.exists(args.tflite_file):
        #     os.remove(args.tflite_file)
    
        files = get_represent_images(args.dataset, args.num_present_images)
        h, w = eval(args.shape)
    
        def representative_dataset():
            for file in tqdm(files):
                # read images in RGB format 
                # assume that the images were trained in RGB format
                # img = cv2.imread(str(file), cv2.IMREAD_GRAYSCALE)[..., ::-1]
                img = cv2.imread(str(file))[..., ::-1]
                img = cv2.resize(img, (w, h))
                # img = np.expand_dims(img, axis=-1)
                img = ((img - 127.5) / 127.5).astype(np.float32)
                img = img[None, ...]
                yield [img]
    
        converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]    # 8bits weight quantization
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        converter.experimental_new_quantizer = False  # it must be false if TF version is not 2.4.1
        converter.representative_dataset = representative_dataset
        tflite_model = converter.convert()
    
        save_tflite_file = args.input_onnx.replace("onnx", "tflite")
        save_tflite_path = os.path.join(
                args.output_dir, os.path.basename(save_tflite_file),
        )
    
        with open(save_tflite_path, 'wb') as f:
            f.write(tflite_model)
    
        return save_tflite_path
    
    
    def mo_yolox_ov_xml(ov_xml):
        """
        special modifications for yolox model
        """
        with open(ov_xml) as f:
            xml = f.read()
        with open(ov_xml, "w") as f:
            f.write(xml.replace('<data axis="2"/>', '<data axis="1"/>'))
    
    
    if __name__ == "__main__":
        onnx_args = parse_args()
        tensorflow_model_dir = onnx2tensorflow(onnx_args)
        tflite_quantize(onnx_args, tensorflow_model_dir)
    
    """
    torch-onnx2tflite.py
    -i yolox-tiny.onnx \
    -s "(320, 320)" \
    -d yolox/datasets/COCO/val2017
    <-t yox_tiny.tflite [alt] >\
    """
  • 3. 其他:
    •   接着,我尝试使用pyinstaller将这个工具固化,生成一个exe来使用,但是似乎并不如意。因为它的工作机制是在cmd命令行输入代码,但是openvino2tensorflow需要添加到环境变量后才能使用,因此,如果不添加这个环境变量,生成的exe也没有什么用处。因此,这个exe的准备似乎没有用处。(我还没有找到解决方法,如果您有解决方法,请告诉我,不甚感激。)

标签:转换,tflite,onnx,args,pypi,defaults,model,dir
From: https://www.cnblogs.com/cainiaoxuexi2017-ZYA/p/18232936

相关文章

  • Python数据类型转换(新)
    目录Python数据类型的转换隐式类型转换显式类型转换Python数据类型的转换数据类型分为1.隐式类型转换2.显式类型转换隐式类型转换在隐式类型转换中,Python会自动将一种数据类型转换为另一种数据类型,不需要认为去干预比如在进行算术运算的时候,较低数据类型(整数)就会......
  • java 数值类型 强制转换注意
    数值类型分别为【byte】,【short】,【int】,【long】,【float】,【double】byte:最大值为127,最小值为-128;short:最大值为32767,最小值为-32768;int:最大值为2,147,483,647,最小值为-2,147,483,648;long:最大值为9,223,372,036,854,775,807,最小值为-9,223,372,036,854,7......
  • 网络字节序和本地字节序之间转换
    网络字节序和本地字节序之间转换目录网络字节序和本地字节序之间转换主机字节序网络字节序相关函数htons,htonl,ntohs,和ntohl相关函数inet_aton,inet_ntoa,inet_pton,和inet_ntop当我们与同一台计算机的进程进行通信时,一般不用考虑字节序。什么是字节序——字节序是一......
  • 进制转换
    voiddecToBinary(intdecimal){intbinary[32];intcnt=0;while(decimal>0){binary[cnt]=decimal%2;//取模放在低位decimal=decimal/2;//更新参数cnt++;}printf("Binaryequivalent:"......
  • Transgaga——人脸与猫脸之间互相转换算法解析
    1.概述虽然pix2pix作为风格转换模型被提出,但它依赖于成对的数据集。与之相比,CycleGAN通过引入循环损失,实现了无需配对数据的风格转换。不过,CycleGAN在处理需要大幅几何变化的风格转换时表现不佳,仅在如马和斑马这类颜色变化的场景中有效。2018年,MUNIT利用变分自编码器(VAE)......
  • pdf如何转换成excel文档?这3个方法免费!
    职场人士常常会遇到PDF文件格式,因为PDF便于传输且能够保持排版稳定,因此在文件分享中备受青睐。然而,PDF文件中的表格数据可能涉及到公式和函数,而PDF格式又不易编辑,这时我们就需要将其转换为Excel格式进行编辑修改。因此,掌握PDF转换成Excel的技能对职场人士来说非常有意义。幸运的......
  • C++ 强制类型转换运算符简介
    C++提供了四种强制类型转换运算符:static_cast、reinterpret_cast、const_cast和dynamic_cast。这些运算符各自具有特定的用途,适用于不同的类型转换需求。本文将详细介绍这四种运算符及其应用场景,并讨论它们在向上转换中的使用方法。1.static_caststatic_cast用于在编译时执......
  • (蕊源)代理 RY3750 SOT-23-5 1.2MHz,30V,升压转换器
    产品描述RY3750是一个升压转换器。其1.23V的反馈电压降低了功率损耗,提高了效率。优化的运行频率可以满足LC滤波器值小、低运行电流的要求。内部软启动功能可以降低冲涌电流。小包装类型为节省PCB空间和总BOM成本提供了最佳的解决方案。产品特点2.5V至5.5V的输入电压1.23V反......
  • 代码随想录算法训练营第二十三天 | 669.修剪二叉搜索树 108.将有序数组转换为二叉搜索
    669.修剪二叉搜索树题目链接文章讲解视频讲解classSolution{public:TreeNode*trimBST(TreeNode*root,intlow,inthigh){if(root==nullptr)returnnullptr;//当前值小于左边界时,当前节点的左子树全部小于左边界,所以全部删除,直接处理右子树......
  • TensorRT c++部署onnx模型
    在了解一些概念之前一直看不懂上交22年开源的TRTModule.cpp和.hpp,好在交爷写的足够模块化,可以配好环境开箱即用,移植很简单。最近稍微了解了神经网络的一些概念,又看了TensorRT的一些api,遂试着部署一下自己在MNIST手写数字数据集上训练的一个LeNet模型,识别率大概有98.9%,实现用pytor......