首页 > 其他分享 >[note] pytorch的几种维度操作方式比对

[note] pytorch的几种维度操作方式比对

时间:2023-08-28 16:47:28浏览次数:36  
标签:dim tensor reshape torch shape cat note pytorch 维度

pre

今天看代码在想 torch.unbind + torch.cattorch.reshape 的区别,直观上来看reshape似乎更便利。

chatgpt

问题

x is a tensor of three dimension, what is the difference between torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1))

回答

Both torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1)) are operations that reshape and reorganize a 3-dimensional tensor x, but they achieve this in slightly different ways.

  1. torch.cat(torch.unbind(x, dim=1), dim=0):

    • This operation involves two steps. First, it unbinds the tensor x along its second dimension (dim=1), which effectively splits the tensor into a list of tensors along that dimension. Each tensor in the list corresponds to one "column" from the original 3D tensor.
    • Then, it uses torch.cat to concatenate these split tensors along the specified dimension (dim=0), effectively stacking them on top of each other to create a new tensor with a reshaped layout.
    • The final shape of the tensor will be (x.shape[0] * x.shape[1], x.shape[2]).
  2. torch.reshape(x, (x.shape[0]*x.shape[1], -1)):

    • This operation directly reshapes the entire tensor x without splitting or concatenating it.
    • The desired shape is specified as a tuple (x.shape[0]*x.shape[1], -1), where -1 in the second dimension means that PyTorch will automatically calculate the size of that dimension to preserve the total number of elements in the tensor.
    • The final shape of the tensor will also be (x.shape[0] * x.shape[1], x.shape[2]).

In summary, both operations reshape the tensor x into a 2-dimensional tensor with the same number of elements, but they approach it differently. The first approach involves splitting and concatenating along dimensions, while the second approach directly reshapes the tensor. The end result will be the same in terms of shape and arrangement of elements.

测试

本来是相信chatgpt的,但后面越想越不对,做了如下实验:

import torch
from einops import rearrange

x = torch.randint(0, 100, (3, 2, 1))
print(x.squeeze())
# 对比四种不同的维度操作方式
a = torch.cat(torch.unbind(x, dim=1), dim=0)
b = torch.reshape(x, (x.shape[0]*x.shape[1], -1))
b2 = torch.reshape(x, (-1, x.shape[-1]))
c = x.view(x.shape[0]*x.shape[1], -1)
c2 = x.view(-1, x.shape[-1])
d = rearrange(x, 'b p c -> (b p) c')

# cat+unbind与其他三者结果均不同
print('a =', a.squeeze())
print('b =', b.squeeze())
print('b2 =', b2.squeeze())
print('c =', c.squeeze())
print('c2 =', c2.squeeze())
print('d =', d.squeeze())

# cat+unbind的结果(a)就无法像c一样用rearrange变回x
x2 = rearrange(c, '(b p) c -> b p c', b=3, p=2)
print(f'x==x2 = {(x==x2).squeeze()}')

输出:

tensor([[43, 84],
        [90, 80],
        [59, 23]])
a = tensor([43, 90, 59, 84, 80, 23])
b = tensor([43, 84, 90, 80, 59, 23])
b2 = tensor([43, 84, 90, 80, 59, 23])
c = tensor([43, 84, 90, 80, 59, 23])
c2 = tensor([43, 84, 90, 80, 59, 23])
d = tensor([43, 84, 90, 80, 59, 23])
x==x2 = tensor([[True, True],
        [True, True],
        [True, True]])
x-1 = tensor([[57, 19, 97, 12, 19, 24],
        [65, 71, 88, 40, 65, 46]])
x-2 = tensor([57, 19, 97, 12, 19, 24, 65, 71, 88, 40, 65, 46])
x-2 = tensor([[57, 19, 97, 12, 19, 24],
        [65, 71, 88, 40, 65, 46]])

ep

总的来说,假设x=[[97, 14], [ 0, 16], [55, 62]]torch.cat(torch.unbind(x, dim=1), dim=0)将x按列拆开然后拼合,得到[97, 0, 55, 14, 16, 62];而 reshape/view/rearrange则是将x按行拆开再拼合,得到[97, 14, 0, 16, 55, 62],该结果与torch.cat(torch.unbind(x, dim=0), dim=0)的一致

