错误方式
希望在进行softmax之前,如果对被mask掉的位置加上一个特别小的数字,那么softmax之后就会变成0。
pad_mask = (1 - doc_token_mask) * (-1999999) # 把原本0的位置变成一个特别小的数字
qk = qk + pad_mask # 加到原来的上面去
qk_softmax = torch.softmax(qk, dim=-1)
但是这样有两个问题:
- 在fp16的情况下,如果自己随便写的数字特别小,会发生inf
- 在计算梯度的时候,如果是加法,会影响梯度计算。
正确方式
qk = qk.masked_fill_(1-doc_token_mask, -float('inf')) # 把原本0的位置直接变成一个特别小的数字,而且 -float('inf')和精度无关
测试:
a=torch.tensor([1,2,3,4]).float()
mask=torch.tensor([1,1,0,0])
b=a.masked_fill_(1-mask, -float('inf')) # tensor([1., 2., -inf, -inf])
torch.softmax(b, dim=0) # tensor([0.2689, 0.7311, 0.0000, 0.0000])
标签:tensor,torch,mask,softmax,qk,inf
From: https://www.cnblogs.com/carolsun/p/16807755.html