首页 > 编程语言 >TNN推理测试demo--c++

TNN推理测试demo--c++

时间:2022-12-11 18:11:20浏览次数:42  
标签:std name -- demo c++ instance blob input tnn

//
// Created by DangXS on 2022/12/8.
//


#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <cfloat>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <sstream>
#include <string>
#include <iostream>

#include "tnn/core/common.h"
#include "tnn/core/instance.h"
#include "tnn/core/macro.h"
#include "tnn/core/tnn.h"
#include "tnn/utils/blob_converter.h"
#include "tnn/utils/cpu_utils.h"
#include "tnn/utils/data_type_utils.h"
#include "tnn/utils/dims_vector_utils.h"

#include "opencv2/opencv.hpp"

using namespace std;
using namespace cv;

#define LITETNN_DEBUG


#define DEVICE_X86     0x0010
#define DEVICE_ARM     0x0020
#define DEVICE_OPENCL  0x1000


//#define DEVICE DEVICE_X86
#define DEVICE DEVICE_ARM
//#define DEVICE DEVICE_OPENCL


// static methods.
// reference: https://github.com/Tencent/TNN/blob/master/examples/base/utils/utils.cc
std::string content_buffer_from(const char *proto_or_model_path) {
    std::ifstream file(proto_or_model_path, std::ios::binary);
    if (file.is_open()) {
        file.seekg(0, std::ifstream::end);
        int size = file.tellg();
        char *content = new char[size];
        file.seekg(0, std::ifstream::beg);
        file.read(content, size);
        std::string file_content;
        file_content.assign(content, size);
        delete[] content;
        file.close();
        return file_content;
    } // empty buffer
    else {
#ifdef LITETNN_DEBUG
        std::cout << "Can not open " << proto_or_model_path << "\n";
#endif
        return "";
    }
}

// static methods.
tnn::DimsVector get_input_shape(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    tnn::DimsVector shape = {};
    tnn::BlobMap blob_map = {};
    if (_instance) {
        _instance->GetAllInputBlobs(blob_map);
    }

    if (name == "" && blob_map.size() > 0)
        if (blob_map.begin()->second)
            shape = blob_map.begin()->second->GetBlobDesc().dims;

    if (blob_map.find(name) != blob_map.end()
        && blob_map[name]) {
        shape = blob_map[name]->GetBlobDesc().dims;
    }

    return shape;
}

// static methods.
tnn::DimsVector get_output_shape(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    tnn::DimsVector shape = {};
    tnn::BlobMap blob_map = {};
    if (_instance) {
        _instance->GetAllOutputBlobs(blob_map);
    }

    if (name == "" && blob_map.size() > 0)
        if (blob_map.begin()->second)
            shape = blob_map.begin()->second->GetBlobDesc().dims;

    if (blob_map.find(name) != blob_map.end()
        && blob_map[name]) {
        shape = blob_map[name]->GetBlobDesc().dims;
    }

    return shape;
}

// static methods.
std::vector<std::string> get_input_names(
        const std::shared_ptr<tnn::Instance> &_instance) {
    std::vector<std::string> names;
    if (_instance) {
        tnn::BlobMap blob_map;
        _instance->GetAllInputBlobs(blob_map);
        for (const auto &item : blob_map) {
            names.push_back(item.first);
        }
    }
    return names;
}

// static method
std::vector<std::string> get_output_names(
        const std::shared_ptr<tnn::Instance> &_instance) {
    std::vector<std::string> names;
    if (_instance) {
        tnn::BlobMap blob_map;
        _instance->GetAllOutputBlobs(blob_map);
        for (const auto &item : blob_map) {
            names.push_back(item.first);
        }
    }
    return names;
}

// static method
tnn::MatType get_output_mat_type(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap output_blobs;
        _instance->GetAllOutputBlobs(output_blobs);
        auto blob = (name == "") ? output_blobs.begin()->second : output_blobs[name];
        if (blob->GetBlobDesc().data_type == tnn::DATA_TYPE_INT32) {
            return tnn::NC_INT32;
        }
    }
    return tnn::NCHW_FLOAT;
}

