上三角 triu
Pytorch上三角和下三角的调用与numpy是相同的。
np.triu(np.ones((5,5)),k=0) # k控制对角线开始的位置
Out[25]:
array([[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.]])
构建一个上三角mask
torch.triu(torch.ones(5,5),diagonal=0)
Out[17]:
tensor([[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.]])
dianonal
控制上三角的对角线开始位置
torch.triu(torch.ones(5,5),diagonal=1)
Out[20]:
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0.]])
下三角 tril
torch.tril(torch.ones(5,5),diagonal=0)
Out[21]:
tensor([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]])
torch.tril(torch.ones(5,5),diagonal=1)
Out[22]:
tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])