首页 > 其他分享 >torch.cat() 与 torch.stack() 的区别

torch.cat() 与 torch.stack() 的区别

时间:2022-12-28 18:12:56浏览次数:35  
标签:tensor torch cat print stack size

目录



1. torch.cat()

torch.cat(tensors, dim=0)

在给定维度中拼接张量序列。

参数:

  • tensors:张量序列。
  • dim:拼接张量序列的维度。
import torch

a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.cat((a, b))
print(a.size(), b.size(), c.size())
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([4, 3])

可以看出,\(a、b、c\) 都是二维。


张量序列必须具有相同大小:

d = torch.rand(2, 4)
print(torch.cat((a, d)))
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 4 for tensor number 1 in the list.

具体拼接:

print(a)
print(torch.cat((a, a, a), dim=0))
print(torch.cat((a, a, a), dim=1))
tensor([[0.2381, 0.7100, 0.8150],
        [0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150],
        [0.5190, 0.5829, 0.9186],
        [0.2381, 0.7100, 0.8150],
        [0.5190, 0.5829, 0.9186],
        [0.2381, 0.7100, 0.8150],
        [0.5190, 0.5829, 0.9186]])
tensor([[0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150, 0.2381, 0.7100, 0.8150],
        [0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186, 0.5190, 0.5829, 0.9186]])



2. torch.stack()

torch.stack(tensors, dim=0)

沿新维度拼接张量。

参数:

  • tensors:张量序列
  • dim:要插入的维度。
import torch

a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
print(a.size(), b.size(), c.size())
torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 2, 3])

可以看出,\(a、b\) 是二维,\(c\) 是三维。


张量序列必须具有相同大小:

d = torch.rand(2, 4)
print(torch.stack((a, d)))
RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [2, 4] at entry 1

具体拼接:

x = torch.arange(1, 7).reshape((3, 2))
y = torch.arange(10, 70, 10).reshape((3, 2))
z = torch.arange(100, 700, 100).reshape((3, 2))
print(x)
print(y)
print(z)
tensor([[1, 2],
        [3, 4],
        [5, 6]])
tensor([[10, 20],
        [30, 40],
        [50, 60]])
tensor([[100, 200],
        [300, 400],
        [500, 600]])
m = torch.stack((x,y,z))
print(m)
tensor([[[  1,   2],
         [  3,   4],
         [  5,   6]],

        [[ 10,  20],
         [ 30,  40],
         [ 50,  60]],

        [[100, 200],
         [300, 400],
         [500, 600]]])
n = torch.stack((x,y,z), 1)
print(n)
tensor([[[  1,   2],
         [ 10,  20],
         [100, 200]],

        [[  3,   4],
         [ 30,  40],
         [300, 400]],

        [[  5,   6],
         [ 50,  60],
         [500, 600]]])
h = torch.stack((x,y,z), 2)
print(h)
tensor([[[  1,  10, 100],
         [  2,  20, 200]],

        [[  3,  30, 300],
         [  4,  40, 400]],

        [[  5,  50, 500],
         [  6,  60, 600]]])


标签:tensor,torch,cat,print,stack,size
From: https://www.cnblogs.com/keye/p/17010928.html

相关文章