首页 > 其他分享 >torch:针对mask掉的位置不进行softmax

torch:针对mask掉的位置不进行softmax

时间:2022-10-19 21:01:25浏览次数:43  
标签:tensor torch mask softmax qk inf

错误方式

希望在进行softmax之前,如果对被mask掉的位置加上一个特别小的数字,那么softmax之后就会变成0。

pad_mask = (1 - doc_token_mask) * (-1999999)  # 把原本0的位置变成一个特别小的数字
qk = qk + pad_mask  # 加到原来的上面去
qk_softmax = torch.softmax(qk, dim=-1)

但是这样有两个问题:

  • 在fp16的情况下,如果自己随便写的数字特别小,会发生inf
  • 在计算梯度的时候,如果是加法,会影响梯度计算。

正确方式

qk = qk.masked_fill_(1-doc_token_mask, -float('inf'))  # 把原本0的位置直接变成一个特别小的数字,而且 -float('inf')和精度无关

测试:

a=torch.tensor([1,2,3,4]).float()
mask=torch.tensor([1,1,0,0])
b=a.masked_fill_(1-mask, -float('inf'))  # tensor([1., 2., -inf, -inf])

torch.softmax(b, dim=0)  # tensor([0.2689, 0.7311, 0.0000, 0.0000])

标签:tensor,torch,mask,softmax,qk,inf
From: https://www.cnblogs.com/carolsun/p/16807755.html

相关文章

  • 《PyTorch深度学习实践》-刘二大人 第三讲
    #梯度下降法frommatplotlibimportpyplotasplt#preparethetrainingsetx_data=[1.0,2.0,3.0]y_data=[2.0,4.0,6.0]#initialguessofweightw=......
  • 《PyTorch深度学习实践》-刘二大人 第二讲
    刘二大人的Pytorch保姆式教程。我觉得算0基础学Pytorch吧,从我现在的基础看就是比较easy的程度,正和我意~课堂练习:importnumpyasnpimportmatplotlib.pyplotasplt......
  • 安装Pytorch
    下面三种需求都是可以尝试的:错误1:AssertionError:TorchnotcompiledwithCUDAenabled错误2:torch.cuda.is_available() 输出false需求3:就是想安装Pytorch 请......
  • pytorch安装gpu版本
    pipinstalltorch==1.8.1+cu111torchvision==0.9.1+cu111torchaudio==0.8.1-fhttps://download.pytorch.org/whl/torch_stable.html-ihttps://pypi.tuna.tsinghua.e......
  • 组件Mask(超出部分隐藏)
    组件使用mask的地方只需要两个组件:自带的Node和渲染组件里的Mask使用组件首先要有一个使用mask组件的图层,设置好宽高。他下面有张图片。层级结构如图:mask_exam的暂定为......
  • 安装pytorch
    https://blog.csdn.net/love_respect/article/details/124681233?spm=1001.2101.3001.6650.1&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7......
  • Pytorch 多卡并行 torch.nn.DistributedDataParallel (DDP)
    PyTorch分布式训练简明教程 (知乎,推荐)PyTorch分布式DPP启动方式(包含完整用例) (csdn) ......
  • 打开pytorch里的jupyter
    解决所有问题后打开Jupyter后发现没有创建的虚拟环境,只有python3。解决办法:打开AnacondaPrompt,进入pytorch环境,输入如下命令:1.激活pytorchactivatepy362.输如,比如......
  • Mac OS安装 pytorch方法
    1、Pytorch介绍PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序它主要由Facebook的人工智能小组开发,不仅能够实现强大的GPU加速,同时还支持动态神......
  • window10系统下Pytorch安装教程
    1、Pytorch介绍PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序它主要由Facebook的人工智能小组开发,不仅能够实现强大的GPU加速,同时还支持动态神......