首页 > 其他分享 >Pytorch:合并分割

Pytorch:合并分割

时间:2024-06-19 16:11:38浏览次数:31  
标签:分割 tensor chunk torch 合并 Pytorch split out size

1 前言

记录一下Pytorch中对tensor合并分割的方法

2 合并

Pytorch中对tensor合并的方法有两种:
torch.cat()
torch.stack()

其中,torch.cat()直接将两个变量进行拼接,不会产生新的维度
torch.stack()则会将tensor堆叠,产生新的维度

tensor1 = torch.randn(2,3)
tensor2 = torch.randn(2,3)
print(tensor1)
print(tensor2)
# out:
tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581]])
tensor([[-0.8929, -1.3781, -0.6344],
       [-0.0994,  0.5217, -2.2306]])
 
tensor_cat = torch.cat([tensor1,tensor2])
print(f"tensor_out:{tensor_cat}")
print(f"size of tensor_out:{tensor_cat.size()}")
tensor_stack = torch.stack([tensor1,tensor2])
print(f"tensor_stack:{tensor_stack}")
print(f"size of tensor_stack:{tensor_stack.size()}")

# out
tensor_out:tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581],
        [-0.8929, -1.3781, -0.6344],
        [-0.0994,  0.5217, -2.2306]])
size of tensor_out:torch.Size([4, 3])
tensor_stack:tensor([[[ 1.3124, -0.6630, -1.1289],
         [-0.0913,  0.7382,  0.4581]],

        [[-0.8929, -1.3781, -0.6344],
         [-0.0994,  0.5217, -2.2306]]])
size of tensor_stack:torch.Size([2, 2, 3])

torch.vstack能够完成与torch.cat一样的效果
torch.vstack能够按顺序垂直(行)堆叠张量

tensor_vstack = torch.vstack([tensor1,tensor2])
print(f"tensor_vstack:{tensor_vstack}")
print(f"size of tensor_vstack:{tensor_vstack.size()}")

# out:
tensor_vstack:tensor([[ 1.3124, -0.6630, -1.1289],
        [-0.0913,  0.7382,  0.4581],
        [-0.8929, -1.3781, -0.6344],
        [-0.0994,  0.5217, -2.2306]])
size of tensor_vstack:torch.Size([4, 3])

torch.hstack则是能够按水平顺序堆叠张量(按列)

tensor_hstack = torch.hstack([tensor1,tensor2])
print(f"tensor_hstack:{tensor_hstack}")
print(f"size of tensor_hstack:{tensor_hstack.size()}")

# out:
tensor_hstack:tensor([[ 1.3124, -0.6630, -1.1289, -0.8929, -1.3781, -0.6344],
        [-0.0913,  0.7382,  0.4581, -0.0994,  0.5217, -2.2306]])
size of tensor_hstack:torch.Size([2, 6])

3 分割

Pytorch中对tensor合并的方法有两种:
torch.split()
torch.chunk()

其中,splittensor拆分为多块,每个块都是原始tensor视图

chunk则是按照dimtensor分割为chunkstensor块,返回块的元组

def split( tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0 ) -> Tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`split_size`. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according to :attr:`split_size_or_sections`. Args: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or list of sizes for each chunk dim (int): dimension along which to split the tensor.
torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:
    input (Tensor) – the tensor to split
    chunks (int) – number of chunks to return
    dim (int) – dimension along which to split the tensor

split:

tensor = torch.randn(10).reshape(5,2)
print(f"tensor:{tensor}")
torch.split(tensor,2)

# out:
tensor:tensor([[ 0.9619,  0.6095],
        [-1.8024, -0.1534],
        [ 1.7452,  0.4705],
        [-0.8512,  0.3175],
        [-0.0290, -0.1422]])

(tensor([[ 0.9619,  0.6095],
         [-1.8024, -0.1534]]),
 tensor([[ 1.7452,  0.4705],
         [-0.8512,  0.3175]]),
 tensor([[-0.0290, -0.1422]]))

torch.split(tensor,[2,3])

# out:
(tensor([[-1.5071, -0.0346],
         [-0.6429,  0.5917]]),
 tensor([[ 0.2722,  0.3824],
         [ 0.6135,  0.7926],
         [-0.5771, -0.4590]]))

chunk:

torch.chunk(tensor, 2 ,dim=1)

# out:
(tensor([[-1.5071],
         [-0.6429],
         [ 0.2722],
         [ 0.6135],
         [-0.5771]]),
 tensor([[-0.0346],
         [ 0.5917],
         [ 0.3824],
         [ 0.7926],
         [-0.4590]]))

torch.chunk(tensor, 2 ,dim=0)

