矩阵和 numpy.transpose
由
文章
[Transformer源码详解(Pytorch版本) - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/398039366?utm_campaign=shareopn&utm_medium=social&utm_oi=1396930517548257280&utm_psn=1547199426296487936&utm_source=wechat_session)
代码
[harvardnlp/annotated-transformer: An annotated implementation of the Transformer paper. (github.com)](https://github.com/harvardnlp/annotated-transformer)
所引出的问题
部分受启发于 https://www.cnblogs.com/sunshinewang/p/6893503.html
转置有三种方式,transpose
方法、T
属性以及swapaxes
方法。
矩阵
import numpy as np
x = np.arange(24).reshape((2,3,4))
print(x)
print(x.transpose((1,0,2))) # shape(3,2,4)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
#############################################
[[[ 0 1 2 3]
[12 13 14 15]]
[[ 4 5 6 7]
[16 17 18 19]]
[[ 8 9 10 11]
[20 21 22 23]]]
括号上 由外到内的一个层级
不要思考xyz了,直接用012的思路吧,012就是从内到外的一个矩阵层次划分,对应到矩阵表示中也是同理的
转置
主要是 numpy.transpose
主要是考虑角标? 毕竟矩阵表示中的 xyz 对应着每一个坐标的xyz,与数学中强调形状不同,计算机中矩阵的应用更强调于角标的变换,我个人的三维想象能力欠佳,所以只能以角标计算的方式理解代码中的矩阵转置
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = scores.softmax(dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
第四行这个地方我目前倾向于写错了,首先这样的-1,-2的写法也需要满足维度,我用jupyter的实验中,维度不符合是不可以的。
下面是具体实验内容。
(-2,-1)
import numpy as np
x = np.arange(6).reshape(2,3)
print(x)
print("############################")
print(x.transpose(-2,-1))
[[0 1 2]
[3 4 5]]
############################
[[0 1 2]
[3 4 5]]
(-1,-2)
import numpy as np
x = np.arange(6).reshape(2,3)
print(x)
print("############################")
print(x.transpose(-1,-2))
[[0 1 2]
[3 4 5]]
############################
[[0 3]
[1 4]
[2 5]]
另外在代码中的与qkv计算相关的部分
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
标签:transpose,矩阵,np,print,query,numpy
From: https://www.cnblogs.com/CCCarloooo/p/16632623.html