首页 > 其他分享 >PyTorch | torch.sum()函数的用法

PyTorch | torch.sum()函数的用法

时间:2023-02-25 21:13:29浏览次数:48  
标签:dim tensor sum torch print PyTorch 维度

torch.sum()对输入的tensor数据的某一维度求和,一共两种用法。

方法1详解

torch.sum(input, *, dtype=None) → Tensor
  • input:输入的张量
案例
x = torch.randn(2, 3)
print(x)
y = torch.sum(x)
print(y)

输出结果:

tensor([[-0.2328,  1.4580,  0.7448],
        [-0.7813,  0.3045, -1.9038]])
tensor(-0.4107)
# -0.2328+1.4580+0.7448-0.7813+0.3045-1.9038 = -0.41059999999999963

方法2详解

torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor
  • input:输入的张量
  • dim:求和的维度,可以是一个列表,也就是可以同时接收多个维度,并可同时在这些维度上进行指定操作。
  • keepdim:默认为False,若keepdim=True,则返回的Tensor除dim之外的维度与input相同。因为求和之后这个dim的元素个数为1,所以要被去掉,如果要保留这个维度,则应当keepdim=True。
案例
x = torch.arange(0, 12).view(3, 4)
print(x)
y1 = torch.sum(x, dim=1)
print(y1)
y2 = torch.sum(x, dim=0)
print(y2)
y3 = torch.sum(x, dim=0, keepdim=True)
print(y3)

输出结果:

# x
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
# y1
tensor([ 6, 22, 38])
# y2
tensor([12, 15, 18, 21])
# y3
tensor([[12, 15, 18, 21]])

对二维和三维进行举例

例子一:对二维的list进行sum()操作。

# 如果dim=1,则按行求和;如果dim=0,则按列求和
a = torch.ones((2, 3))
print(a):
tensor([[1, 1, 1],
 		[1, 1, 1]])

a1 =  torch.sum(a)
a2 =  torch.sum(a, dim=0)
a3 =  torch.sum(a, dim=1)

print(a)
print(a1)
print(a2)

输出结果:

tensor(6.)
tensor([2., 2., 2.])
tensor([3., 3.])

如果加上keepdim=True, 则会保持dim的维度不被squeeze

a1 =  torch.sum(a, dim=(0, 1), keepdim=True)
a2 =  torch.sum(a, dim=(0, ), keepdim=True)
a3 =  torch.sum(a, dim=(1, ), keepdim=True)

输出结果:

tensor([[6.]])
tensor([[2., 2., 2.]])
tensor([[3., 3.]])

例子二:对三维的list进行sum()操作。

# 32块,每块4行,每行256列
a = torch.ones((32, 4,256))
a

输出结果:

# 对第二个维度进行sum()操作
a1 = torch.sum(a,dim=1)
a1

输出结果:

对比a和a1的维度:

a1.shape
# torch.Size([32, 256])

a.shape
# torch.Size([32, 4, 256])

说明对第一个维度进行sum()操作,把每一块中的几行对应的相加起来了,然后每一块只剩下一行,所以去掉了一个维度。

说明:对更高维进行某一个或者多个维度相加,我们想要理解就按照上面这两个例子进行代码的一行行的执行,观察数据的变化和维度的改变。

标签:dim,tensor,sum,torch,print,PyTorch,维度
From: https://www.cnblogs.com/zhangxuegold/p/17155371.html

相关文章