如何将list转为tensor
在遇到需要将list转换为tensor的情况时,往往不能直接转换,而是需要借助
torch.cat
方法进行。为防止需要的时候找不到教程,本文给出示例进行该操作。
操作方法
问题
给定对于数据 x
和 y
。x
的形状为 (2, 3, 4)
,表示 batch_size为2,每个batch有3个向量,每个向量维数是4
, y
的形状为 (2, 4)
,表示 batch_size为2,每个batch内仅有一个4维的向量
。现在目标是将同一个 batch
内 x
的3个向量分别和 y
的1个向量进行拼接,得到一个形状为 (2, 3, 4)
的数据。
代码
首先随机地生成 x
和 y
import torch
x = torch.rand((2, 3, 4))
y = torch.rand((2, 4))
print(x)
print(y)
输出如下:
tensor([[[0.3170, 0.5800, 0.2717, 0.3887],
[0.0862, 0.4881, 0.1419, 0.1491],
[0.1860, 0.4508, 0.2637, 0.9106]],[[0.0923, 0.1211, 0.8768, 0.7573], [0.9067, 0.0651, 0.2780, 0.6712], [0.0755, 0.1534, 0.9984, 0.8169]]])
tensor([[0.1451, 0.0273, 0.5603, 0.3951],
[0.8981, 0.8639, 0.3545, 0.4461]])
第二步,拆分拼接
s = []
for xx, yy in zip(x, y):
ss = []
for i in xx:
ss.append(torch.cat((i, yy), 0).unsqueeze(0))
print(ss)
ss = torch.cat(ss, dim=0)
s.append(ss.unsqueeze(0))
s
输出如下:
[tensor([[[0.3170, 0.5800, 0.2717, 0.3887, 0.1451, 0.0273, 0.5603, 0.3951],
[0.0862, 0.4881, 0.1419, 0.1491, 0.1451, 0.0273, 0.5603, 0.3951],
[0.1860, 0.4508, 0.2637, 0.9106, 0.1451, 0.0273, 0.5603, 0.3951]]]),
tensor([[[0.0923, 0.1211, 0.8768, 0.7573, 0.8981, 0.8639, 0.3545, 0.4461],
[0.9067, 0.0651, 0.2780, 0.6712, 0.8981, 0.8639, 0.3545, 0.4461],
[0.0755, 0.1534, 0.9984, 0.8169, 0.8981, 0.8639, 0.3545, 0.4461]]])]
这一步完成了每一个 batch
中的拼接,但 batch
之间还是以 list
的方式链接的。
第三步,合成 batch
s = torch.cat(s, dim=0)
s
输出如下:
tensor([[[0.3170, 0.5800, 0.2717, 0.3887, 0.1451, 0.0273, 0.5603, 0.3951],
[0.0862, 0.4881, 0.1419, 0.1491, 0.1451, 0.0273, 0.5603, 0.3951],
[0.1860, 0.4508, 0.2637, 0.9106, 0.1451, 0.0273, 0.5603, 0.3951]],[[0.0923, 0.1211, 0.8768, 0.7573, 0.8981, 0.8639, 0.3545, 0.4461], [0.9067, 0.0651, 0.2780, 0.6712, 0.8981, 0.8639, 0.3545, 0.4461], [0.0755, 0.1534, 0.9984, 0.8169, 0.8981, 0.8639, 0.3545, 0.4461]]])
至此,目标达成!
标签:tensor,0.3545,torch,list,batch,0.4461,0.0273,转为 From: https://www.cnblogs.com/Meloniala/p/17326790.html