首页 > 其他分享 >PyTorch学习(8):PyTorch中Tensor的合并于拆分(torch.cat, torch.stack, torch.trunk, torch.split)

PyTorch学习(8):PyTorch中Tensor的合并于拆分(torch.cat, torch.stack, torch.trunk, torch.split)

时间:2024-05-30 17:32:53浏览次数:33  
标签:Tensor torch 张量 PyTorch 拼接 维度 stack 函数

1. 写在前面

        在使用PyTorch执行深度学习开发时,经常会用到对Tensor的合并于拆分操作。如我们在使用CSP时,有时候会需要将Tensor拆分成两部分,其中一部分进行进行Cross Stage操作,另一部分执行多重卷积操作,这个时候我们就会用到四个典型的接口,分别是torch.cat, torch.stack, torch.trunk, torch.split。接下来将逐一进行讲解。

2. torch.cat函数

        cat函数用于在指定的维度上将一系列张量拼接在一起。所有输入张量必须在拼接维度上具有相同的大小,或者至少在拼接维度上具有相同的大小,否则会引发错误。如果输入张量在拼接维度上的大小不同,则它们必须在拼接维度上具有相同的大小,否则会引发错误。

示例:
        假设有两个形状为(2, 3)的张量a和b,可以使用以下代码将它们沿着第二维连接:

a = torch.rand(2, 3, 2)

b = torch.rand(2, 3, 2)

result = torch.cat((a, b), dim=1)

        输出结果将是一个形状为(2, 6, 2)的张量。

3. torch.stack函数

        stack函数与cat函数类似,但它在拼接的同时,在指定维度处插入一个新的维度。可以理解为stack是在指定维度处,分别为两个维度数据加上一层[]后,再进行拼接。stack函数要求所有输入张量在拼接维度上的大小必须完全相同。

示例:
        假设有两个形状为(2, 5)的张量a和b,可以使用以下代码将它们沿着第二维堆叠在一起:

a = torch.rand(2, 5)

b = torch.rand(2, 5)

result = torch.stack((a, b), dim=1)

        输出结果将是一个形状为(2, 2, 5)的张量。

4. torch.trunk函数

        在PyTorch中,trunk函数并不是一个内置函数。可能是指torch.chunk函数,它用于将一个张量分割成多个较小的张量。chunk函数在指定维度上按照平均分配的方式进行分割,如果分割的总数不能整除当前维度的大小,则最后一部分可能会小于其他部分。

示例:
        假设有一个形状为(8, 4)的张量d,可以使用以下代码将其分割为四个大小相等的张量:

d = torch.randn(8, 4)

chunks = torch.chunk(d, chunks=4, dim=0)

        输出结果将是一个包含四个张量的列表,每个张量的形状为(2, 4)。

5. torch.split函数

        split函数类似于chunk函数,但它允许按特定方案进行分割,而不仅仅是按份数均匀分割。split函数接受一个分割方案,该方案是一个列表,指示了如何将张量分割成多个部分。

示例:
        假设有一个形状为(8, 4)的张量d,可以使用以下代码按照指定的分割方案进行分割:

d = torch.randn(8, 4)

section = [1, 2, 1, 2, 2]

result = torch.split(d, section, dim=0)

        输出结果将是一个包含五个张量的列表,每个张量的大小取决于section列表中的元素。

标签:Tensor,torch,张量,PyTorch,拼接,维度,stack,函数
From: https://blog.csdn.net/tecsai/article/details/139329445

相关文章