首页 > 其他分享 >关于深度学习模型不收敛问题解决办法

关于深度学习模型不收敛问题解决办法

时间:2024-11-03 15:00:50浏览次数:3  
标签:解决办法 nn 示例 模型 图像 ReLU 学习 深度 收敛

1. 问题重现

笔者在训练Vgg16网络时出现不收敛问题,具体描述为训练集准确率和测试集准确率一直稳定于某一值,如下图所示。

image

2. 可能的原因

2.1 数据问题

  • 噪声数据。不平衡的数据集、含有噪声或异常值的数据可能导致模型难以学习,尝试更换数据集,出现这种问题比较难办。

  • 数据预处理。确保数据质量,包括数据清洗、标准化、归一化等。
    示例:
    transforms.Normalize(mean=0.4,std=0.1)、
    transforms.Resize(size=(228,228),interpolation=InterpolationMode.BICUBIC)

  • 数据增强。可以使用以下方法扩充数据集。
    随机旋转(RandomRotation):随机旋转图像一定角度。
    随机裁剪(RandomCrop):从图像中随机裁剪出指定大小的子图。
    水平翻转(HorizontalFlip):水平翻转图像。
    垂直翻转(VerticalFlip):垂直翻转图像。
    颜色变换(ColorJitter):随机改变图像的色彩。
    调整亮度(AdjustBrightness):调整图像的亮度。
    调整对比度(AdjustContrast):调整图像的对比度。
    调整饱和度(AdjustSaturation):调整图像的饱和度。

2.2 模型设计问题

  • 更换模型:模型过于复杂或过于简单都可能导致不收敛。读者可以尝试将复杂的模型改为简单的模型,例如下图所示。笔者并不建议直接更换整个网络框架,因为这样工作量太大并且可能会出现新的问题。
    image

  • 简化:简化是指化简模型的某一部分而整体框架不变。★ 例如在Vgg16中(上图左半部分),第一个Convolution+ReLU后输出的特征通道数为64,接着第二个Convolution+ReLU,.....等等,那么是否可以简化网络使每个Convolution+ReLU输出的特征图通道变成32?或16?或8?,这样或许有效。 ★ Vgg16每次降采样后特征图大小都是原来的1/2,那么是否可以采用降采样更大的倍数呢?即使用更大的步长(stride)。

2.3 激活函数问题

不同的模型采用不同的激活函数可能会有不同的效果,若出现不收敛状况后,不妨改变一下激活函数?
例如将Sigmod激活函数改成ReLu?或LeakyReLU?
image

2.4 超参数问题

  • 调整DataLoader的batch_size参数,例如16、32、64、128、256等等,都尝试一下看看有没有效果。试着将DataLoader的shuffle参数改为True?
    示例:
    DataLoader(dataset=data,batch_size=16,shuffle=True,drop_last=False)

  • 学习率过高或过低。★ 笔者经常使用的学习率一般为0.05、0.01、0.005、0.001、0.0005、0.0001,尝试改变试试,或许有效果呢? ★ 尝试使用动态学习率,例如余弦退火学习率CosineAnnealingLR(下图)或其他?
    image
    示例:
    torch.optim.lr_scheduler.CosineAnnealingLR(optimer, T_max=20)

  • 分析学习率。可以结合学习率变化和验证集准确性调整学习率上下界,具体分析如下图所示。
    image

2.5 梯度问题

梯度消失或爆炸是导致模型不收敛的常见原因。通常做法是在卷积层(Convolution)后添加批量归一化(BatchNorm)再添加激活函数(例如ReLU等)。
示例:
nn.Conv2d(in_channels=,out_channels=,kernel_size=,stride=,padding=),
nn.BatchNorm2d(num_features=),
nn.ReLU(inplace=True),

2.6 优化器选择不当

不同的优化算法适用于不同类型的问题,错误的选择可能会阻碍模型的学习过程。改变不同的优化器,例如:

  1. 随机梯度下降法(Stochastic Gradient Descent,SGD)
  2. SGDM(带动量的SGD:SGD with momentum)
  3. 加速梯度(Nesterov Accelerated Gradient,NAG)
  4. 自适应动量优化(Adaptive Moment Estimation,Adam)
  5. ..等等

示例:
torch.optim.Adam(model.parameters(),lr=LEARN_RATE)

