首页 > 其他分享 >【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程

【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程

时间:2024-04-03 21:31:17浏览次数:27  
标签:img pth sr torch TensorRT Pytorch output path model

  1. 整体流程为:.pth -> .onnx -> .plan (或.trt,二者等价)
  2. 需要的工具和包:Docker,Pytorch,ONNX,onnxruntime,TensorRT(trtexec 和 polygraphy)

.pth 到 .onnx

这里以 SwinIR (https://github.com/JingyunLiang/SwinIR) 预训练模型为例

  1. init_torch_model() 函数主要是对模型初始化,这里是根据 mian_test_swinir.py 中 define_model(args) 的模型定义函数调整的,按照需求对超参数、模型的选择来进行改写各种模型配置。
  2. torch.onnx.export() 函数则是 torch 中自带的模型转化方法,注意可以设置 dynamic_axes ,即特定维度的动态输入,具体可参考官方文档:https://pytorch.org/tutorials//beginner/onnx/export_simple_model_to_onnx_tutorial.html
import torch
from models.network_swinir import SwinIR as net

torch_model_path = '/yourpath/to/swinir/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth'

def init_torch_model():
    torch_model = net(upscale=4, 
                in_chans=3, 
                img_size=64,         
                window_size=8, 
                img_range=1., 
                depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                mlp_ratio=2, 
                upsampler='nearest+conv', resi_connection='3conv')
    param_key_g = 'params_ema'

    pretrained_model = torch.load(torch_model_path)
    torch_model.load_state_dict(pretrained_model[param_key_g] 
                          if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)

    torch_model.eval()

    return torch_model

model = init_torch_model()

x = torch.randn(1, 3, 256, 256) 
# 0, 1, 2, 3 中 0, 2, 3 都是动态的
 
with torch.no_grad(): 
    torch.onnx.export(
        model, 
        x, 
        "swinir_real_sr_large_model_dynamic_20.onnx", 
        opset_version=19, 
        input_names=['input'], 
        output_names=['output'],
        dynamic_axes={'input' : {0 : 'batch_size',
                                 2 : 'height',
                                 3 : 'width'},
                      'output' : {0 : 'batch_size',
                                  2 : 'height',
                                  3 : 'width'}})
用 onnxruntime 测试 .onnx 是否能用
import cv2
import numpy as np
import torch
import time
import onnxruntime  
import os
from PIL import Image
import torchvision.transforms as transforms
from crop1_4 import crop
from combine4_1 import combine

# 全局初始化ONNX Runtime会话
def initialize_session():
    session_options = onnxruntime.SessionOptions()
    # session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # 打出日志
    ort_session = onnxruntime.InferenceSession('/path/to/yourmodel.onnx',
                                               session_options=session_options,
                                               providers=['CUDAExecutionProvider'])
    return ort_session
   
def srxn(sr_xn, sr_input):

    ort_session = initialize_session() # 初始化ONNX Runtime会话

    save_dir = f'/path/to/outputs'
    if not os.path.exists(save_dir):
        # 如果目录不存在,则创建目录
        os.makedirs(save_dir)

    path = sr_input
    (imgname, imgext) = os.path.splitext(os.path.basename(path))

    if sr_xn == 2:
        output = main_x2(sr_input, ort_session)
    elif sr_xn == 4:
        output = main_x4(sr_input, ort_session)
    elif sr_xn == 8:
        output_mid = main_x2(sr_input)

        sr_output_mid = os.path.join(save_dir, f"mid_result.png")
        cv2.imwrite(sr_output_mid, output_mid)

        output = main_x4(sr_output_mid)
    
    saved_image_path = os.path.join(save_dir, f"final_result.png")
        
    save_success = cv2.imwrite(saved_image_path, output)

    if save_success:
        print(f"Image successfully saved at: {os.path.abspath(saved_image_path)}")
    else:
        print("Failed to save the image.")

    sr_output = saved_image_path

    return sr_output


def main_x4(sr_input, ort_session):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read image
    path = sr_input
    (imgname, imgext) = os.path.splitext(os.path.basename(path))

    # image to HWC-BGR, float32 (NumPy)
    img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.

    # HCW-BGR to CHW-RGB
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  
    
    # CHW-RGB to NCHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  

    # inference
    with torch.no_grad():
        window_size = 8
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        
        # output = test(img_lq)
        start_time = time.time() # start time

        # 假设 img_lq 是一个存储在CUDA上的Tensor (NCHW-RGB)
        if img_lq.is_cuda:
            numpy_input = img_lq.cpu().numpy()
        else:
            numpy_input = img_lq.numpy()

        # check is using GPU?
        print(onnxruntime.get_device())

        # runtime
        # ort_session = onnxruntime.InferenceSession('/home/stone/Desktop/SR/SwinIR/swinir_real_sr_large_model_dynamic.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

        # onnx 的输入是 numpy array 而非 tensor!    
        ort_inputs = {'input': numpy_input}

        ort_output = ort_session.run(['output'], ort_inputs)[0]
        
        ort_output = torch.from_numpy(ort_output)# numpy 转 torch

        output = ort_output[..., :h_old * 4, :w_old * 4]

        stop_time = time.time() # start time
        print(f'Test time: {stop_time - start_time:.2f}s')  

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

    return output


def read_img_from_path(img_file_path):
    # 定义图片扩展名列表
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']  # 可根据需要添加更多

    # 初始化一个列表来存储图片路径
    image_paths = []

    # 遍历sr_input目录下的所有文件
    for root, dirs, files in os.walk(img_file_path):
        for file in files:
            if os.path.splitext(file)[1].lower() in image_extensions:
                # 构建完整的文件路径并添加到列表中
                image_paths.append(os.path.join(root, file))

    return image_paths

if __name__ == '__main__':

    image_paths = read_img_from_path(img_file_path = '/path/to/inputs')
    # 遍历所有找到的图片路径
    for path in image_paths:
        sr_input = path
        sr_xn = 4
        sr_output = srxn(sr_xn, sr_input)

这里有个小坑:

  1. 初始化 ONNX Runtime 会话可能比较费时,所以 onnxruntime.InferenceSession 初始化可以和 .run 分开,初始化一次后的每次推理只需要 .run 即可,具体见上述代码。
  2. 初始化 ONNX Runtime 如果特别费时,可以通过 onnx-simplifier。
  3. 解决这个问题具体可参考:https://blog.csdn.net/weixin_44212848/article/details/137044477

.onnx 到 .plan (.trt)

本文是直接用的 TensorRT 中的 trtexec 和 polygraphy 的命令行工具,比较快捷。以下 bash 都是在 docker 的命令行中进行的,具体的 TensorRT docker 可参考 https://github.com/NVIDIA/TensorRT/blob/main/quickstart/deploy_to_triton/README.md

trtexec \
--onnx=yourmodel.onnx \
--saveEngine=yourmodel.plan \
--minShapes=input:1x3x36x36 \
--optShapes=input:2x3x512x512 \
--maxShapes=input:2x3x512x512 \
--verbose \
--fp16 \
> trtexec-result-512-2-fp16.log 2>&1

可以检测 .log 中的情况,如果没问题就 .plan 就转化好啦。

当然这里也有些坑,比如明明是显存不够错误,但日志里完全没提 oom,而是说节点问题(参考https://blog.csdn.net/weixin_44212848/article/details/137286847)

不论什么问题,可以试试 polygraphy inspect 检查一下 TensorRT 是否完全支持你的 .onnx

polygraphy inspect model modelA.onnx \
    --model-type=onnx \
    --shape-inference \
    --show layers attrs weights \
    --list-unbounded-dds \
    --verbose \
    > result-01.log

如果完全支持的话,.log 里的内容大致类似如下,重点是提到 “Graph is fully supported by TensorRT; Will not generate subgraphs.”,那么恭喜你的 .onnx 大概率是可以转化到 .plan 的!

[W] 'colored' module is not installed, will not use colors when logging. To enable colors, please install the 'colored' module: python3 -m pip install colored
[I] Loading bytes from yourmodel.onnx
[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[I] Graph is fully supported by TensorRT; Will not generate subgraphs.
  • 官方关于 trtexec 的中文博客:https://developer.nvidia.com/zh-cn/blog/tensorrt-trtexec-cn/
  • 官方 trtexec 示例:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/07-Tool/trtexec/command.sh
  • 官方 polygraph 示例:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/07-Tool/Polygraphy-CLI/InspectExample/command.sh
  • 推荐可以看下官方 b 站教程(时间充裕的话):https://www.bilibili.com/video/BV12X4y1H7P6/?spm_id_from=333.788&vd_source=32f6f61e74ca115cbaca6bd6bb144662

标签:img,pth,sr,torch,TensorRT,Pytorch,output,path,model
From: https://blog.csdn.net/weixin_44212848/article/details/137357445

相关文章

  • YOLOV4:You Only Look Once目标检测模型在pytorch当中的实现
    文章目录概要整体架构流程技术名词解释技术细节小结源码链接:GitHub-AlexeyAB/darknet:YOLOv4/Scaled-YOLOv4/YOLO-NeuralNetworksforObjectDetection(WindowsandLinuxversionofDarknet)概要1.1模型架构YOLOv4项目实现了YOLOv4算法的网络架构,......
  • Pytorch torch.utils.data.DataLoader 用法详细介绍
    文章目录1.介绍2.参数详解3.用法4.参考1.介绍torch.utils.data.DataLoader是PyTorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作。PytorchDataLoader的详细官......
  • 环境搭建 ubuntu22.04+gtx1070+cuda12.0+cudnn8.8.0+TensorRT8.6
    构建基础             cuda12.0的.deb包会强制安装所依赖的nvidia-525.60.13版本驱动,但是对于ubuntu22.04来说,linux内核为6.5.0,其与该nvidia驱动不兼容,会报错,所以要先安装所支持的驱动,然后再使用runfile进行安装cuda12.0。cuda与驱动版本对应可查如下官网:1.C......
  • 大模型中常用的注意力机制GQA详解以及Pytorch代码实现
    分组查询注意力(GroupedQueryAttention)是一种在大型语言模型中的多查询注意力(MQA)和多头注意力(MHA)之间进行插值的方法,它的目标是在保持MQA速度的同时实现MHA的质量。这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。GQA是在论文GQA:TrainingGeneraliz......
  • PyTorch学习(5):并行训练模型权重的本地化与加载
    1.并行训练与非并行训练        在训练深度神经网络时,我们一般会采用CPU或GPU来完成。得益于开源传统,许多算法都提供了完整的开源代码工程,便于学习和使用。随着GPU的普及,GPGPU已经占据了大部分的训练场景。        我们在这里仅以GPU训练场景做一些说明。......
  • pytorch | torchvision.transforms.CenterCrop
    torchvision.transforms.CenterCrop==>从图像中心裁剪图片transforms.CenterCroptorchvision.transforms.CenterCrop(size)功能:从图像中心裁剪图片size:所需裁剪的图片尺寸transforms.CenterCrop(196)的效果如下:(也可以写成transforms.CenterCrop((196,196)))如果裁剪......
  • pytorch在Mac上实现像cuda一样的加速
    1.参考:https://developer.apple.com/metal/pytorch/2.具体实现:2.1RequirementsMacM芯片或者AMD的GPUmacOS12.3orlaterPython3.7orlaterXcodecommand-linetools: xcode-select--install2.2准备anac......
  • Pytorch - Dataloader
    BasicallytheDataLoaderworkswiththeDatasetobject.SotousetheDataLoaderyouneedtogetyourdataintothisDatasetwrapper.Todothisyouonlyneedtoimplementtwomagicmethods:__getitem__and__len__.The__getitem__takesanindexandretu......
  • 【PyTorch 实战2:UNet 分类模型】10min揭秘 UNet 分割网络如何工作以及pytorch代码实现
    UNet网络详解及PyTorch实现一、UNet网络原理  U-Net,自2015年诞生以来,便以其卓越的性能在生物医学图像分割领域崭露头角。作为FCN的一种变体,U-Net凭借其Encoder-Decoder的精巧结构,不仅在医学图像分析中大放异彩,更在卫星图像分割、工业瑕疵检测等多个领域展现出强大的应用......
  • 使用镜像安装cuda12.1版本pytorch
    1.添加通道condaconfig--addchannelshttps://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/condaconfig--addchannelscondaconfig--addchannelshttps://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/condaconfig--addchannelshttps://mirrors.bfs......