首页 > 其他分享 >torch.stack 堆叠函数帮助理解多维数组

torch.stack 堆叠函数帮助理解多维数组

时间:2024-08-19 10:51:58浏览次数:7  
标签:dim torch 张量 堆叠 维度 stack

概论

在 PyTorch 中,torch.stack 函数用于在指定的维度上将一组张量堆叠起来。这个操作会在指定维度上创建一个新的维度,并将输入张量在该维度上进行堆叠。假设有两个形状相同的张量 ab,它们的形状都是 (2, 3, 4),那么在不同的 dim 参数下使用 torch.stack 会产生不同的结果。

以下是对这四种情况的解释:

  1. c = torch.stack([a, b], dim=0)

    • dim=0 的位置上创建一个新的维度。
    • 原始张量的形状为 (2, 3, 4),堆叠后形状变为 (2, 3, 4) 前加上一个新的维度,形状变为 (2, 2, 3, 4)
    • 堆叠后张量 c 的形状为 (2, 3, 4)
    • 可以理解为把 ab 堆叠在第一个维度上,结果的第一个维度表示堆叠的张量数目。
  2. d = torch.stack([a, b], dim=1)

    • dim=1 的位置上创建一个新的维度。
    • 原始张量的形状为 (2, 3, 4),堆叠后形状变为 (2, 2, 3, 4)
    • 堆叠后张量 d 的形状为 (2, 2, 3, 4)
    • 可以理解为在第二个维度上插入一个新的维度,使每个原始张量的第一维度内的每个元素都变为包含两个子元素的张量。
  3. e = torch.stack([a, b], dim=2)

    • dim=2 的位置上创建一个新的维度。
    • 原始张量的形状为 (2, 3, 4),堆叠后形状变为 (2, 3, 2, 4)
    • 堆叠后张量 e 的形状为 (2, 3, 2, 4)
    • 这表示在第三个维度上创建新的维度,每个原始张量的前两个维度内的每个元素都变为包含两个子元素的张量。
  4. f = torch.stack([a, b], dim=3)

    • dim=3 的位置上创建一个新的维度。
    • 原始张量的形状为 (2, 3, 4),堆叠后形状变为 (2, 3, 4, 2)
    • 堆叠后张量 f 的形状为 (2, 3, 4, 2)
    • 这表示在第四个维度上创建新的维度,每个原始张量的前三个维度内的每个元素都变为包含两个子元素的张量。

总结来说,torch.stack 会在指定的 dim 维度上插入一个新的维度,使得原始张量在这个维度上堆叠起来。新的张量的形状将会比原始张量多一个维度,且堆叠方向对应于 dim 所指定的位置。

一言以蔽之

在指定维度 dim 上使用 torch.stack 堆叠时,会在该维度插入一个新的维度,使得原始张量在 dim 之前的所有维度的每个元素都变为包含堆叠张量数目的子元素的张量。

解读

这句话的意思是:

当你在指定的维度 dim 上使用 torch.stack 函数时,它会在张量的这个位置插入一个新的维度。这个新的维度会将原始张量沿着 dim 之前的所有维度中的每个元素扩展,使得这些元素现在在新增的维度上包含多个(等于你堆叠的张量个数)子元素。

具体来说,假设你有多个形状相同的张量,当你在某个维度 dim 上堆叠它们时,堆叠后的新张量在 dim 之前的每一个维度上的每个元素都会新增一个维度,用来存放你堆叠的张量。这些子元素的数量就是你堆叠的张量数目。

例如:

  • dim=0:在第一个维度上堆叠,那么结果张量的第一个维度的大小就是堆叠的张量个数,每个子张量会在这个新的维度中排列。
  • dim=1:在第二个维度上堆叠,那么结果张量的第二个维度的大小就是堆叠的张量个数,原来第一维度的每个元素现在会包含多个子元素。
  • 以此类推,堆叠的位置决定了新的维度插入在哪里,以及原张量如何被扩展。

再探讨

若一个张量 a 形状是 (2, 3, 4),堆叠代码:torch.stack([a, a], dim=?)

当 dim=0 时,由于是第0维度,前面没有了,故把整个张量看作一个元素堆叠。

当 dim=1 时,前面有维度 2, 在这里插入新维度会把 (3, 4) 看作一个元素,进行堆叠。

当 dim=2 时,前面有维度(2, 3),这时会把(4, )看作一个元素,进行堆叠。

