首页 > 其他分享 >pytorch-合并与分割

pytorch-合并与分割

时间:2023-07-31 10:24:40浏览次数:25  
标签:dim 分割 rand 32 torch 合并 shape pytorch Size

Merge or split
▪ Cat(合并)
▪ Stack(合并)
▪ Split(拆分)
▪ Chunk(拆分)

合并

cat

这个就是合并两个tensor
比如说有两个班级的成绩单,一个是1-4班的,一个是5-9班的,我们现在需要合并这两份成绩单。
▪ Statistics about scores
▪ [class1-4, students, scores]
▪ [class5-9, students, scores]

torch.cat([a,b],dim)
这个就是合并a,b两个tensor再第dim个维度上,需要注意的是除了dim这个维度,剩下的维度都要shape相等

image

a=torch.rand(4,32,8)
b=torch.rand(5,32,8)
torch.cat([a,b],dim=0).shape
# torch.Size([9, 32, 8])


a1=torch.rand(4,3,32,32)
a2=torch.rand(5,3,32,32)
torch.cat([a1,a2],dim=0).shape
# torch.Size([9, 3, 32, 32])

如果

a1=torch.rand(4,3,32,32)
a2=torch.rand(4,1,32,32)
torch.cat([a1,a2],dim=0).shape

由于除了dim=0以外dim=1的维度不相同,所以不行
image
但是假如说

torch.cat([a1,a2],dim=1).shape
# torch.Size([4, 4, 32, 32])
a1=torch.rand(4,3,16,32)
a2=torch.rand(4,3,16,32)
torch.cat([a1,a2],dim=2).shape
#  torch.Size([4, 3, 32, 32])
#  这个时候就是两张图片,上下拼接上

stack

这个也是合并

stack([a,b],dim)
不过这个合并和上一个不一样,这个合并会创造一个新的维度,比如[32,8]和[32,8],在dim=0维度进行合并的话是[2,32,8]。然后res[0,:,:]是第一个,res[1,:,:]是第二个。

a1=torch.rand(4,3,16,32)
a2=torch.rand(4,3,16,32)
torch.cat([a1,a2],dim=2).shape
# torch.Size([4, 3, 32, 32])


torch.stack([a1,a2],dim=2).shape
# torch.Size([4, 3, 2, 16, 32])

然后我们下面有一个具体的场景,就是一共有俩个班级,每个班级一共有32个学生,每个学生有8门课,进行合并。这个时候我们就不能利用cat了,因为cat合并的结果为[64,8]。而stack合并的结果为[2,32,8],这样更符合要求

a=torch.rand(32,8)
b=torch.rand(32,8)
torch.stack([a,b],dim=0).shape
# torch.Size([2, 32, 8])

但是需要注意的是除了要合并的那个维度,其余的维度都要相等。

拆分

split

这个是通过长度进行拆分的。

.split([,],dim)
[,]里面的参数是最终的得到的tensor的dim上的维度,比如说c=[3,32,8],
aa,bb=c.split([2,1],dim=0),
那么aa.shape=[2,32,8],bb.shape=[1,32,8]。返回的该维度上的值的和要等于目标函数在该维度上的值
或者说也可以直接给顶一个len,这个就是返回的最终函数的维度都一样,比如说c=[2,32,8],也可以用c.split(1,dim=0).

c=torch.randn(2,32,8)
a,b=c.split(1,dim=0)
a.shape,b.shape
#(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))



cc=torch.rand(3,32,8)
aa,bb=cc.split([2,1],dim=0)
aa.shape,bb.shape
#(torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))

但是需要注意的是这个函数只能拆分成两个,不如说c=[3,32,8],aa,bb,cc=c.split(1,dim=0)这个是不行的
image

chunk

这个是按照数量进行拆分,这个可以返回多个

---=.chunk(num,dim)
这个就是最终拆分成num个,比如说[6,32,8],num=2,那就是最终拆分成两个,返回[3,32,8],[3,32,8]、如果num=3,那就是最终拆分成三个,返回[2,32,8],[2,32,8],[2,32,8]

c=torch.randn(6,32,8)
aa,bb=c.chunk(2,dim=0)
aa.shape,bb.shape
# (torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))


