首页 > 编程语言 >c++ 部署libtorch时对Tensor块的常用操作API

c++ 部署libtorch时对Tensor块的常用操作API

时间:2022-12-14 15:11:06浏览次数:79  
标签:std tensor cout dim torch c++ libtorch Tensor

一、前言

使用pytorch可以很方便地训练网络,并且pytorch的官方网站中给出了很全的python对tensor的操作接口API,但是在部署libtorch的时候,c++对tensor的操作接口API资料甚少,因此,本文旨在整理部署libtorch的时候,操作tensor块常用的接口API,有很多和python的接口类似。

二、部署libtorch常用接口

1.tensor与vector的转换方法

1.1 tensor 转 vector

at::Tensor t=at::ones({2,2},at::kInt);//建立一个2X2的tensor
vector<int> v(t.data_ptr<int>(),t.data_ptr<int>()+t.numel());//将tensor转换为vector

t是一个类型为at::kInt的tensor,其中kInt可以用其他数据类型替换如kFloat等,t.data_ptr<int>()返回int类型的指针,返回的地址是数据存储的起始位置。t.numel()返回t中的元素个数。

1.2 vector 转 tensor

vector<int> v={1,2,3,4};
at::TensorOptions opts=at::TensorOptions().dtype(at::kInt);
c10::IntArrayRef s={2,2};//设置返回的tensor的大小
at::Tensor t=at::from_blob(v.data(),s,opts).clone();

opts用来对返回的tensor作一些额外的解释,例如类型。s用来指定返回的tensor的维度。clone是为了深复制,让后面对t的操作不会收到v的影响。v.data()返回vector中的数据的指针。  

2. torch::tensor.sizes() 返回{c10::ArrayRef} 类型

获取其中元素时,需要使用下标[]操作。

3. torch::empty(), torch::ones()

torch::Tensor a = torch::empty({2, 4}, at::kCUDA); // a 在cuda上
torch::Tensor a = torch::empty({2, 4}, at::kLong); // a 数据类型long
torch::Tensor a = torch::empty({2, 4}, at::device(at::kCUDA).dtype(at::kLong())); // cuda上数据类型long
torch::Tensor a = torch::empty({2, 4}, at::device({at::kCUDA, 1})); //在gpu1傻瓜
torch::Tensor a = torch::empty({2, 4}, at::requires_grad());//需要梯度
std::cout << a << std::endl;
torch::Tensor b = torch::ones({2, 4});
std::cout << b<< std::endl;
结果:
 1.7171e+10  4.5796e-41  9.7406e-08  3.0815e-41
 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00
[ CPUFloatType{2,4} ]
 1  1  1  1
 1  1  1  1
[ CPUFloatType{2,4} ]

4. torch::squeeze() 在指定dim增加维度

torch::squeeze(const Tensor & self, int64_t dim)

std::cout << a << std::endl;
std::cout << torch::unsqueeze(a, 1) << std::endl;

结果:
 1.0000
 1.4142
 0.7071
[ CPUFloatType{3} ]
 1.0000
 1.4142
 0.7071
[ CPUFloatType{3,1} ]

5. * 或者torch::mul 乘法操作

a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法。 torch::mul与 *用法相同。

std::cout << a << std::endl;
std::cout << b << std::endl;
std::cout << a * b <<std::endl;
结果:
 21.0000
 30.7409
[ CPUFloatType{2,1} ]
1.0000  1.4142  0.7071
[ CPUFloatType{1,3} ]
21.0000  29.6985  14.8492
30.7409  43.4741  21.7371
[ CPUFloatType{2,3} ]

6. + 或者torch:: add 乘法操作

5和6都是按位加或者乘,这里就用到了广播机制(broadcastable),1、相同size的一定可以广播;2、每个tensor至少有一个dim;3、遍历对比两个tensor尺寸大小时,两个tensor对应位置上的dim要么相同,要么其中一个是1,要么不存在,如果都不是,说明两个tensor不能广播。
例如下面:
尺寸{1, 4, 4} 和 尺寸{3, 1,4}进行遍历比较:1、1和3有一个是1;2、4和1有一个是1;3、4和4相同,因此tensor a 和 b可以广播

torch::Tensor a = torch::randint(0, 10, {1, 4, 4});
std::cout <<  a << std::endl;
torch::Tensor b = torch::randint(0, 10, {3, 1, 4});
std::cout <<  b << std::endl;
std::cout << a+b << std::endl;
//tensor a
(1,.,.) = 
  7  1  9  8
  3  5  9  6
  0  4  7  6
  0  6  2  2
[ CPUFloatType{1,4,4} ]
//tensor b
(1,.,.) = 
  4  4  5  3
