示例:
import torch box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]], [[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32) wh = torch.tensor([[[200.], [400.], [200.], [400.]], [[200.], [400.], [200.], [400.]]]).to(torch.float32) print(box.shape) # (2, 3 ,4) print(wh.shape) # (2, 4, 1) result = box @ wh print(result.shape) # (2, 3, 1) print(result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]])
下面这个示例用到了广播机制:
import torch box = torch.tensor([[[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]], [[0.1000, 0.2000, 0.5000, 0.3000], [0.6000, 0.6000, 0.9000, 0.9000], [0.1000, 0.1000, 0.2000, 0.2000]]]).to(torch.float32) wh = torch.tensor([[[200.], [400.], [200.], [400.]]]).to(torch.float32) print(box.shape) # (2, 3 ,4) print(wh.shape) # (1, 4, 1) 注意这个wh的第0维度的大小是1 result = box @ wh # 这里在第0维度会使用广播机制 print(result.shape) # (2, 3, 1) print(result) # tensor([[[320.], # [900.], # [180.]], # [[320.], # [900.], # [180.]]])
标签:相乘,torch,张量,0.2000,print,pytorch,0.6000,0.9000,0.1000 From: https://www.cnblogs.com/picassooo/p/18503922