permute进行的是置换。 permute的dim需要参数表示进行置换的维度。 import torch import numpy as np x=np.arange(24).reshape((2,3,4)) x = torch.tensor(x) print(x) y=x.permute((2, 1, 0)) print(y) 由代码可知列和批进行了置换。每一列都对应了一个新批,而批又转为了列。 输出:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
tensor([[[ 0, 12],
[ 4, 16],
[ 8, 20]],
[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]]], dtype=torch.int32)