PyTorch中对tensor的很多操作如sum,softmax等都可以设置dim参数用来指定操作在哪一维进行。PyTorch中的dim类似于numpy中的axis。
dim与方括号的关系
创建一个矩阵
a = torch.tensor([[1, 2], [3, 4]]) print(a)
输出:
tensor([[1, 2], [3, 4]])
因为a是一个矩阵,所以a的左边有2个括号
括号之间是嵌套关系,代表了不同的维度。从左往右数,两个括号代表的维度分别是0和1,在第0维遍历得到向量,在
第1维遍历得到标量
同样地,对于3维tensor
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]]) print(b)
输出
tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
则3个括号代表的维度从左往右分别为0,1,2,在第0维遍历得到矩阵,在第1维遍历得到向量,在第2维遍历得到标量
更详细一点
在指定的维度上进行操作
在某一维度求和(或者进行其他操作)就是对该维度中的元素进行求和。
对于矩阵a
a = torch.tensor([[1, 2], [3, 4]]) print(a)
输出
tensor([[1, 2], [3, 4]])
求a在第0维的和,因为第0维代表最外边的括号,括号中的元素为向量 [1,2] , [3,4],第0维的和就是第0维中的元素相加,也就是两个向量 [1,2] , [3,4] 相加,所以结果为
[1 , 2 ] + [3 , 4 ] = [4 , 6]
s = torch.sum(a, dim=0) print(s)
输出
tensor([4, 6])
可以看到,a是2维矩阵,而相加的结果为1维向量,可以使用参数keepdim = True来保证维度数目不变。
s = torch.sum(a, dim=0, keepdim=True) print(s)
输出
tensor([[4, 6]])
在a的第0维求和,就是对第0维中的元素(向量)进行相加。同样的,对a第1维求和,就是对a第1维中的元素(标量)进行相加,a的第1维元素为标量1,2和3,4,则结果为
[1 + 2 ] = [3] ,[ 3 + 4 ] = [7]
s = torch.sum(a, dim=1) print(s)
输出
tensor([3, 7])
保持维度不变
s = torch.sum(a, dim=1, keepdim=True) print(s)
输出
tensor([[3], [7]])
对3维tensor的操作也是这样
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]]) print(b)
输出
tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
将b在第0维相加,第0维为最外层括号,最外层括号中的元素为矩阵[ [3 , 2], [1 , 4] ]和[ [5, 6] ,[7, 8] ]。在第0维求和,
就是将第0维的元素(矩阵)相加
s = torch.sum(b, dim=0) print(s)
输出
tensor([[ 8, 8], [ 8, 12]])
求b在第1维的和,就是将b第1维中的元素[ 3, 2] 和[ 1 , 4 ],[ 5 , 6]和 [7 , 8 ]相加,所以
s = torch.sum(b, dim=1) print(s)
输出
tensor([[ 4, 6], [12, 14]])
则在b的第2维求和,就是对标量3和2,1和4,5和6,7和8求和
s = torch.sum(b, dim=2) print(s)
结果为
tensor([[ 5, 5], [11, 15]])
除了求和,其他操作也是类似的,如求b在指定维度上的最大值
m = torch.max(b, dim=0) print(m)
b在第0维的最大值是第0维中的元素(两个矩阵[[3,2],[1,4]]和[[5,6],[7,8]])的最大值,取矩阵对应位置最大值即可
结果为
torch.return_types.max( values=tensor([[5, 6], [7, 8]]), indices=tensor([[1, 1], [1, 1]]))
b在第1维的最大值就是第1维元素(4个(2对)向量)的最大值
m = torch.max(b, dim=1) print(m)
输出为
torch.return_types.max( values=tensor([[3, 4], [7, 8]]), indices=tensor([[0, 1], [1, 1]]))
b在第0维的最大值就是第0维元素(8个(4对)标量)的最大值
m = torch.max(b, dim=2) print(m)
输出
torch.return_types.max( values=tensor([[3, 4], [6, 8]]), indices=tensor([[0, 1], [1, 1]]))
总结
在tensor的指定维度操作就是对指定维度包含的元素进行操作,如果想要保持结果的维度不变,设置参数keepdim = True即可。
原文链接:https://www.cnblogs.com/flix/p/11262606.html
标签:dim,tensor,sum,torch,维度,PyTorch,print From: https://www.cnblogs.com/lusiqi/p/17181959.html