首页 > 编程语言 >onnx转engine工具(包含量化) python脚本

onnx转engine工具(包含量化) python脚本

时间:2024-08-09 10:59:13浏览次数:11  
标签:engine python onnx self cache batch im np size

量化工具在网上搜索五花八门,很多文章没有说明使用的版本导致无法复现,这里参考了一些写法实现量化,并转为engine。具体实现代码见下方,欢迎各位小伙伴批评指正。

tensorrt安装

参考windows11下安装Tensor RT,并在conda虚拟环境下使用_tensor rt 免费吗-CSDN博客

pycuda安装

参考GPU编程(基于Python和CUDA)(一)——零基础安装pycuda-CSDN博客

代码

版本说明:

        tensorrt:8.5.3.1

        cuda:11.7

不同trt版本,有些api不一样

import cv2
import glob
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import os
import numpy as np
from PIL import Image


def preprocess_input(image):
    """
    图像预处理
    """
    image = image / 255.0
    image = image - np.array([0.485, 0.456, 0.406])
    image = image / np.array([0.229, 0.224, 0.225])

    # image -= np.array([0.5, 0.5, 0.5])
    # image /= np.array([0.5, 0.5, 0.5])
    return image


class YOLOXEntropyCalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, args, files_path='data', cache_file='YOLOX.cache'):
        trt.IInt8EntropyCalibrator2.__init__(self)

        self.cache_file = cache_file

        self.batch_size = args.batch_size
        self.Channel = args.channel
        self.Height = args.height
        self.Width = args.width

        # 获取数据集中图像的路径列表
        self.imgs = glob.glob(os.path.join(files_path, '*.jpg'))

        # 初始化内存
        self.batch_idx = 0
        self.max_batch_idx = len(self.imgs) // self.batch_size
        self.data_size = trt.volume([self.batch_size, self.Channel, self.Height, self.Width]) * trt.float32.itemsize
        self.device_input = cuda.mem_alloc(self.data_size)

    def __resize_pic(self, im0, auto=False):
        '''
            图片读取,resize,padding
        :param im: 图像数组
        :return: resize,padding后的图像数组,只resize的图像shape H,W,原始图像shape,  padding的像素 H,W
        '''
        h, w, _ = im0.shape
        r = max(h / self.Height, w / self.Width)
        new_h = int(round(h / r))
        new_w = int(round(w / r))
        if auto:
            ph, pw = np.mod(self.Height - new_h, 32) / 2, np.mod(self.Width - new_w, 32) / 2  # 最小填充
        else:
            ph, pw = (self.Height - new_h) / 2, (self.Width - new_w) / 2  # 填充到正方形

        pt = int(round(ph - 0.1))
        pb = int(round(ph + 0.1))
        pl = int(round(pw - 0.1))
        pr = int(round(pw + 0.1))
        im = cv2.resize(im0, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        im = cv2.copyMakeBorder(im, pt, pb, pl, pr, cv2.BORDER_CONSTANT, value=(114, 114, 114))
        return im, (new_h, new_w), im0.shape, (pt, pb, pl, pr)

    def per_process(self, img_path):
        '''
            前处理函数,读取图片,经过padding及归一化处理为模型的输入形式 B,C,H,W
        :param img_path: 图片路径
        :return: 输入数据(ndarray)
        '''
        im = cv2.imread(img_path)
        im, img1_shape, im0_shape, pad = self.__resize_pic(im)
        im = im[..., ::-1]
        im = np.ascontiguousarray(im, dtype=np.float32)
        im = np.transpose(im, (2, 0, 1))
        im = im / 255

        return im

    def next_batch(self):
        """
        读取一个batch的图像数据
        """
        if self.batch_idx < self.max_batch_idx:
            # ***********读取一个batch的文件**************#
            batch_files = self.imgs[self.batch_idx * self.batch_size: (self.batch_idx + 1) * self.batch_size]

            batch_imgs = np.zeros((self.batch_size, self.Channel, self.Height, self.Width), dtype=np.float32)
            for i, f in enumerate(batch_files):
                img = self.per_process(f)
                # 判断字节是否与缓冲区对齐
                assert (img.nbytes == self.data_size / self.batch_size), 'not valid img!' + f
                batch_imgs[i] = img
            self.batch_idx += 1
            print("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
            return np.ascontiguousarray(batch_imgs)
        else:
            return np.array([])

    def get_batch_size(self):
        """
        获取batch大小
        """
        return self.batch_size

    def get_batch(self, names, p_str=None):
        """
        获取一个batch的图像数据,并拷贝到device内存中
        """
        try:
            batch_imgs = self.next_batch()
            if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.Channel * self.Height * self.Width:
                return None
            cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
            return [int(self.device_input)]
        except Exception as e:
            print("发生异常,异常为:{}".format(e))
            return None

    def read_calibration_cache(self):
        """
        读取缓存数据
        """
        # 如果存在校准集的缓存,则使用现有缓存,否则返回空值
        if os.path.exists(self.cache_file):
            print("succeed finding cache file:{}".format(self.cache_file))
            with open(self.cache_file, "rb") as f:
                return f.read()
        else:
            print("failed finding cache!")
            return

    def write_calibration_cache(self, cache):
        """
        写入缓存数据
        """
        with open(self.cache_file, "wb") as f:
            f.write(cache)
        print("succeed saving cache!")


import tensorrt as trt
import argparse

# 显式配置batch size
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)


def ONNX2TRT(args, calib=None):
    '''
    :brief:  convert onnx to tensorrt engine, use mode of ['fp16', 'int8']
    :return: trt engine
    '''

    # 判断模式是否可用
    assert args.mode.lower() in ['fp16', 'int8'], "mode should be in ['fp16', 'int8']"

    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    with trt.Builder(G_LOGGER) as builder, \
            builder.create_network(EXPLICIT_BATCH) as network, \
            trt.OnnxParser(network, G_LOGGER) as parser, \
            builder.create_builder_config() as config, \
            trt.Runtime(G_LOGGER) as runtime:

        # 配置tensorrt的推理缓冲区大小,即构建阶段可用的显存大小
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4096 * (1 << 30))

        if args.mode.lower() == 'int8':
            assert (builder.platform_has_fast_int8 == True), "not support int8"
            # 配置int8量化所需的参数及校准集
            config.set_flag(trt.BuilderFlag.INT8)
            config.int8_calibrator = calib

        elif args.mode.lower() == 'fp16':
            assert (builder.platform_has_fast_fp16 == True), "not support fp16"
            # 配置fp16模式下的参数
            config.set_flag(trt.BuilderFlag.FP16)
            # config.fp16

        # 加载onnx模型,并解析
        print('Loading ONNX file from path {}...'.format(args.onnx_file_path))
        with open(args.onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            if not parser.parse(model.read()):  # parser是tensorrt的onnx解析类,声明位置见20行
                print("ERROR: Failed to parse the ONNX file.")
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        print(network.get_input(0).shape)
        # network.get_input(0).shape = [1, 3, 640, 960]
        print('Completed parsing of ONNX file')

        # 构建序列化引擎文件
        print('Building an engine from file {}; this may take a while...'.format(args.onnx_file_path))
        # 根据配置及解析的网络构建序列化会话
        engine = builder.build_serialized_network(network, config)

        with open(args.engine_file_path, "wb") as f:
            f.write(engine)
        print("Completed creating Engine")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pytorch2TensorRT args")
    parser.add_argument("--batch_size", type=int, default=1, help='batch_size')
    parser.add_argument("--channel", type=int, default=3, help='input channel')
    parser.add_argument("--height", type=int, default=640, help='input height')
    parser.add_argument("--width", type=int, default=640, help='input width')
    parser.add_argument("--cache_file", type=str, default='YOLOX.cache', help='cache_file')
    parser.add_argument("--mode", type=str, default='int8', help='fp16 or int8')
    parser.add_argument("--onnx_file_path", type=str, default='yolov8s.onnx', help='onnx_file_path')
    parser.add_argument("--engine_file_path", type=str, default='yolov8s_int8.engine', help='engine_file_path')
    args = parser.parse_args()
    calib = YOLOXEntropyCalibrator(args)
    ONNX2TRT(args, calib)

标签:engine,python,onnx,self,cache,batch,im,np,size
From: https://blog.csdn.net/Meoyou/article/details/141052052

相关文章

  • 20:Python函数
    #Python3函数#函数是组织好的,可重复使用的,用来实现单一,或相关联功能的代码段。#函数能提高应用的模块性,和代码的重复利用率。你已经知道Python提供了许多内建函数,比如print()。#但你也可以自己创建函数,这被叫做用户自定义函数。#定义一个函数#你可以定义一个由自己想要功能......
  • 使用python做页面,测试数据库连通性!免费分享!测试通过~
    免费分享刚刚写的一个小程序,测试通过没问题,解BUG也就花了半小时吧有更好的方法欢迎评论区推给我谢谢。importtkinterastkfromtkinterimportmessageboximportpymysqldefget_db_info(db_source):ifdb_source=='database1':hostname=e1.get()......
  • Python面试宝典第30题:找出第K大元素
    题目        给定一个整数数组nums,请找出数组中第K大的数,保证答案存在。其中,1<=K<=nums数组长度。        示例1:输入:nums=[3,2,1,5,6,4],K=2输出:5        示例2:输入:nums=[50,23,66,18,72],K=3输出:50快速选择算法......
  • 使用Python和Flask框架实现简单的RESTful API
    目录环境准备创建Flask应用运行Flask应用测试API注意事项在当今的Web开发领域,RESTfulAPI因其简洁性和高效性而备受欢迎。本文将引导你使用Python的Flask框架来创建一个简单的RESTfulAPI,用于增删改查(CRUD)用户信息。环境准备在开始之前,请确保你的Python环境中已经安......
  • nodejs语言,MySQL数据库;springboot的个性化资讯推荐系统66257(免费领源码)计算机毕业设计
    摘 要随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,个性化资讯推荐系统当然也不能排除在外。个性化资讯推荐系统是以实际运用为开发背景,运用软件工程原理和开发方法,采用springboot技术构建的一个管理系统。整......
  • c#语言,SQL server数据库;基于Web的社区人员管理系统的设计与实现36303(免费领源码)计算机
    目 录摘要1绪论1.1慨述1.2课题意义1.3B/S体系结构介绍1.4ASP.NET框架介绍2 社区人员管理系统分析2.1可行性分析2.2系统流程分析2.2.1数据增加流程2.2.2数据修改流程52.2.3数据删除流程52.3系统功能分析62.3.1功能性分析62.3.2非功能性......
  • Python多种接口请求方式示例
    发送JSON数据如果你需要发送JSON数据,可以使用json参数。这会自动设置Content-Type为application/json。importrequestsimportjsonurl='http://example.com/api/endpoint'data={"key":"value","another_key":"another_value"......
  • 【优秀python毕设案例】基于python django的新媒体网络舆情数据爬取与分析
    摘   要如今在互联网时代下,微博成为了一种新的流行社交形式,是体现网络舆情的媒介之一。现如今微博舆论多带有虚假不实、恶意造谣等负面舆论,为了营造更好的网络环境,本设计提出了基于新媒体的网络舆情数据爬取与分析,主要对微博热点话题进行处理。本设计首先以Python为环......
  • 在 Rust 中嵌入 Python 来调用外部 Python 库
    我正在尝试学习如何将Python嵌入到Rust应用程序中。出于学习目的,我想创建一个运行永远循环的Rust脚本/应用程序。该循环会休眠设定的时间间隔,醒来后,它使用Pythonrequests库从互联网时间服务器获取当前时间。虽然这不是一个实际应用程序,但我的目标是了解如何从Rust调用......
  • 如何从我的 Python 应用程序更新我的 Facebook Business 令牌?
    我有一个使用FacebookBusiness库的Python应用程序。因此,我需要使用Facebook提供的令牌来访问我的见解并操纵它们。但是,这个令牌有一个很长的到期日期,但我想知道是否有办法自动更新这个令牌在我的应用程序中,这样它就不会停止运行。当然可以!你可以使用Facebook提......