aaa,bbb,ccc=c.chunk(3,dim=0)
aaa.shape,bbb.shape,ccc.shape
# (torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))

image

标签:dim,分割,rand,32,torch,合并,shape,pytorch,Size
From: https://www.cnblogs.com/lipu123/p/17592238.html

相关文章

  • IJCAI 2023 | 腾讯优图实验室入选论文解读,含小样本学习方法、玻璃物体分割、RSI变化检
    前言 近日,IJCAI2023(InternationalJointConferenceonArtificialIntelligence)国际人工智能联合大会公布了录用结果。本届会议共有4566篇投稿,接收率为15%。作为当前全球最负盛名的AI学术会议之一,IJCAI将于今年8月在澳门举行。本文转载自腾讯优图仅用于学术分享,若侵权请联......
  • ChatGPT炒股:自动批量提取股票公告中的表格并合并数据
    在很多个股票公告中,都有同样格式的“日常性关联交易”的表格,如何合并到一张Excel表格中呢?首先,在ChatGPT中输入提示词:写一段Python代码:F盘文件夹“新三板2023年日常性关联交易20230704”中很多个PDF文件,用Tabula提取这些PDF文件中第1页中的第2个表格,然后保存到表格文件中,文件标题......
  • pytorch索引与切片
    indexinga=torch.randn(4,3,28,28)a[0].shape#torch.Size([3,28,28])a[0,0].shape#torch.Size([28,28])a[0,0,2,4]#tensor(0.6574)selectfirst/lastN这个a可以看成一个图片:[batch,RBG,h,w]a.shape#torch.Size([4,3,28,28])a[:2].shape#torch.Size......
  • 因子分解机介绍和PyTorch代码实现
    因子分解机(FactorizationMachines,简称FM)是一种用于解决推荐系统、回归和分类等机器学习任务的模型。它由SteffenRendle于2010年提出,是一种基于线性模型的扩展方法,能够有效地处理高维稀疏数据,并且在处理特征组合时表现出色。它是推荐系统的经典模型之一,并且模型简单、可解释性强,......
  • 图注意力网络论文详解和PyTorch实现
    前言 图神经网络(gnn)是一类功能强大的神经网络,它对图结构数据进行操作。它们通过从节点的局部邻域聚合信息来学习节点表示(嵌入)。这个概念在图表示学习文献中被称为“消息传递”。本文转载自P**nHub兄弟网站作者|EbrahimPichka仅用于学术分享,若侵权请联系删除欢迎关注公......
  • PyTorch的数据类型
    python和pytorch中的类型对比:我们可以发现pytorch中每中类型后面都有一个Tensor。但是很遗憾PyTorch没有String类型。我们在做NLP的时候会遇到String类型处理的问题,我们会将string转化问数值:one-hot[0,1,0,0,....]Embeddingword2vecglove1Datatype我们需要注......
  • 基于wsl2在container中利用conda安装pytorch环境
    ###一、利用conda创建一个新的环境参考命令condacreate-nENV_NAMEpython=X.X•-nENV_NAME指定环境名称•python=X.X指定要创建的Python版本,比如python=3.6使用命令:condacreate-npytorch1.13python=3.8参考资料-Anacondaconda常用命令:从入门到精通:https://......
  • PyTorch基础知识-新手笔记
    NumPy与TensorTensor为神经网络界的NumPy,与NumPy相似。相同之处:二者均可共享内存,它们之间的转换非常方便和高效。不同之处:NumPy会把ndarray放在CPU中加速。  Tensor会把ndarray放在GPU中加速。PyTorch中的Tensor可以是零维(又称为标量或一个数)、一维、二维及多维的数组。标量(s......
  • PYTHON mysql形成分割文件
    importrandom,string,re,time,sys,hashlib,pymysql,requestsf=open("aa.txt","w")connect=pymysql.connect(user='root',password='123456',db='new',host......
  • LeetCode 热题 100 之 56. 合并区间
    题目以数组intervals表示若干个区间的集合,其中单个区间为intervals[i]=[starti,endi]。请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好覆盖输入中的所有区间 。示例1:输入:intervals=[[1,3],[2,6],[8,10],[15,18]]输出:[[1,6],[8,10],[15,18]]解......