首页 > 其他分享 >PyTorch中的dim

PyTorch中的dim

时间:2023-03-05 22:35:07浏览次数:37  
标签:dim tensor sum torch 维度 PyTorch print

PyTorch中对tensor的很多操作如sum,softmax等都可以设置dim参数用来指定操作在哪一维进行。PyTorch中的dim类似于numpy中的axis。

dim与方括号的关系

创建一个矩阵

a = torch.tensor([[1, 2], [3, 4]])
print(a)

输出:

tensor([[1, 2],
        [3, 4]])

因为a是一个矩阵,所以a的左边有2个括号

括号之间是嵌套关系,代表了不同的维度。从左往右数,两个括号代表的维度分别是0和1,在第0维遍历得到向量,在

第1维遍历得到标量

同样地,对于3维tensor

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)

输出

tensor([[[3, 2],
         [1, 4]],

        [[5, 6],
         [7, 8]]])

则3个括号代表的维度从左往右分别为0,1,2,在第0维遍历得到矩阵,在第1维遍历得到向量,在第2维遍历得到标量

更详细一点

在指定的维度上进行操作

在某一维度求和(或者进行其他操作)就是对该维度中的元素进行求和。

对于矩阵a

a = torch.tensor([[1, 2], [3, 4]])
print(a)

输出

tensor([[1, 2],
        [3, 4]])

求a在第0维的和,因为第0维代表最外边的括号,括号中的元素为向量 [1,2] , [3,4],第0维的和就是第0维中的元素相加,也就是两个向量 [1,2] , [3,4] 相加,所以结果为

[1 , 2 ] + [3 , 4 ] = [4 , 6]

s = torch.sum(a, dim=0)
print(s)

输出

tensor([4, 6])

可以看到,a是2维矩阵,而相加的结果为1维向量,可以使用参数keepdim = True来保证维度数目不变。

s = torch.sum(a, dim=0, keepdim=True)
print(s)

输出

tensor([[4, 6]])

在a的第0维求和,就是对第0维中的元素(向量)进行相加。同样的,对a第1维求和,就是对a第1维中的元素(标量)进行相加,a的第1维元素为标量1,2和3,4,则结果为

[1 + 2 ] = [3] ,[ 3 + 4 ] = [7]

s = torch.sum(a, dim=1)
print(s)

输出

tensor([3, 7])

保持维度不变

s = torch.sum(a, dim=1, keepdim=True)
print(s)

输出

tensor([[3],
        [7]])

对3维tensor的操作也是这样

b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)

输出

tensor([[[3, 2],
         [1, 4]],

        [[5, 6],
         [7, 8]]])

将b在第0维相加,第0维为最外层括号,最外层括号中的元素为矩阵[ [3 , 2], [1 , 4] ]和[ [5, 6] ,[7, 8] ]。在第0维求和,

就是将第0维的元素(矩阵)相加

s = torch.sum(b, dim=0)
print(s)

输出

tensor([[ 8,  8],
        [ 8, 12]])

求b在第1维的和,就是将b第1维中的元素[ 3, 2] 和[ 1 , 4 ],[ 5 , 6]和 [7 , 8 ]相加,所以

s = torch.sum(b, dim=1)
print(s)

输出

tensor([[ 4,  6],
        [12, 14]])

则在b的第2维求和,就是对标量3和2,1和4,5和6,7和8求和

s = torch.sum(b, dim=2)
print(s)

结果为

tensor([[ 5,  5],
        [11, 15]])

除了求和,其他操作也是类似的,如求b在指定维度上的最大值

m = torch.max(b, dim=0)
print(m)

b在第0维的最大值是第0维中的元素(两个矩阵[[3,2],[1,4]]和[[5,6],[7,8]])的最大值,取矩阵对应位置最大值即可

结果为

torch.return_types.max(
values=tensor([[5, 6],
        [7, 8]]),
indices=tensor([[1, 1],
        [1, 1]]))

b在第1维的最大值就是第1维元素(4个(2对)向量)的最大值

m = torch.max(b, dim=1)
print(m)

输出为

torch.return_types.max(
values=tensor([[3, 4],
        [7, 8]]),
indices=tensor([[0, 1],
        [1, 1]]))

b在第0维的最大值就是第0维元素(8个(4对)标量)的最大值

m = torch.max(b, dim=2)
print(m)

输出

torch.return_types.max(
values=tensor([[3, 4],
        [6, 8]]),
indices=tensor([[0, 1],
        [1, 1]]))

总结

在tensor的指定维度操作就是对指定维度包含的元素进行操作,如果想要保持结果的维度不变,设置参数keepdim = True即可。

 

原文链接:https://www.cnblogs.com/flix/p/11262606.html

标签:dim,tensor,sum,torch,维度,PyTorch,print
From: https://www.cnblogs.com/lusiqi/p/17181959.html

相关文章

  • 机器学习日志 手写数字识别 pytorch 神经网络
    我是链接第一次用pytorch写机器学习,不得不说是真的好用pytorch的学习可以看这里,看看基本用法就行,个人感觉主要还是要看着实践代码来学习总结了几个点:1.loss出现nan这......
  • numpy深度学习常用函数及参数理解(axis, keepdims)
    axis:以axis=0为例,则沿着第0个下标(最左边的下标)变化的方向进行操作,也就是将除了第0个下标外,其他两个下标都相同的部分分成一组,然后再进行操作例如一个3*3的二维数组A(3,......
  • 安装pytorch报错 ERROR: Could not install packages due to an OSError: [Errno 28]
    windos安装,报错如下  看了不少回答,大概是缓存和内存满了我的C盘只给了70G,然后意外发现只剩下3G多了,先用系统自带的清理工具清理了一下,然后腾讯电脑管家“工具箱”中......
  • pytorch_debug
    1、报错信息1.1、出错位置1image=Image.open('./img.png')2#图像预处理3transforms=transforms.Compose([transforms.Resize(256),4......
  • 【达梦】导入导出 dexp & dimp
    导入语句:./dimpUSERID=user_name/'"password"'@127.0.0.1:5237FILE=imp_exp.dmpDIRECTORY=/home/sudoroot/dameng/00-scriptsREMAP_SCHEMA=DEV:PRELOG=dev_imp.log......
  • 01 Pytorch的数据载体张量与线性回归
    Pytorch的数据载体张量与线性回归Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html?Pytorch中文文档:https://pytorch-cn.readthedocs.io/zh/latest/1.......
  • DL 基础:PyTorch 常用代码存档
    1pandas读csvimporttorchfromtorchimportnnimportnumpyasnpimportpandasaspdfromcopyimportdeepcopydevice="cuda"iftorch.cuda.is_available()......
  • pytorch和pyG安装
    操作系统:windows10显卡:GTX1650CUDA版本11.1下载安装CUDAToolkit11.1.0新建conda环境,python3.8condacreate-nGNNpython=3.8。激活condaactivateGNNwin下CUDA......
  • FCN(语义分割) Pytorch
    https://github.com/pytorch/vision/tree/main/torchvision/models/segmentation   ......
  • PyTorch
    PyTorch查看CUDA和CUDNN版本>>>importtorch>>>print(torch)<module'torch'from'/home/fanrui/code_pytorch/pytorch.master/torch/__init__.py'>>>>torch.__ver......