目录
Tensor Broadcast
两个不同形状的矩阵进行(element-by-elemnet)算术运算时,更小维度的矩阵通过Broadcast变成与更大维度矩阵相同的形状。底层通过算法实现,不会进行数据拷贝。
例如:
a = torch.tensor([1, 2, 3])
b = torch.tensor(2)
a * b
tensor([2, 4, 6])
Broadcast规则
从右向左比较Tensor的维度
- 如果维度相同,保持不变
- 如果一个维度为1,进行stretch操作
- 如果维度不同,则运行错误
Broadcast示例
a = torch.tensor([[ 0, 0, 0],
[10, 10, 10],
[20, 20, 20],
[30, 30, 30]])
b = torch.tensor([1, 2, 3])
a + b
tensor([[ 1, 2, 3],
[11, 12, 13],
[21, 22, 23],
[31, 32, 33]])
a = torch.tensor([[ 0], [10], [20], [30]])
b = torch.tensor([1, 2, 3])
a * b
tensor([[ 0, 0, 0],
[10, 20, 30],
[20, 40, 60],
[30, 60, 90]])
维度不同:
PyTorch Broadcast API
torch.broadcast_tensors()
x = torch.arange(3).view(1, 3)
y = torch.arange(2).view(2, 1)
torch.broadcast_tensors(x, y)
(tensor([[0, 1, 2],
[0, 1, 2]]),
tensor([[0, 0, 0],
[1, 1, 1]]))
torch.broadcast_to()
x = torch.arange(3)
torch.broadcast_to(x, (3, 3))
tensor([[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
参考文献
- Numpy Broadcasting
- https://jakevdp.github.io/PythonDataScienceHandbook/02.05-computation-on-arrays-broadcasting.html