1 前言
记录一下Pytorch
中对tensor
合并分割的方法
2 合并
Pytorch
中对tensor
合并的方法有两种:
torch.cat()
torch.stack()
其中,torch.cat()
直接将两个变量进行拼接,不会产生新的维度
而torch.stack()
则会将tensor
堆叠,产生新的维度
tensor1 = torch.randn(2,3)
tensor2 = torch.randn(2,3)
print(tensor1)
print(tensor2)
# out:
tensor([[ 1.3124, -0.6630, -1.1289],
[-0.0913, 0.7382, 0.4581]])
tensor([[-0.8929, -1.3781, -0.6344],
[-0.0994, 0.5217, -2.2306]])
tensor_cat = torch.cat([tensor1,tensor2])
print(f"tensor_out:{tensor_cat}")
print(f"size of tensor_out:{tensor_cat.size()}")
tensor_stack = torch.stack([tensor1,tensor2])
print(f"tensor_stack:{tensor_stack}")
print(f"size of tensor_stack:{tensor_stack.size()}")
# out
tensor_out:tensor([[ 1.3124, -0.6630, -1.1289],
[-0.0913, 0.7382, 0.4581],
[-0.8929, -1.3781, -0.6344],
[-0.0994, 0.5217, -2.2306]])
size of tensor_out:torch.Size([4, 3])
tensor_stack:tensor([[[ 1.3124, -0.6630, -1.1289],
[-0.0913, 0.7382, 0.4581]],
[[-0.8929, -1.3781, -0.6344],
[-0.0994, 0.5217, -2.2306]]])
size of tensor_stack:torch.Size([2, 2, 3])
而torch.vstack
能够完成与torch.cat
一样的效果
torch.vstack
能够按顺序垂直(行)堆叠张量
tensor_vstack = torch.vstack([tensor1,tensor2])
print(f"tensor_vstack:{tensor_vstack}")
print(f"size of tensor_vstack:{tensor_vstack.size()}")
# out:
tensor_vstack:tensor([[ 1.3124, -0.6630, -1.1289],
[-0.0913, 0.7382, 0.4581],
[-0.8929, -1.3781, -0.6344],
[-0.0994, 0.5217, -2.2306]])
size of tensor_vstack:torch.Size([4, 3])
而torch.hstack
则是能够按水平顺序堆叠张量(按列)
tensor_hstack = torch.hstack([tensor1,tensor2])
print(f"tensor_hstack:{tensor_hstack}")
print(f"size of tensor_hstack:{tensor_hstack.size()}")
# out:
tensor_hstack:tensor([[ 1.3124, -0.6630, -1.1289, -0.8929, -1.3781, -0.6344],
[-0.0913, 0.7382, 0.4581, -0.0994, 0.5217, -2.2306]])
size of tensor_hstack:torch.Size([2, 6])
3 分割
Pytorch
中对tensor
合并的方法有两种:
torch.split()
torch.chunk()
其中,split
将tensor
拆分为多块,每个块都是原始tensor
视图
chunk
则是按照dim
将tensor
分割为chunks
个tensor
块,返回块的元组
def split( tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0 ) -> Tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`split_size`. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according to :attr:`split_size_or_sections`. Args: tensor (Tensor): tensor to split. split_size_or_sections (int) or (list(int)): size of a single chunk or list of sizes for each chunk dim (int): dimension along which to split the tensor.
torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:
input (Tensor) – the tensor to split
chunks (int) – number of chunks to return
dim (int) – dimension along which to split the tensor
split
:
tensor = torch.randn(10).reshape(5,2)
print(f"tensor:{tensor}")
torch.split(tensor,2)
# out:
tensor:tensor([[ 0.9619, 0.6095],
[-1.8024, -0.1534],
[ 1.7452, 0.4705],
[-0.8512, 0.3175],
[-0.0290, -0.1422]])
(tensor([[ 0.9619, 0.6095],
[-1.8024, -0.1534]]),
tensor([[ 1.7452, 0.4705],
[-0.8512, 0.3175]]),
tensor([[-0.0290, -0.1422]]))
torch.split(tensor,[2,3])
# out:
(tensor([[-1.5071, -0.0346],
[-0.6429, 0.5917]]),
tensor([[ 0.2722, 0.3824],
[ 0.6135, 0.7926],
[-0.5771, -0.4590]]))
chunk
:
torch.chunk(tensor, 2 ,dim=1)
# out:
(tensor([[-1.5071],
[-0.6429],
[ 0.2722],
[ 0.6135],
[-0.5771]]),
tensor([[-0.0346],
[ 0.5917],
[ 0.3824],
[ 0.7926],
[-0.4590]]))
torch.chunk(tensor, 2 ,dim=0)
# out:
(tensor([[-1.5071, -0.0346],
[-0.6429, 0.5917],
[ 0.2722, 0.3824]]),
tensor([[ 0.6135, 0.7926],
[-0.5771, -0.4590]]))