(2,.,.) = 
  6  4  9  9
(3,.,.) = 
  7  5  3  6
[ CPUFloatType{3,1,4} ]
//tensor a + b
(1,.,.) = 
  11   5  14  11
   7   9  14   9
   4   8  12   9
   4  10   7   5
(2,.,.) = 
  13   5  18  17
   9   9  18  15
   6   8  16  15
   6  10  11  11
(3,.,.) = 
  14   6  12  14
  10  10  12  12
   7   9  10  12
   7  11   5   8
[ CPUFloatType{3,4,4} ]

7. torch::max,tensor数据块中每个维度上,求最大值以及最大值下标

torch::Tensor a = torch::randn({3,4});
cout << a << endl;
std::tuple<torch::Tensor, torch::Tensor> max_classes = torch::max(a, 1);
auto max_1= std::get<0>(max_classes);
auto max_index= std::get<1>(max_classes);
cout << max_1 << endl;
cout << max_index << endl;
结果
tensor a 的值为:
[0.4388 -0.8234  0.3935  0.0000
0.0121  1.0354  0.0000  1.5286
 0.1590  2.7148 -0.0737 -0.5168]
 [ CPUFloatType{3,4} ]

max_1 是最大值,值为:
 0.3935
 1.5286
 2.7148 
 
max_index是最大值的下标,值为:
[ CPUFloatType{3} ]
 2
 3
 1
[ CPULongType{3} ]

8. torch::nonzero获取tensor数据块中非零数的下标,返回下标

torch::Tensor a = torch::randn({3,4});
a[2][3] = 0;
a[1][2] = 0;
cout << a << endl;
auto b = torch::nonzero(a);
 cout << b << endl;
 结果:
 **tensor a 的值为:**
-0.4388 -0.8234  0.3935  0.0000
 0.0121  1.0354  0.0000  1.5286
 0.1590  2.7148 -0.0737 -0.5168
[ CPUFloatType{3,4} ]

**tensor b 的 值是非0值对应的下标,如下:**
 0  0
 0  1
 0  2
 1  0
 1  1
 1  3
 2  0
 2  1
 2  2
 2  3
[ CPULongType{10,2} ]

9. > 大于符号,tensor数据中的每一位做比较,返回对应位置的bool值

torch::Tensor a = torch::randn({3,4});
cout << ( a > 0.1) << endl;
结果:
tensor a 的值为:
-0.4388 -0.8234  0.3935  0.0000
 0.0121  1.0354  0.0000  1.5286
 0.1590  2.7148 -0.0737 -0.5168
[ CPUFloatType{3,4} ]

 0  0  1  0
 0  1  0  1
 1  1  0  0
[ CPUBoolType{3,4} ]

10. Tensor::where

where 的作用 与 比较符号( >, =, <等)+ nonzero的作用相同

//where返回的是tensor a中数值大于10的坐标位置,where的返回结果是vector,vector的size和a的dims相同。
torch::Tensor a = torch::randint(1, 30, {8});
std::cout << a << std::endl;
std::vector<torch::Tensor> index = torch::where( a >= 10);
std::cout << index.size() << std::endl;
std::cout << index[0] <<std::endl;

a结果为:
 14
 23
 28
 21
 20
 21
 14
  1
[ CPUFloatType{8} ]

index.size() = a.dims = 1

index[0]结果为:
 0
 1
 2
 3
 4
 5
 6
[ CPULongType{7} ]

11.Tensor::slice,取tensor数据块中roi区域的数据,不复制

inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step)

12. torch::stack() 增加新的维度进行堆叠

static inline Tensor stack(TensorList tensors, int64_t dim)
这里的dim是增加哪个维度,如果dim=0,结果tensor的sizes为 4x6
如果dim=1,结果tensor的sizes为6x4

