首页 > 其他分享 >pytorch 自定义 dataloader 维度不对齐+广播机制导致不易察觉 bug

pytorch 自定义 dataloader 维度不对齐+广播机制导致不易察觉 bug

时间:2024-07-18 15:32:05浏览次数:14  
标签:torch target 自定义 dataloader reduction pytorch 维度 input size

很简单,自定义了一个 dataloader,出现以下不易察觉 bug
inputs 维度:[bs, 4],这个没问题
labels 维度:正确应该是 [bs, 1],但是 dataloader 出来是 [bs]

模型的 outputs 维度:[bs, 1]

如果用 torch.mean(torch.abs(labels - outputs)) 计算 L1 Loss / MAE

由于 pytorch 的广播机制,torch.abs(labels - outputs) 变成了一个 bs * bs 的矩阵,然后计算了这个矩阵的均值,直接变成对比学习了

怎么解决这个问题呢:

在传入 dataloader 时,就令 labels 的维度是 [n, 1],而不是 [n]

怎么避免这个问题呢:

使用 F.l1_loss(),会发现这样的情况会直接报错,因为它会检查两个 tensor 的维度是一致的:

所以在自己手搓 Loss 时,需要检查维度一致,避免广播。

def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
    # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
    r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor

    Function that takes the mean element-wise absolute value difference.

    See :class:`~torch.nn.L1Loss` for details.
    """
    if not torch.jit.is_scripting():
        tens_ops = (input, target)
        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
            return handle_torch_function(
                l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce,
                reduction=reduction)
    if not (target.size() == input.size()):
        warnings.warn("Using a target size ({}) that is different to the input size ({}). "
                      "This will likely lead to incorrect results due to broadcasting. "
                      "Please ensure they have the same size.".format(target.size(), input.size()),
                      stacklevel=2)
    if size_average is not None or reduce is not None:
        reduction = _Reduction.legacy_get_string(size_average, reduce)
    if target.requires_grad:
        ret = torch.abs(input - target)
        if reduction != 'none':
            ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
    else:
        expanded_input, expanded_target = torch.broadcast_tensors(input, target)
        ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
    return ret

标签:torch,target,自定义,dataloader,reduction,pytorch,维度,input,size
From: https://www.cnblogs.com/coldchair/p/18309613

相关文章

  • Qt实现仪表盘-自定义控件
            仪表盘在很多汽车和物联网相关的系统中很常用,本文就来介绍一下Qt 仪表盘的实现示例,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下。一、简述         使用Qt绘制一个仪表盘,用来显示当前的温度,绘制刻度、绘制数字......
  • QT利用QPainter实现自定义圆弧进度条组件
               在可视化应用中,弧形进度条应用也比较广泛,本文示例封装了一个可复用、个性化的弧形进度条组件。本文示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下。主要结构就是外围一圈圆角进度,中间加上标题和对应进度的百分比,进度条的起始角......
  • 2024-07-18 给vue项目添加自定义路由守卫
    要配置路由守卫要使用到vue-router,它是Vue.js官方的路由管理器,主要用于帮助开发者构建单页面应用(SinglePageApplication,简称SPA)。步骤一:新建路由文件,文件名随意,建议叫router.ts,规范一点//router.tsimport{createRouter,createWebHashHistory}from"vue-router";i......
  • 使用Spring Boot AOP和自定义注解优雅实现操作日志记录
    使用SpringBootAOP和自定义注解优雅实现操作日志记录大家好,今天我们来聊聊如何在SpringBoot项目中,通过AOP(面向切面编程)和自定义注解,优雅地实现操作日志记录。操作日志对于系统的可维护性和安全性至关重要,它能帮助我们追踪用户行为,排查问题。什么是AOP?AOP,全称Aspect-Oriented......
  • 自定义转换器
    我们要自定义转换器就要声明一个类,然后继承父类的BaseConverter需要用正则表达式的需要重写父类的regex代码实现: fromflaskimportFlaskfromwerkzeug.routingimportBaseConverterapp=Flask(__name__)classCustomConverter(BaseConverter):#自定义转换器要继承......
  • pytorch学习(四)绘制loss和correct曲线
    这一次学习的时候静态绘制loss和correct曲线,也就是在模型训练完成后,对统计的数据进行绘制。以minist数据训练为例子importtorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTen......
  • Netcode for Entities如何添加自定义序列化,让GhostField支持任意类型?以int3为例(1.2.3
    一句话省流:很麻烦也很抽象,能用内置支持的类型就尽量用。首先看文档。官方文档里一开头就列出了所有内置的支持的类型:GhostTypeTemplates其中Entity类型需要特别注意一下:在同步这个类型的时候,如果是刚刚Instantiate的Ghost(也就是GhostId尚未生效,上一篇文章里说过这个问题),那么客......
  • pytorch|找不到 fbgemm.dll 问题处理
    问题现象运行逻辑:importtorch报错如下:Traceback(mostrecentcalllast):File"C:\scaffold\metasequoia-tyc\ner_address\test_torch.py",line1,in<module>importtorchFile"D:\py\Python310\lib\site-packages\torch\__init__.......
  • windows11 使用pytorch transformers运行Qwen2-0.5B-Instruct模型 (基于anaconda pyth
    吾名爱妃,性好静亦好动。好编程,常沉浸于代码之世界,思维纵横,力求逻辑之严密,算法之精妙。亦爱篮球,驰骋球场,尽享挥洒汗水之乐。且喜跑步,尤钟马拉松,长途奔袭,考验耐力与毅力,每有所进,心甚喜之。 吾以为,编程似布阵,算法如谋略,需精心筹谋,方可成就佳作。篮球乃团队之艺,协作共进,方显力......
  • 关于在vue2中使用LogicFlow自定义节点
    主要参考LogicFlow官方文档在基础流程图搭建起来后,我们想要构建自己的需求风格,例如:那么该如何对节点进行自定义设定呢?文档当中有着详细的解释,本文以实际需求为例大体介绍:import{RectNode,RectNodeModel,h}from"@logicflow/core";classCustomNodeViewextendsR......