1 广播机制介绍
矩阵运算,往往只能在两个矩阵维度相同或者相匹配时才能运算。比如加减法需要两个矩阵的维度相同,乘法需要前一个矩阵的列数与后一个矩阵的行数相等。
当参与运算的两个维度不同也不匹配的矩阵进行运算时,该机制会对数组进行扩展,使数组的shape属性值一样,这样,就可以进行矢量化运算了。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。
2 广播机制的规则
2.1 广播机制的适用规则
两个张量要进行广播运算,需要遵循以下规则:
- 每个 Tensor 至少要有一个维度;
- 遍历 Tensor 所有维度时,从末尾开始遍历(即从右往左开始遍历),两个Tensor 可能存在下列情况:
- Tensor 维度相等;
- Tensor 维度不等且其中一个维度为1;
- Tensor 维度不等且其中一个维度不存在。
2.2 广播机制的使用规则
如果两个 Tensor 的维度不同,则在维度较小的 Tensor 前面(即左侧)增加维度,使它们维度相等。
对于每个维度,计算结果的维度值取两个 Tensor 中较大的那个值。
两个 Tensor 扩展维度的过程是将数值进行复制。
3 广播机制的使用
3.1 创建张量
分别创建一个1行3列和一个3行1列的Tensor。
3.2 判断是否适用广播机制
当执行 x1 + x2 对这两个张量做相加运算时,因其维度不同,需要首先判断是否适用进行广播:
- x1, x2 至少有一个维度,符合条件1;
- x1的维度是(1,3),x2的维度是(3,1) 因此答 Tensor 维度不等且其中一个维度为1 这个条件。由此可见,可以进行广播。
3.3 广播机制推演
当对x1和x2执行相加运算时,
首先取第一行最后一列,x1为3,x2为10,进行相加运算;
再取第二行最后一列时,x1中没有第二行,因此把第一行的数据复制到第二行,此时第二行最后一列为3,x2在该位置为20,进行相加运算;
当取第三行最后一列时,x1中没有第三行,再把第二行的数据复制到第三行,此时第三行最后一列为3,x2在该位置为30,进行相加运算;
然后开始取从右向左数的第二列第一行的数据,x1为2,x2中没有该列,因此把最右侧一列的数据复制到该列,此时x2在该列第一行为10,进行相加运算;
依此类推,最后x1,x2 在运算过程中,都通过广播机制,变成了相同维度的3行3列张量,最后得到的结果也是一个3行3列的张量。
计算结果
3.4 无法使用广播机制的情况
参与运算的 Tensor 维度不等,但是其中一个维度大于1。
参与运算的 Tensor 其中一个没有任何,其中的 c2 没有任何维度。
标签:运算,05,广播,x2,PyTorch,维度,x1,Tensor From: https://blog.51cto.com/u_113754/6215261