首页 > 其他分享 >pytorch的unsqueeze

pytorch的unsqueeze

时间:2022-12-24 08:22:05浏览次数:41  
标签:unsqueeze tensor torch shape pytorch print import

就是在指定维度前再插入一个新的维度。

import torch import numpy as np x=np.arange(24).reshape((2,3,4)) x = torch.tensor(x) print(x) y=x.permute((2, 0, 1)) print(y.shape) print(y.unsqueeze(0).shape) print(y.unsqueeze(0).float())

tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
torch.Size([4, 2, 3])
torch.Size([1, 4, 2, 3])

标签:unsqueeze,tensor,torch,shape,pytorch,print,import
From: https://www.cnblogs.com/hahaah/p/17001975.html

相关文章