当 dim=3 时, 前面有维度(2, 3, 4),这是会把每一个元素看做一个元素,进行堆叠。

看代码,再读上文。

a = torch.randn(2, 3, 4)
a
tensor([[[ 0.4964, -0.2426, -0.4883, -0.9112],
         [ 0.2928,  1.8061, -0.0770, -0.2761],
         [-0.1384,  0.5872,  0.1957,  1.4741]],

        [[-1.1077,  1.0878,  0.4793,  0.9741],
         [ 2.0282,  0.7055, -0.0954, -0.3203],
         [-0.7217, -1.1332,  0.0738, -0.8602]]])
b= torch.stack([a, a], dim=0)
b # 把整个a看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
          [ 0.2928,  1.8061, -0.0770, -0.2761],
          [-0.1384,  0.5872,  0.1957,  1.4741]],

         [[-1.1077,  1.0878,  0.4793,  0.9741],
          [ 2.0282,  0.7055, -0.0954, -0.3203],
          [-0.7217, -1.1332,  0.0738, -0.8602]]],


        [[[ 0.4964, -0.2426, -0.4883, -0.9112],
          [ 0.2928,  1.8061, -0.0770, -0.2761],
          [-0.1384,  0.5872,  0.1957,  1.4741]],

         [[-1.1077,  1.0878,  0.4793,  0.9741],
          [ 2.0282,  0.7055, -0.0954, -0.3203],
          [-0.7217, -1.1332,  0.0738, -0.8602]]]])

b= torch.stack([a, a], dim=1)
b # 把a的(3, 4)部分看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
          [ 0.2928,  1.8061, -0.0770, -0.2761],
          [-0.1384,  0.5872,  0.1957,  1.4741]],

         [[ 0.4964, -0.2426, -0.4883, -0.9112],
          [ 0.2928,  1.8061, -0.0770, -0.2761],
          [-0.1384,  0.5872,  0.1957,  1.4741]]],


        [[[-1.1077,  1.0878,  0.4793,  0.9741],
          [ 2.0282,  0.7055, -0.0954, -0.3203],
          [-0.7217, -1.1332,  0.0738, -0.8602]],

         [[-1.1077,  1.0878,  0.4793,  0.9741],
          [ 2.0282,  0.7055, -0.0954, -0.3203],
          [-0.7217, -1.1332,  0.0738, -0.8602]]]])

b= torch.stack([a, a], dim=2)
b #把a的(4,)部分看作一个元素,进行堆叠
tensor([[[[ 0.4964, -0.2426, -0.4883, -0.9112],
          [ 0.4964, -0.2426, -0.4883, -0.9112]],

         [[ 0.2928,  1.8061, -0.0770, -0.2761],
          [ 0.2928,  1.8061, -0.0770, -0.2761]],

         [[-0.1384,  0.5872,  0.1957,  1.4741],
          [-0.1384,  0.5872,  0.1957,  1.4741]]],


        [[[-1.1077,  1.0878,  0.4793,  0.9741],
          [-1.1077,  1.0878,  0.4793,  0.9741]],

         [[ 2.0282,  0.7055, -0.0954, -0.3203],
          [ 2.0282,  0.7055, -0.0954, -0.3203]],

         [[-0.7217, -1.1332,  0.0738, -0.8602],
          [-0.7217, -1.1332,  0.0738, -0.8602]]]])
b= torch.stack([a, a], dim=3)
b # 把a的每个元素看作一个元素进行堆叠
tensor([[[[ 0.4964,  0.4964],
          [-0.2426, -0.2426],
          [-0.4883, -0.4883],
          [-0.9112, -0.9112]],

         [[ 0.2928,  0.2928],
          [ 1.8061,  1.8061],
          [-0.0770, -0.0770],
          [-0.2761, -0.2761]],

         [[-0.1384, -0.1384],
          [ 0.5872,  0.5872],
          [ 0.1957,  0.1957],
          [ 1.4741,  1.4741]]],


        [[[-1.1077, -1.1077],
          [ 1.0878,  1.0878],
          [ 0.4793,  0.4793],
          [ 0.9741,  0.9741]],

         [[ 2.0282,  2.0282],
          [ 0.7055,  0.7055],
          [-0.0954, -0.0954],
          [-0.3203, -0.3203]],

         [[-0.7217, -0.7217],
          [-1.1332, -1.1332],
          [ 0.0738,  0.0738],
          [-0.8602, -0.8602]]]])