// static method
tnn::DataFormat get_output_data_format(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap output_blobs;
        _instance->GetAllOutputBlobs(output_blobs);
        auto blob = (name == "") ? output_blobs.begin()->second : output_blobs[name];
        return blob->GetBlobDesc().data_format;
    }
    return tnn::DATA_FORMAT_NCHW;
}

// static method
tnn::MatType get_input_mat_type(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap input_blobs;
        _instance->GetAllInputBlobs(input_blobs);
        auto blob = (name == "") ? input_blobs.begin()->second : input_blobs[name];
        if (blob->GetBlobDesc().data_type == tnn::DATA_TYPE_INT32) {
            return tnn::NC_INT32;
        }
    }
    return tnn::NCHW_FLOAT;
}

// static method
tnn::DataFormat get_input_data_format(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap input_blobs;
        _instance->GetAllInputBlobs(input_blobs);
        auto blob = (name == "") ? input_blobs.begin()->second : input_blobs[name];
        return blob->GetBlobDesc().data_format;
    }
    return tnn::DATA_FORMAT_NCHW;
}


int main(int argc, char *argv[]) {
    const char *proto_path = "./1.tnnproto";
    const char *model_path = "./1.tnnmodel";
//    const char *proto_path = "./1.param";
//    const char *model_path = "./1.bin";
    // Note, tnn:: actually is TNN_NS::, I prefer the first one.
    std::shared_ptr<tnn::TNN> net;
    std::shared_ptr<tnn::Instance> instance;
    std::shared_ptr<tnn::Mat> input_mat; // assume single input.

    std::string proto_content_buffer, model_content_buffer;
    proto_content_buffer = content_buffer_from(proto_path);
    model_content_buffer = content_buffer_from(model_path);

    tnn::ModelConfig model_config;
    model_config.model_type = tnn::MODEL_TYPE_TNN;
//    model_config.model_type = tnn::MODEL_TYPE_NCNN;
    model_config.params = {proto_content_buffer, model_content_buffer};

    // 1. init TNN net
    tnn::Status status;
    net = std::make_shared<tnn::TNN>();
    status = net->Init(model_config);
    model_config.params.clear();
    if (status != tnn::TNN_OK || !net) {
#ifdef LITETNN_DEBUG
        std::cout << "CreateInst failed!\t" << status.description().c_str() << "\n";
#endif

        std::cout << "net->Init failed!\n";
        return -100;
    }
    // 2. init device type, change this default setting
    // for better performance. such as CUDA/OPENCL/...
    auto _device = (tnn::DeviceType) DEVICE;
    tnn::DeviceType network_device_type = _device; // CPU,GPU
    tnn::DeviceType input_device_type = _device; // CPU only
    tnn::DeviceType output_device_type = _device;
    const unsigned int num_threads = 2;
    int input_batch;
    int input_channel;
    int input_height;
    int input_width;
    unsigned int input_value_size;
    tnn::DataFormat input_data_format;  // e.g DATA_FORMAT_NHWC
    tnn::MatType input_mat_type; // e.g NCHW_FLOAT
    // Actually, i prefer to hardcode the input/output names
    // into subclasses, but we just let the auto detection here
    // to make sure the debug information can show more details.
    std::string input_name; // assume single input only.
    std::vector<std::string> output_names; // assume >= 1 outputs.
    tnn::DimsVector input_shape; // vector<int>
    std::map<std::string, tnn::DimsVector> output_shapes;


    // 3. init instance
    tnn::NetworkConfig network_config;
    network_config.library_path = {""};
    network_config.device_type = network_device_type;

    instance = net->CreateInst(network_config, status);
    if (status != tnn::TNN_OK || !instance) {
#ifdef LITETNN_DEBUG
        std::cout << "CreateInst failed!" << status.description().c_str() << "\n";
#endif
        return -200;
    }
    // 4. setting up num_threads
    instance->SetCpuNumThreads((int) num_threads);
    // 5. init input information.
    input_name = get_input_names(instance).front();
    printf("input_name: %s\n", input_name.c_str());

    input_shape = get_input_shape(instance, input_name);
    printf("input_shape: \n\t");
    for (int i = 0; i < input_shape.size(); ++i) {
        printf(" %d", input_shape[i]);
    }
    printf("\n");

    if (input_shape.size() != 4) {
#ifdef LITETNN_DEBUG
        throw std::runtime_error("Found input_shape.size()!=4, but "
                                 "BasicTNNHandler only support 4 dims."
                                 "Such as NCHW, NHWC ...");
#else
        return -400;
#endif
    }
    input_mat_type = get_input_mat_type(instance, input_name);
    input_data_format = get_input_data_format(instance, input_name);

    printf("input_data_format: %d\n", input_data_format);
    input_batch = input_shape.at(0);
    input_channel = input_shape.at(1);
    input_height = input_shape.at(2);
    input_width = input_shape.at(3);

    // 6. init input_mat
    input_value_size = input_batch * input_channel * input_height * input_width;
    // 7. init output information, debug only.
    output_names = get_output_names(instance);
    int num_outputs = output_names.size();

    printf("output_names: \n\t");
    for (auto &name: output_names) {
        output_shapes[name] = get_output_shape(instance, name);
        printf("%s, ", name.c_str());
    }
    printf("\n");


    // forward
    string img_path = "./1.jpg";
    cv::Mat mat = cv::imread(img_path);
    assert(!mat.empty());

    // In TNN: x*scale + bias
    std::vector<float> scale_vals = {
            (1.0f / 0.229f) * (1.0 / 255.f),
            (1.0f / 0.224f) * (1.0 / 255.f),
            (1.0f / 0.225f) * (1.0 / 255.f)
    };
    std::vector<float> bias_vals = {
            -0.485f * 255.f * (1.0f / 0.229f) * (1.0 / 255.f),
            -0.456f * 255.f * (1.0f / 0.224f) * (1.0 / 255.f),
            -0.406f * 255.f * (1.0f / 0.225f) * (1.0 / 255.f)
    };

    // 1. make input mat
    cv::Mat mat_rs;
    cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
    cv::cvtColor(mat_rs, mat_rs, cv::COLOR_BGR2RGB);

//    // Format Types: (0: NCHW FLOAT), (1: 8UC3), (2: 8UC1)
//    tnn::DataType data_type = tnn::DataType::DATA_TYPE_INT8;
//    int bytes = tnn::DimsVectorUtils::Count(input_shape) * tnn::DataTypeUtils::GetBytesSize(data_type);
//    void *mat_data = malloc(bytes);
//    input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::N8UC3, input_shape, mat_data);

    // push into input_mat (1,3,224,224)
    input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::N8UC3, input_shape, (void *) mat_rs.data);
    assert(input_mat->GetData()!= nullptr);

    // 2. set input_mat
    tnn::MatConvertParam input_cvt_param;
    input_cvt_param.scale = scale_vals;
    input_cvt_param.bias = bias_vals;

    status = instance->SetInputMat(input_mat, input_cvt_param);
    if (status != tnn::TNN_OK) {
#ifdef LITETNN_DEBUG
        std::cout << status.description().c_str() << "\n";
#endif
        return -500;
    }

    // 3. forward
    double c1 = cv::getTickCount();
    status = instance->Forward();
    double c2 = cv::getTickCount();
    auto spend_time = (c2 - c1) / cv::getTickFrequency();
    printf("forward elapsed: %f ms\n", spend_time);

    if (status != tnn::TNN_OK) {
#ifdef LITETNN_DEBUG
        std::cout << status.description().c_str() << "\n";
#endif
        return -600;
    }
    // 4. fetch.
    tnn::MatConvertParam cvt_param;
    std::vector<std::shared_ptr<tnn::Mat>> feat_mats;
    for (int i = 0; i < output_names.size(); ++i) {
        std::shared_ptr<tnn::Mat> feat_mat;
        status = instance->GetOutputMat(feat_mat, cvt_param, output_names[i], output_device_type);
        feat_mats.push_back(feat_mat);
    }

    printf("net output: \n");
    for (int i = 0; i < output_names.size(); ++i) {
        printf("name: %s\t\tfeat shape: %d %d %d %d\n", output_names[i].c_str(),
               feat_mats[i]->GetBatch(), feat_mats[i]->GetChannel(), feat_mats[i]->GetHeight(),
               feat_mats[i]->GetWidth());

        float *data = reinterpret_cast<float *>(feat_mats[i]->GetData());
        for (int j = 0; j < 10; ++j) {
            std::cout << data[j] << ", ";
        }
        std::cout << std::endl;
    }
    printf("\n");

    if (status != tnn::TNN_OK) {
#ifdef LITETNN_DEBUG
        std::cout << status.description().c_str() << "\n";
#endif
        return -700;
    }
    // release