# out:
(tensor([[-1.5071, -0.0346],
         [-0.6429,  0.5917],
         [ 0.2722,  0.3824]]),
 tensor([[ 0.6135,  0.7926],
         [-0.5771, -0.4590]]))

4 Ref

  1. https://aiaer.blog.csdn.net/article/details/125086792?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-1-125086792-blog-108471904.235^v43^pc_blog_bottom_relevance_base6&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-1-125086792-blog-108471904.235^v43^pc_blog_bottom_relevance_base6&utm_relevant_index=2

标签:分割,tensor,chunk,torch,合并,Pytorch,split,out,size
From: https://www.cnblogs.com/liuliu55/p/18256456

相关文章

  • Day 26| 39. 组合总和 、 40.组合总和II 、 131.分割回文串
    组合总和本题是集合里元素可以用无数次,那么和组合问题的差别其实仅在于startIndex上的控制题目链接/文章讲解:https://programmercarl.com/0039.组合总和.html视频讲解:https://www.bilibili.com/video/BV1KT4y1M7HJ给定一个无重复元素的数组candidates和一个目标数targ......
  • 点云分割网络PointConv
    PDF:《PointConv:DeepConvolutionalNetworkson3DPointClouds》CODE:https://github.com/DylanWusee/pointconv一、大体内容PointConv是一种在非均匀采样下对3D点云进行卷积的运算,可以用来构建深度卷积网络,其将卷积核视为由权重函数和密度函数组成的三维点的局部坐标的非......
  • Win11+Miniconda3+python3.9安装pyspark+pytorch
    Win11+Miniconda3+python3.9安装pyspark+pytorch步骤1:安装Miniconda3,具体可以百度或者google步骤2:安装好Miniconda3之后,要创建虚拟环境,类似于虚拟机的样子,然后在虚拟环境安装各种python包已经装好了pytorch,具体步骤可以参考网上的一些教程,很多时候要综合多个教程,比如说先建立......
  • java freemarker实现单元格动态合并
    在Java项目中,使用FreeMarker模板引擎来动态生成Excel文件,并实现单元格的动态合并(特别是行合并)。可以通过以下步骤来完成:1.准备数据模型        需要准备一个合适的数据模型,该模型应能表示出哪些单元格需要合并。        例如,如果想要根据某一列的值来决定......
  • 如何将keil5中的bin文件合并
    前言    最近有个需求,需要把单片机中的两个bin文件合并成一个bin文件,方便板子在生产烧录代码阶段可以节约烧录次数,这两个文件一般指的是BOOT+APP文件,bin文件里面没带有地址信息,但是在单片机中的烧录文件需要定位起始地址,所以就需要特别注意它们的偏移地址。因为可能会......
  • Pytorch数据加载与使用
    前言在训练的时候通常使用Dataset来处理数据集。Dataset的作用提供一个方式获取数据内容和标签(label)。实战fromtorch.utils.dataimportDatasetfromPILimportImageimportosclassget_data(Dataset):def__init__(self,root_dir,label_dir):self.r......
  • Pytorch入门(一):MNIST-手写数字识别-搭建网络模型
    前言作为刚入门深度学习的一位初学者来说,各种各样的学习资料、视频让我看得头昏眼花。明明本来是想了解Pytorch使用方法,莫名其妙看了一个多小时的算法推理,原理逻辑,让人很是苦恼。于是在自己学习了一段时间后,打算做出这个pytorch的系列教程,就是想让大家基于项目进行实战,更......
  • PyTorch与TensorFlow模型互转指南
    在深度学习的领域中,PyTorch和TensorFlow是两大广泛使用的框架。每个框架都有其独特的优势和特性,因此在不同的项目中选择使用哪一个框架可能会有所不同。然而,有时我们可能需要在这两个框架之间进行模型的转换,以便于在不同的环境中部署或利用两者的优势。本文将详细介绍如何......
  • vue中的代码分割
    随着Web应用的日益复杂化,用户对页面加载速度的期望越来越高。在这种背景下,前端性能优化成为了开发者们必须面对的挑战。Vue.js,作为现代前端开发的首选框架之一,其轻量级和灵活性为构建高性能的Web应用提供了可能。然而,即使是使用Vue.js,如果不进行适当的优化,应用的加载时间和......
  • pytorch使用交叉熵训练模型学习笔记
    python代码:importtorchimporttorch.nnasnnimporttorch.optimasoptim#定义一个简单的神经网络模型classSimpleModel(nn.Module):def__init__(self):super(SimpleModel,self).__init__()self.fc=nn.Linear(3,2)#输入3维,输出2类......