再总结

torcha.stack 把dim=? 指定插入维度后,把原有维度以插入维度为起点,看作一个整体,做为一个堆叠元素,进行堆叠。
例如,有 a 形状为 (2, 5, 8, 3) ,当dim=1时,对a进行切片以(5,8,3)为一个元素,进行堆叠。
torch.stack([a, a], dim=1) 的结果形状为:(2,2,5,8,3)

标签:dim,torch,张量,堆叠,维度,stack
From: https://www.cnblogs.com/litifeng/p/18366887

相关文章

  • 深度学习-pytorch-basic-001
    importtorchimportnumpyasnptorch.manual_seed(1234)<torch._C.Generatorat0x21c1651e190>defdescribe(x):print("Type:{}".format(x.type()))print("Shape/Size:{}".format(x.shape))print("Values:{}"......
  • PyTorch深度学习实战(18)—— 可视化工具
    在训练神经网络时,通常希望能够更加直观地了解训练情况,例如损失函数曲线、输入图片、输出图片等信息。这些信息可以帮助读者更好地监督网络的训练过程,并为参数优化提供方向和依据。最简单的办法就是打印输出,这种方式只能打印数值信息,不够直观,同时无法查看分布、图片、声音等......
  • 零基础学习人工智能—Python—Pytorch学习(五)
    前言上文有一些文字打错了,已经进行了修正。本文主要介绍训练模型和使用模型预测数据,本文使用了一些numpy与tensor的转换,忘记的可以第二课的基础一起看。线性回归模型训练结合numpy使用首先使用datasets做一个数据X和y,然后结合之前的内容,求出y_predicted。#pipinstallmatp......
  • PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析
    文章目录前言完整代码代码解析1.导入库2.设备配置3.超参数设置4.数据集加载5.数据加载器6.定义BiRNN模型7.实例化模型并移动到设备8.损失函数和优化器9.训练模型10.测试模型11.保存模型常用函数前言本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNI......
  • HarmonyOS 层叠布局:(Stack)打造灵活多变的UI界面
    在应用开发中,布局设计是用户体验的关键之一。而在HarmonyOS中,层叠布局(Stack)是一种极为灵活的布局方式。它允许我们在同一个区域内放置多个组件,并根据需求将它们叠加起来,形成丰富的视觉效果。无论是广告展示还是卡片叠加效果,层叠布局都能轻松胜任。今天,我将带大家深入了解Stack......
  • 用pytorch实现LeNet-5网络
     上篇讲述了LeNet-5网络的理论,本篇就试着搭建LeNet-5网络。但是搭建完成的网络还存在着问题,主要是训练的准确率太低,还有待进一步探究问题所在。是超参数的调节有问题?还是网络的结构有问题?还是哪里搞错了什么1.库的导入dataset:datasets.MNIST()函数,该函数作用是导入MNIST数......
  • 面试题:在Java中,JVM(Java虚拟机)的内存模型是如何设计的?请详细解释堆(Heap)、栈(Stack)、方法
    面试题:在Java中,JVM(Java虚拟机)的内存模型是如何设计的?请详细解释堆(Heap)、栈(Stack)、方法区(MethodArea)以及程序计数器(ProgramCounterRegister)的作用和它们之间的关系。更多答案在这里,手机或电脑浏览器就可以打开,面霸宝典【全拼音】.com这里可以优化简历,模拟面试,企业项......
  • EVAT: Electric Vehicle Adoption Tools - Tech Stack Overview
    EVAT:ElectricVehicleAdoptionTools-TechStackOverviewIntroductionWelcometotheEVATproject!ThisdocumentprovidesanoverviewofthetechnologiesweareusingtodevelopourElectricVehicleAdoptionTools.Ourgoalistocreateauser-friendlyp......
  • PyTorch--实现循环神经网络(RNN)模型
    文章目录前言完整代码代码解析导入必要的库设备配置超参数设置数据集加载数据加载器定义RNN模型实例化模型并移动到设备损失函数和优化器训练模型测试模型保存模型小改进神奇的报错ValueError:LSTM:Expectedinputtobe2Dor3D,got4Dinstead前言首先,这篇......
  • 掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
    PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:1、torch.matmultorch.matmul是PyTorch中用于矩阵乘法的函数。它能够处理各种不同维度的张量,并根据张量的维度自动调整其操作方式。torch......