//    free(mat_data);
    status = net->DeInit();
    if (status != tnn::TNN_OK || !net) {
#ifdef LITETNN_DEBUG
        std::cout << "DeInit failed!" << status.description().c_str() << "\n";
#endif
        std::cout << "net->DeInit failed!\n";
        return -1000;
    }

    return 0;
}

 

标签:std,name,--,demo,c++,instance,blob,input,tnn
From: https://www.cnblogs.com/dxscode/p/16974069.html

相关文章

  • P4902 乘积 题解
    乘积给出\(A\),\(B\),求下面的式子的值.\[\prod_{i=A}^{B}\prod_{j=1}^{i}(\frac{i}{j})^{\left\lfloor\frac{i}{j}\right\rfloor}\(\bmod\19260817)\]包含\(T\)组......
  • ToDesk使用
    现在的终端产品种类非常的多,常见的包括tablet,手机,笔记本 ,ipod...等等,这些终端带屏产品连同台式机,智能电视等固定设备占据了我们的工作和生活中的大部分时间,不知道你发现......
  • Android手机应用开发之手机GPS定位
    最近在做Android手机应用开发,还是很有意思的。其实如果只是做简单手机应用开发而不是手机游戏开发的话,还是很简单的。把主要的控件掌握了,就可以开发简单的应用了。下面主要......
  • Microsoft 365 开发:如何通过Powershell更新OneDrive的管理员
    Blog链接:​​​https://blog.51cto.com/13969817​​又到了员工离职潮了,很多用户离开了组织另谋高就了,那么对于企业Office365的管理员而言,需要快速的将离开公司的员工信息......
  • 2021 届 字节跳动 校招提前批开始啦~~~
    又是一年校招季~~字节跳动的2021届校招提前批已经开始了~~因为疫情,今年各家公司的大规模校园宣讲多多少少都会受到影响。所以大家一定要特别关注校招的线上宣传活动,公众号......
  • 乔布斯语录:领袖和跟风者的区别在于创新
    资料图:​​乔布斯​​1、领袖和跟风者的区别就在于创新。创新无极限!只要敢想,没有什么不可能,立即跳出思维的框框吧。如果你正处于一个上升的朝阳行业,那么尝试去寻......
  • 脚本之一键安装单节点elasticsearch
    #!/bin/bashES_VERSION=7.17.5#ES_VERSION=7.9.3#ES_VERSION=7.6.2UBUNTU_URL="https://mirrors.tuna.tsinghua.edu.cn/elasticstack/7.x/apt/pool/main/e/elasticsearch/el......
  • 向乔布斯学习如何把用户体验做到极致
    北京时间10月6日消息,苹果董事会主席、联合创始人史蒂夫·乔布斯周三辞世,享年56岁。乔布斯的辞世,引起了IT界名人的关注。上海自然道公司总裁,有“多普达之父”之称的杨兴平(......
  • 不到30岁就挣下亿万身家的创业者们
    最近科技网站BusinessInsider对一些令人印象深刻的年轻创业者进行了盘点,他们发现,不少创业者在不满30岁身家就超过了1亿美元,同样有不少创业者在不满30岁就已经拥有一......
  • Zabbix 6 系列学习 04:容器方式安装
    本文会以两种环境介绍此安装方式,一种是基于Docker的方式,第二种是基于Podman的方式。Docker本文环境系统:Ubuntu22.04容器:Docker安装Dockersudoaptinstalldockerdoc......