2.7 损失函数选择不当

  • 二分类使用nn.BCELoss(),多分类使用nn.CrossEntropyLoss() 等等。
  • 正则化技术。如L1/L2正则化、Dropout等可以防止过拟合。

3. 总结

以上是笔者学习过程中的经验,一般使用一种或多种便可解决不收敛问题,需要活学活用。

标签:解决办法,nn,示例,模型,图像,ReLU,学习,深度,收敛
From: https://www.cnblogs.com/hello-nullptr/p/18523210

相关文章

  • SpringBoot技术栈:在线试题库系统深度开发
    摘要使用旧方法对作业管理信息进行系统化管理已经不再让人们信赖了,把现在的网络信息技术运用在作业管理信息的管理上面可以解决许多信息管理上面的难题,比如处理数据时间很长,数据存在错误不能及时纠正等问题。这次开发的精品在线试题库系统有管理员,教师,学生三个角色。管理......
  • 深度学习模型综述:基础、架构及应用实例(有代码哦~)
    深度学习是机器学习领域的重要分支,基于多层神经网络模拟人类大脑的神经结构,能自动提取数据特征并在图像识别、自然语言处理等任务中取得了出色的成绩。本文将从深度学习的基础、主要模型架构及其典型应用展开,深入探讨深度学习模型的设计、训练与应用。一、深度学习的基本概念......
  • 基于django框架在线图书推荐系统的设计与实现 python个性化图书/书籍/电子书推荐系统
    基于django框架在线图书推荐系统的设计与实现python个性化图书/书籍/电子书推荐系统平均加权混合推荐热门推荐协同过滤算法推荐爬虫排行榜数据可视化分析机器学习深度学习大数据一、项目简介1、开发工具和使用技术Pycharm、Python3及以上版本,Django3.6及以上版......
  • 深度讲解-互联网算法备案指南和教程
    随着人工智能和大数据技术的迅猛发展,互联网算法在内容推荐、用户画像、智能客服等领域发挥着越来越重要的作用。然而,算法的广泛应用也带来了潜在的安全风险和合规挑战。为了规范互联网算法的开发与应用,国家互联网信息办公室等相关部门发布了《互联网算法备案管理规定》,要求具备......
  • 深度学习周报(10.28-11.3)
    目录摘要Abstract1卷积神经网络(ConvolutionaNeuralNetwork,CNN)1.1什么是卷积神经网络1.2感受野(ReceptiveField)1.3参数共享(ParameterSharing)1.4卷积层(ConvolutionalLayer)1.5 池化(Pooling)1.6 CNN整体流程2CNN实例——手写数字识别2.1 数据集......
  • 基于YOLO11/v10/v8/v5深度学习的危险驾驶行为检测识别系统设计与实现【python源码+Pyq
    《博主简介》小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~......
  • 信息抽取与知识图谱在医疗行业的融合:AI技术的深度应用案例
    一、系统概述在现代医疗行业,信息的碎片化与数据孤岛问题日益严重,导致医疗服务和研究效率的降低。思通数科针对这一现状,推出了一款开源免费的信息抽取与知识图谱平台,旨在将医疗数据的深度分析与智能化服务结合起来。二、应用场景在一家大型医疗中心,信息技术部门面临着整合各科......
  • 【C++动态规划】有效括号的嵌套深度
    本文涉及知识点C++动态规划LeetCode1111.有效括号的嵌套深度有效括号字符串定义:对于每个左括号,都能找到与之对应的右括号,反之亦然。详情参见题末「有效括号字符串」部分。嵌套深度depth定义:即有效括号字符串嵌套的层数,depth(A)表示有效括号字符串A的嵌套深度。详......
  • 计算机视觉基石:深度解析数据标注技能与工具
    在计算机视觉的浩瀚宇宙中,数据标注犹如一颗璀璨的星辰,为机器学习模型照亮了前行的道路。它不仅是构建和训练高效视觉模型的基础,更是初学者踏入这一领域的必经之路。本文将带你深入探索数据标注的奥秘,从基础概念到实战工具,再到数据采集的精髓,全方位解析这一关键技能。一、数据......
  • 深度学习(tensorboard使用)
    在做深度学习的时候,尤其是在没有界面的服务器上训练时,可以利用tensorboard工具输出各种曲线或中间图像。下面代码将曲线和图像输出到run目录下临时文件中。fromtensorboardXimportSummaryWriterfromPILimportImageimportnumpyasnpimporttorchvisionimporttorch......