Tensor的维度转置方法
在搭建神经网络的时候,经常会遇到需要交换维度的时候,比如将HWCN的Tensor维度顺序变换为NCHW顺序,此时需要用到Tensor的转置方法。
一般有以下三种方法:
1、numpy.transpose
如果Tensor是由np.Array转换而来,那么可以在变量还是np.Array的时候先进行通道转置,此时可以使用np.transpose方法:
>>> import numpy as np
>>> aa = np.ndarray((1,3,3,4))
>>> aa.shape
(1,3,3,4)
>>> aa.transpose((3,1,0,2)).shape
(4,3,1,3)
>>> np.transpose(aa,(3,1,0,2)).shape
(4,3,1,3)
arr.transpose(new_shape)和np.transpose(arr,new_shape)都合法,结果完全一样。
* 如果只是二维数组转置或者只交换第一和最后两个维度,那么也可以用arr.T方法:
>>> aa.T.shape
(4,3,3,1)
2、torch.tranpose
torch.transpose方法和np.transpose方法有一个最大的区别,torch.transpose只能支持两个维度的交换,函数原型为:
torch.transpose(tensor,dim0,dim1)
如果超过两个维度,会报错。使用方式为:
>>> aaTensor = torch.from_numpy(aa)
>>> aaTensor.transpose(0,3).shape
torch.Size([4,3,3,1])
>>> torch.transpose(aaTensor,0,3).shape
torch.Size([4,3,3,1])
>>> aaTensor.transpose(3,1,0,2).shape
Traceback (most recent call last):
File "<string>", line 1, in <module>
TypeError: transpose() received an invalid combination of arguments - got (int, int, int, int), but expected one of:
* (int dim0, int dim1)
* (name dim0, name dim1)
torch.transpose方法有一个后缀格式函数tensor.transpose_(),是transpose的inplace版本,调用该函数不返回结果,直接修改原始tensor的维度:
>>> aaTensor.transpose_(3,0)
>>> aaTesor.shape
torch.Size([4,3,3,1])
3、torch.permute
torch.permute用法和numpy.transpose完全相同,接受多个指定的维度,将输入Tensor的维度按照指定的维度顺序重排:
>>> torch.permute(aaTensor,3,1,0,2).shape
torch.Size([4,3,1,3])
>>>aaTensor.permute(3,1,0,2).shape
torch.Size([4,3,1,3])
注意torch.transpose、torch.permute、arr.transpose可接受tuple、list、多个整数作为输入,而numpy.transpose只能接受tuple和list。
标签:torch,transpose,np,shape,维度,numpy,aaTensor From: https://www.cnblogs.com/lumeng199/p/17691004.html