首页 > 其他分享 >Pytorch mask:上三角和下三角

Pytorch mask:上三角和下三角

时间:2022-12-02 17:07:08浏览次数:41  
标签:torch tensor 三角 mask diagonal Pytorch ones Out


上三角 triu

Pytorch上三角和下三角的调用与numpy是相同的。

np.triu(np.ones((5,5)),k=0) # k控制对角线开始的位置
Out[25]:
array([[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.]])

构建一个上三角mask

torch.triu(torch.ones(5,5),diagonal=0)
Out[17]:
tensor([[1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.]])

​dianonal​​控制上三角的对角线开始位置

torch.triu(torch.ones(5,5),diagonal=1)
Out[20]:
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0.]])

下三角 tril

torch.tril(torch.ones(5,5),diagonal=0)
Out[21]:
tensor([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]])

torch.tril(torch.ones(5,5),diagonal=1)
Out[22]:
tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])


标签:torch,tensor,三角,mask,diagonal,Pytorch,ones,Out
From: https://blog.51cto.com/u_15899958/5907216

相关文章