首页 > 编程语言 >TNN推理测试demo--c++


时间:2022-12-11 18:11:20浏览次数:42  
标签:std name -- demo c++ instance blob input tnn

// Created by DangXS on 2022/12/8.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <cfloat>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <sstream>
#include <string>
#include <iostream>

#include "tnn/core/common.h"
#include "tnn/core/instance.h"
#include "tnn/core/macro.h"
#include "tnn/core/tnn.h"
#include "tnn/utils/blob_converter.h"
#include "tnn/utils/cpu_utils.h"
#include "tnn/utils/data_type_utils.h"
#include "tnn/utils/dims_vector_utils.h"

#include "opencv2/opencv.hpp"

using namespace std;
using namespace cv;


#define DEVICE_X86     0x0010
#define DEVICE_ARM     0x0020
#define DEVICE_OPENCL  0x1000

//#define DEVICE DEVICE_X86

// static methods.
// reference: https://github.com/Tencent/TNN/blob/master/examples/base/utils/utils.cc
std::string content_buffer_from(const char *proto_or_model_path) {
    std::ifstream file(proto_or_model_path, std::ios::binary);
    if (file.is_open()) {
        file.seekg(0, std::ifstream::end);
        int size = file.tellg();
        char *content = new char[size];
        file.seekg(0, std::ifstream::beg);
        file.read(content, size);
        std::string file_content;
        file_content.assign(content, size);
        delete[] content;
        return file_content;
    } // empty buffer
    else {
        std::cout << "Can not open " << proto_or_model_path << "\n";
        return "";

// static methods.
tnn::DimsVector get_input_shape(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    tnn::DimsVector shape = {};
    tnn::BlobMap blob_map = {};
    if (_instance) {

    if (name == "" && blob_map.size() > 0)
        if (blob_map.begin()->second)
            shape = blob_map.begin()->second->GetBlobDesc().dims;

    if (blob_map.find(name) != blob_map.end()
        && blob_map[name]) {
        shape = blob_map[name]->GetBlobDesc().dims;

    return shape;

// static methods.
tnn::DimsVector get_output_shape(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    tnn::DimsVector shape = {};
    tnn::BlobMap blob_map = {};
    if (_instance) {

    if (name == "" && blob_map.size() > 0)
        if (blob_map.begin()->second)
            shape = blob_map.begin()->second->GetBlobDesc().dims;

    if (blob_map.find(name) != blob_map.end()
        && blob_map[name]) {
        shape = blob_map[name]->GetBlobDesc().dims;

    return shape;

// static methods.
std::vector<std::string> get_input_names(
        const std::shared_ptr<tnn::Instance> &_instance) {
    std::vector<std::string> names;
    if (_instance) {
        tnn::BlobMap blob_map;
        for (const auto &item : blob_map) {
    return names;

// static method
std::vector<std::string> get_output_names(
        const std::shared_ptr<tnn::Instance> &_instance) {
    std::vector<std::string> names;
    if (_instance) {
        tnn::BlobMap blob_map;
        for (const auto &item : blob_map) {
    return names;

// static method
tnn::MatType get_output_mat_type(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap output_blobs;
        auto blob = (name == "") ? output_blobs.begin()->second : output_blobs[name];
        if (blob->GetBlobDesc().data_type == tnn::DATA_TYPE_INT32) {
            return tnn::NC_INT32;
    return tnn::NCHW_FLOAT;

// static method
tnn::DataFormat get_output_data_format(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap output_blobs;
        auto blob = (name == "") ? output_blobs.begin()->second : output_blobs[name];
        return blob->GetBlobDesc().data_format;
    return tnn::DATA_FORMAT_NCHW;

// static method
tnn::MatType get_input_mat_type(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap input_blobs;
        auto blob = (name == "") ? input_blobs.begin()->second : input_blobs[name];
        if (blob->GetBlobDesc().data_type == tnn::DATA_TYPE_INT32) {
            return tnn::NC_INT32;
    return tnn::NCHW_FLOAT;

// static method
tnn::DataFormat get_input_data_format(
        const std::shared_ptr<tnn::Instance> &_instance,
        std::string name) {
    if (_instance) {
        tnn::BlobMap input_blobs;
        auto blob = (name == "") ? input_blobs.begin()->second : input_blobs[name];
        return blob->GetBlobDesc().data_format;
    return tnn::DATA_FORMAT_NCHW;

int main(int argc, char *argv[]) {
    const char *proto_path = "./1.tnnproto";
    const char *model_path = "./1.tnnmodel";
//    const char *proto_path = "./1.param";
//    const char *model_path = "./1.bin";
    // Note, tnn:: actually is TNN_NS::, I prefer the first one.
    std::shared_ptr<tnn::TNN> net;
    std::shared_ptr<tnn::Instance> instance;
    std::shared_ptr<tnn::Mat> input_mat; // assume single input.

    std::string proto_content_buffer, model_content_buffer;
    proto_content_buffer = content_buffer_from(proto_path);
    model_content_buffer = content_buffer_from(model_path);

    tnn::ModelConfig model_config;
    model_config.model_type = tnn::MODEL_TYPE_TNN;
//    model_config.model_type = tnn::MODEL_TYPE_NCNN;
    model_config.params = {proto_content_buffer, model_content_buffer};

    // 1. init TNN net
    tnn::Status status;
    net = std::make_shared<tnn::TNN>();
    status = net->Init(model_config);
    if (status != tnn::TNN_OK || !net) {
        std::cout << "CreateInst failed!\t" << status.description().c_str() << "\n";

        std::cout << "net->Init failed!\n";
        return -100;
    // 2. init device type, change this default setting
    // for better performance. such as CUDA/OPENCL/...
    auto _device = (tnn::DeviceType) DEVICE;
    tnn::DeviceType network_device_type = _device; // CPU,GPU
    tnn::DeviceType input_device_type = _device; // CPU only
    tnn::DeviceType output_device_type = _device;
    const unsigned int num_threads = 2;
    int input_batch;
    int input_channel;
    int input_height;
    int input_width;
    unsigned int input_value_size;
    tnn::DataFormat input_data_format;  // e.g DATA_FORMAT_NHWC
    tnn::MatType input_mat_type; // e.g NCHW_FLOAT
    // Actually, i prefer to hardcode the input/output names
    // into subclasses, but we just let the auto detection here
    // to make sure the debug information can show more details.
    std::string input_name; // assume single input only.
    std::vector<std::string> output_names; // assume >= 1 outputs.
    tnn::DimsVector input_shape; // vector<int>
    std::map<std::string, tnn::DimsVector> output_shapes;

    // 3. init instance
    tnn::NetworkConfig network_config;
    network_config.library_path = {""};
    network_config.device_type = network_device_type;

    instance = net->CreateInst(network_config, status);
    if (status != tnn::TNN_OK || !instance) {
        std::cout << "CreateInst failed!" << status.description().c_str() << "\n";
        return -200;
    // 4. setting up num_threads
    instance->SetCpuNumThreads((int) num_threads);
    // 5. init input information.
    input_name = get_input_names(instance).front();
    printf("input_name: %s\n", input_name.c_str());

    input_shape = get_input_shape(instance, input_name);
    printf("input_shape: \n\t");
    for (int i = 0; i < input_shape.size(); ++i) {
        printf(" %d", input_shape[i]);

    if (input_shape.size() != 4) {
        throw std::runtime_error("Found input_shape.size()!=4, but "
                                 "BasicTNNHandler only support 4 dims."
                                 "Such as NCHW, NHWC ...");
        return -400;
    input_mat_type = get_input_mat_type(instance, input_name);
    input_data_format = get_input_data_format(instance, input_name);

    printf("input_data_format: %d\n", input_data_format);
    input_batch = input_shape.at(0);
    input_channel = input_shape.at(1);
    input_height = input_shape.at(2);
    input_width = input_shape.at(3);

    // 6. init input_mat
    input_value_size = input_batch * input_channel * input_height * input_width;
    // 7. init output information, debug only.
    output_names = get_output_names(instance);
    int num_outputs = output_names.size();

    printf("output_names: \n\t");
    for (auto &name: output_names) {
        output_shapes[name] = get_output_shape(instance, name);
        printf("%s, ", name.c_str());

    // forward
    string img_path = "./1.jpg";
    cv::Mat mat = cv::imread(img_path);

    // In TNN: x*scale + bias
    std::vector<float> scale_vals = {
            (1.0f / 0.229f) * (1.0 / 255.f),
            (1.0f / 0.224f) * (1.0 / 255.f),
            (1.0f / 0.225f) * (1.0 / 255.f)
    std::vector<float> bias_vals = {
            -0.485f * 255.f * (1.0f / 0.229f) * (1.0 / 255.f),
            -0.456f * 255.f * (1.0f / 0.224f) * (1.0 / 255.f),
            -0.406f * 255.f * (1.0f / 0.225f) * (1.0 / 255.f)

    // 1. make input mat
    cv::Mat mat_rs;
    cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
    cv::cvtColor(mat_rs, mat_rs, cv::COLOR_BGR2RGB);

//    // Format Types: (0: NCHW FLOAT), (1: 8UC3), (2: 8UC1)
//    tnn::DataType data_type = tnn::DataType::DATA_TYPE_INT8;
//    int bytes = tnn::DimsVectorUtils::Count(input_shape) * tnn::DataTypeUtils::GetBytesSize(data_type);
//    void *mat_data = malloc(bytes);
//    input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::N8UC3, input_shape, mat_data);

    // push into input_mat (1,3,224,224)
    input_mat = std::make_shared<tnn::Mat>(input_device_type, tnn::N8UC3, input_shape, (void *) mat_rs.data);
    assert(input_mat->GetData()!= nullptr);

    // 2. set input_mat
    tnn::MatConvertParam input_cvt_param;
    input_cvt_param.scale = scale_vals;
    input_cvt_param.bias = bias_vals;

    status = instance->SetInputMat(input_mat, input_cvt_param);
    if (status != tnn::TNN_OK) {
        std::cout << status.description().c_str() << "\n";
        return -500;

    // 3. forward
    double c1 = cv::getTickCount();
    status = instance->Forward();
    double c2 = cv::getTickCount();
    auto spend_time = (c2 - c1) / cv::getTickFrequency();
    printf("forward elapsed: %f ms\n", spend_time);

    if (status != tnn::TNN_OK) {
        std::cout << status.description().c_str() << "\n";
        return -600;
    // 4. fetch.
    tnn::MatConvertParam cvt_param;
    std::vector<std::shared_ptr<tnn::Mat>> feat_mats;
    for (int i = 0; i < output_names.size(); ++i) {
        std::shared_ptr<tnn::Mat> feat_mat;
        status = instance->GetOutputMat(feat_mat, cvt_param, output_names[i], output_device_type);

    printf("net output: \n");
    for (int i = 0; i < output_names.size(); ++i) {
        printf("name: %s\t\tfeat shape: %d %d %d %d\n", output_names[i].c_str(),
               feat_mats[i]->GetBatch(), feat_mats[i]->GetChannel(), feat_mats[i]->GetHeight(),

        float *data = reinterpret_cast<float *>(feat_mats[i]->GetData());
        for (int j = 0; j < 10; ++j) {
            std::cout << data[j] << ", ";
        std::cout << std::endl;

    if (status != tnn::TNN_OK) {
        std::cout << status.description().c_str() << "\n";
        return -700;
    // release
//    free(mat_data);
    status = net->DeInit();
    if (status != tnn::TNN_OK || !net) {
        std::cout << "DeInit failed!" << status.description().c_str() << "\n";
        std::cout << "net->DeInit failed!\n";
        return -1000;

    return 0;


From: https://www.cnblogs.com/dxscode/p/16974069.html


  • P4902 乘积 题解
  • ToDesk使用
    现在的终端产品种类非常的多,常见的包括tablet,手机,笔记本 ,ipod...等等,这些终端带屏产品连同台式机,智能电视等固定设备占据了我们的工作和生活中的大部分时间,不知道你发现......
  • Android手机应用开发之手机GPS定位
  • Microsoft 365 开发:如何通过Powershell更新OneDrive的管理员
  • 2021 届 字节跳动 校招提前批开始啦~~~
  • 乔布斯语录:领袖和跟风者的区别在于创新
  • 脚本之一键安装单节点elasticsearch
  • 向乔布斯学习如何把用户体验做到极致
  • 不到30岁就挣下亿万身家的创业者们
  • Zabbix 6 系列学习 04:容器方式安装