std::cout << a << std::endl;
std::cout << b << std::endl;
torch::Tensor c = torch::stack({a, b}, 1);
std::cout << c << std::endl;
[ 21.0000
 29.6985
 14.8492
 30.7409
 43.4741
 21.7371
[ CPUFloatType{6} ]
 21.0000
 14.8492
 29.6985
 30.7409
 21.7371
 43.4741
[ CPUFloatType{6} ]
 21.0000  21.0000
 29.6985  14.8492
 14.8492  29.6985
 30.7409  30.7409
 43.4741  21.7371
 21.7371  43.4741
[ CPUFloatType{6,2} ]

13. torch::cat() 对数据沿着某一维度进行拼接。cat后数据的总维数不变.

在下面的例子中,只能在dim=0上进行cat,因为a和b只有一个维度,dim=1将报错,因为cat不能增加维度。
这里的dim是沿着哪个维度进行拼接。

std::cout << a << std::endl;
std::cout << b << std::endl;
torch::Tensor c = torch::cat({a, b}, 0);
std::cout << c << std::endl;

[ 21.0000
 29.6985
 14.8492
 30.7409
 43.4741
 21.7371
[ CPUFloatType{6} ]
 21.0000
 14.8492
 29.6985
 30.7409
 21.7371
 43.4741
 [ CPUFloatType{6} ]
 21.0000
 29.6985
 14.8492
 30.7409
 43.4741
 21.7371
 21.0000
 14.8492
 29.6985
 30.7409
 21.7371
 43.4741
[ CPUFloatType{12} ]

14. permute,交换维度, 适合多维数据,更灵活的transpose

torch::Tensor x = torch::randn({2,3,4});
std::cout << x.sizes() << std::endl;
torch::Tensor x_p = x.permute({1,0,2}); //将原来第1维变为0维,同理,0→1,2→2
std::cout << x_p.sizes() << std::endl;

[2, 3, 4]
[3, 2, 4]

15. toType Tensor数据类型转换

cx = cx.toType(torch::kFloat);

16. split将某个维度进行拆分,select的批量操作

std::cout << a << std::endl;
std::vector<torch::Tensor>> b = torch::split(a, 1, 1)
for(int i = 0; i < b.size(); i++)
{
  std::cout << b[i] <<std::endl;
 }
 
a 的结果为:
 159.2578   96.9697  346.3506  260.4342
 159.6481   96.2015  345.8403  261.2210
[ CUDAFloatType{2,4} ]

b中的元素依次为:
 159.2578
 159.6481
[ CUDAFloatType{2,1} ]
 96.9697
 96.2015
[ CUDAFloatType{2,1} ]
 346.3506
 345.8403
[ CUDAFloatType{2,1} ]
 260.4342
 261.2210
[ CUDAFloatType{2,1} ]

  

  

参考:

c++ 部署libtorch 时对Tensor块的常用操作API (例如vector转换为torch、squeeze、select、select_indxe、max等)

 

  

  

  

  

  

  

  

  

  

  

  

 

 

标签:std,tensor,cout,dim,torch,c++,libtorch,Tensor
From: https://www.cnblogs.com/carsonzhu/p/16982238.html

相关文章

  • c++字符串和字符数组互相转
    voidmain(){//字符串转字符数组stringname="ddddd";charbuf[]={0};strcpy(buf,name.c_str());//字符串转字符数组,使用strcpycout<<n......
  • C++基础篇之什么是数据类型
       ......
  • C++获取文件信息
    官方文档:https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/stat-functions?view=msvc-170推荐使用_tstat方法,会根据编译环境选择适合的方法。该方法会......
  • C++高校人员信息管理系统
    C++高校人员信息管理系统题目1高校人员信息管理系统设计1问题描述某高校有四类员工:教师、实验员、行政人员、教师兼行政人员:共有的信息包括编号、姓名、性别、年龄......
  • 「REMAKE C++」Day 0
    Day0继「REMAKEDP」以后居然是「REMAKEC++」QWQ前言今天(2022.12.13)突然想起了一句话,programmer学习就应该边做边学,而不是学了再做。因此开坑此系列,直到完结......
  • 哈夫曼树课程实验报告--C++
    主要内容:  设计一个哈夫曼树,建立函数输入二叉树,并输出哈夫曼树。实现哈夫曼树的初始化并打印,进一步理解哈夫曼树的逻辑结构和存储结构,进一步提高使用理论知识指导解决实际......
  • C++ sort
    sort是C++STL(Standardtemplatelibrary)提供的排序函数,它可以为普通函数或者容器中指定范围内的元素进行排序。默认为升序排序,排序的区间左闭右开。使用标准库(STL)提......
  • C++ 捕获程序异常奔溃minidump
    文章目录​​1、简介​​​​1.1MiniDumpWriteDump函数​​​​1.2VisualStudio分析小型转储​​​​1.3使用Microsoft公共符号服务器​​​​1.4使用WinDbg调试小......
  • ASCII码对照表(C++17 实现ANSI、UTF8、Unicode字符编码互转)
    文章目录​​1、简介​​​​1.1ANSI​​​​1.2Unicode​​​​1.3UTF8​​​​1.4Littleendian/Bigendian​​​​2、C++标准库​​​​2.1MultiByteToWideChar......
  • C++ OpenSSL库编译及使用(VS2017,Python)
    文章目录​​1、简介​​​​2、下载openssl​​​​2.1下载openssl源码(github)​​​​2.2下载openssl源码(官网压缩包)​​​​2.3下载openssl二进制库​​​​2.4下载pe......