首页 > 其他分享 >pytorch如何批量reshape,如何每batch_size进行reshape

pytorch如何批量reshape,如何每batch_size进行reshape

时间:2022-09-18 10:26:04浏览次数:76  
标签:tensor reshape pytorch batch shape print size

假设我有一个tensor,它的batch_size是2:

tensor = torch.randn([2, 6])
print(tensor.shape)

输出是

torch.Size([2, 6])

其中tensor.shape[0]代表tensor的batch_size
如果我要把其中每个Batch的数据从6转换成[2,3],怎么写?循环遍历tensor然后循环内用reshape吗?不!
看下面的操作,很简单:

tensor = torch.randn([2, 6])
    print(tensor)
    tensor = tensor.reshape(tensor.shape[0], 2, 3)  # 将每个批次的数据转换成2,3的形状
    print(tensor)
    tensor = tensor.reshape(tensor.shape[0], 6)  # 恢复原来的形状
    print(tensor)

输出是:

tensor([[-0.7920, -0.7887, -0.7362,  0.2238,  0.3442,  1.5486],
        [ 1.7589, -0.3414,  0.4499, -0.0228,  0.4032,  0.3730]])
tensor([[[-0.7920, -0.7887, -0.7362],
         [ 0.2238,  0.3442,  1.5486]],

        [[ 1.7589, -0.3414,  0.4499],
         [-0.0228,  0.4032,  0.3730]]])
tensor([[-0.7920, -0.7887, -0.7362,  0.2238,  0.3442,  1.5486],
        [ 1.7589, -0.3414,  0.4499, -0.0228,  0.4032,  0.3730]])

Process finished with exit code 0

标签:tensor,reshape,pytorch,batch,shape,print,size
From: https://www.cnblogs.com/lanhongfu/p/16704288.html

相关文章