标签:dim,tensor,reshape,torch,shape,cat,note,pytorch,维度
From: https://www.cnblogs.com/Stareven233/p/17662708.html

相关文章

  • Lnton羚通视频算法算力云平台【PyTorch】教程:torch.nn.ELU
    在PyTorch中,torch.nn.ELU代表指数线性单元(ExponentialLinearUnit),是一种激活函数。ELU函数可以用来增加神经网络的非线性表达能力,使其具备更强的适应性。ELU函数的定义如下:elu(x)=xifx>=0alpha*(exp(x)-1)ifx<0其中,x是输入,alpha是一个正数超参数,控制ELU......
  • pytorch nn.LSTM模块参数详解
    nn.LSTM模块参数input_size:输入的维度hidden_size:h的维度num_layers:堆叠LSTM的层数,默认值为1bias:偏置,默认值:Truebatch_first:如果是True,则input为(batch,seq,input_size)。默认值为:False(seq_len,batch,input_size)bidirectional:是否双向传播,默认值为False 输入(in......
  • 带你上手基于Pytorch和Transformers的中文NLP训练框架
    本文分享自华为云社区《全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据》,作者:汀丶。1.简介目标:基于pytorch、transformers做中文领域的nlp开箱即用的训练框架,提供全套的训练、微调模型(包括大模型、文本转向量、文本生......
  • 简单的将pytorch模型部署到onnx
    1.创建一个pytorch模型这里我用的U2Net,直接加载好训练出的权重model=U2Net(class_nums=4)model.load_state_dict(torch.load(checkpoint_path))2.将pytorch模型转成onnx格式x=torcg.randn(1,3,512,512)withtorch.no_grad():torch.onnx.export(......
  • Leetcode 383. 赎金信(Ransom note)
    题目链接给你两个字符串:ransomNote和magazine,判断ransomNote能不能由magazine里面的字符构成。如果可以,返回true;否则返回false。magazine中的每个字符只能在ransomNote中使用一次。示例1:输入:ransomNote="a",magazine="b"输出:false示例2:输入:ransom......
  • 「Note」图论方向 - 网络流
    1.网络流1.1.定义1.1.1.网络网络是指一个有向图\(G=(V,E)\),每条边\((u,v)\inE\)有一个权值,\(c(u,v)\)称为容量,当\((u,v)\notinE\)时,有\(c(u,v)=0\)。特殊地,在图中有源点、汇点两点,分别为\(s\inV,t\inV\)。1.1.2.流设流函数\(f(u,v)\to\R(u,v\inV)\)表示......
  • win10 CUDA11.1安装torch1.9 / reformer_pytorch
    环境NVIDIA-SMI457.52DriverVersion:457.52CUDAVersion:11.1安装torch-gpucondacreate-ntorch1.9python=3.8pipinstalltorch==1.9.1+cu111torchvision==0.10.1+cu111torchaudio==0.9.1-fhttps://download.pytorch.org/whl/torch_stable.htmlc......
  • Endnote下载-Endnote正版下载 中文版直装
    EndNote绿色中文版可以帮助用户更好的管理文献文件,软件可以帮助用户一键获取PDF全文、管理个人参考文献、在线数据库查找、过滤器、参考书创建输出样式等功能。还可以协助用户更好的管理和研究参考书文献,具备了非常人性化的操作界面。软件地址:看置顶贴软件特色1、随处工作......
  • Endnote下载-Endnote电脑版下载 中文版直装
    软件特色组织参考文献并为您的研究论文创建参考书目,以简单的方式管理您的文档并更轻松地构建项目。团队访问一个参考库:从单个参考图书馆(最多100人)开始工作,无论他们身在何处或与他们有关联的组织。在几秒钟内完成多年的工作:搜索参考元数据,全文期刊文章,文件附件以及您的个人注释和注......
  • PyTorch 提高生产力的技巧
    推荐:使用NSDT场景编辑器助你快速搭建3D应用场景介绍您是否曾经花费数小时调试机器学习模型,但似乎找不到准确性没有提高的原因?你有没有觉得一切都应该完美地工作,但由于某种神秘的原因,你没有得到模范的结果?好吧,没有了。作为初学者探索PyTorch可能会令人生畏。在本文中,您将探索......