首页 > 其他分享 >机器(深度)学习中的 Dropout

机器(深度)学习中的 Dropout

时间:2023-02-10 11:02:08浏览次数:50  
标签:机器 训练 Dropout 网络 神经网络 深度 dropout 神经元

这篇文章中,我将主要讨论神经网络中 dropout 的概念,特别是深度网络,然后进行实验,通过在标准数据集上实施深度网络并查看 dropout 的影响,看看它在实践中实际影响如何。

1. Dropout是什么?

术语“dropout”是指在神经网络中丢弃单元(包括隐藏的和可见的)。

简单来说,dropout 是指随机选择的某组神经元在训练阶段忽略单元(即神经元)。 “忽略”是指在特定的前向或后向传递过程中不考虑这些单元。

详细的就是,在每个训练阶段,单个节点要么以 1-p 的概率退出网络,要么以 p 的概率保留,这样就剩下一个缩小的网络;也删除了到丢弃节点的传入和传出边。

2. 为什么需要Dropout?

鉴于我们对 dropout 有所了解,一个问题出现了——为什么我们需要 dropout?为什么我们需要关闭神经网络的某些部分?

这些问题的答案是“防止过拟合”。

全连接层占据了大部分参数,因此,神经元在训练过程中相互依赖,这抑制了每个神经元的个体能力,导致训练数据过拟合。

3. 重新审视 Dropout

现在我们对 dropout 和动机有了一些了解,让我们来详细了解一下。如果你只是想了解神经网络中的 dropout,那么以上两节就足够了。在本节中,我将涉及更多技术细节。

在机器学习中,正则化是防止过度拟合的方法。正则化通过向损失函数添加惩罚来减少过度拟合。通过添加这个惩罚,模型被训练成不学习相互依赖的特征权重集。了解逻辑回归的人可能熟悉 L1(拉普拉斯)和 L2(高斯)惩罚。

Dropout 是一种神经网络正则化方法,有助于减少神经元之间的相互依赖学习。

4. 训练阶段

训练阶段:对于每个隐藏层,对于每个训练样本,对于每次迭代,忽略(清零)节点(和相应的激活)的随机分数 p。

5. 测试阶段

使用所有激活,但将它们减少一个因子 p(以解决训练期间丢失的激活)。

Srivastava, Nitish, et al.

6. 作用

  1. Dropout 迫使神经网络学习更强大的特征,这些特征与其他神经元的许多不同随机子集结合使用时很有用。
  2. Dropout 使收敛所需的迭代次数加倍。然而,每个时期的训练时间较少。
  3. 有 H 个隐藏单元,每个隐藏单元都可以被丢弃,我们有2^H 个可能的模型。在测试阶段,考虑整个网络,每次激活都减少一个因子 p。

7. 实际效果

让我们在实践中试试这个理论。为了了解 dropout 的工作原理,我在 Keras 中构建了一个深层网络,并尝试在 CIFAR-10 数据集上对其进行验证。构建的深度网络具有三个大小为 64、128 和 256 的卷积层,然后是两个大小为 512 的密集连接层和一个大小为 10 的输出层密集层(CIFAR-10 数据集中的类数)。

我将 ReLU 作为隐藏层的激活函数,将 sigmoid 作为输出层的激活函数(这些是标准,并没有在改变这些方面做太多实验)。另外,我使用了标准的分类交叉熵损失。

最后,我在所有层中使用了 dropout,并将 dropout 的比例从 0.0(根本没有 dropout)增加到 0.9,步长为 0.1,并将每个层运行到 20 个 epoch。结果如下所示:

从上图中我们可以得出结论,随着 dropout 的增加,在趋势开始下降之前,验证准确率有所提高,损失最初有所下降。

如果 dropout fraction 为 0.2,趋势下降可能有两个原因:

  1. 0.2 是此数据集、网络和使用的设置参数的实际最小值
  2. 需要更多的时期来训练网络。

本文由mdnice多平台发布

标签:机器,训练,Dropout,网络,神经网络,深度,dropout,神经元
From: https://www.cnblogs.com/swindler/p/17108195.html

相关文章

  • 297个机器学习彩图知识点(10)
    导读本系列将持续更新20个机器学习的知识点。1.深度学习的动机2.多元逻辑回归3.自然对数4.神经元5.没有免费的午餐6.噪声修正线性单元7.非参数方法......
  • 深度学习笔记——线性回归
    #导入相关包importtensorflowastfimportnumpyasnpimportpandasaspdimportmatplotlib.pyplotasplt%matplotlibinline#读取数据data=pd.read_csv('......
  • 花了半个小时基于 ChatGPT 搭建了一个微信机器人
    相信大家最近被ChatGPT刷屏了,其实在差不多一个月前就火过一次,不会那会好像只在程序员的圈子里面火起来了,并没有被大众认知到,不知道最近是因为什么又火起来了,而且这次搞的......
  • C语言--函数参数深度剖析
    函数定义时参数没有具体值,函数调用时指定参数初始值函数参数在函数内部等同于普通变量在C语言中,数组作为函数参数传递时,大小信息丢失在函数内部修改数组形参,将影响数组......
  • 企业微信集成openai实现ChatGPT机器人
    背景:现在网上查资料,痛点太多了,什么广告,什么重复的,对于程序员的我来说,简直是无语最近接触到ChatGpt,问了些技术问题,答的比某度好,甚至可以写代码,真的太棒了因此想写个专门......
  • 深度复盘-重启 etcd 引发的异常
    作者信息:唐聪、王超凡,腾讯云原生产品中心技术专家,负责腾讯云大规模TKE集群和etcd控制面稳定性、性能和成本优化工作。王子勇,腾讯云专家级工程师,腾讯云计算产品技术......
  • m分别使用Dijkstra算法和Astar算法进行刚体机器人最短路径搜索和避障算法的matlab仿真
    1.算法描述Dijkstra(迪杰斯特拉)算法是典型的最短路径路由算法,用于计算一个节点到其他所有节点的最短路径。主要特点是以起始点为中心向外层层扩展,直到扩展到终点为止(BFS、pr......
  • m分别使用Dijkstra算法和Astar算法进行刚体机器人最短路径搜索和避障算法的matlab仿真
    1.算法描述       Dijkstra(迪杰斯特拉)算法是典型的最短路径路由算法,用于计算一个节点到其他所有节点的最短路径。主要特点是以起始点为中心向外层层扩展,直到扩展到......
  • MYSQL——真实生产环境的数据库机器配置
    摘要介绍真实项目中数据库配置选型,机器压测指标等,以帮助项目设计构建MYSQL集群能够符合你的业务。以下是个人在真实生产环境总结的的相关的参数,仅仅供大家参考。一、生产环......
  • 0基础搭建基于OpenAI的ChatGPT钉钉聊天机器人
    前言:以下文章来源于我去年写的个人公众号。最近chatgpt又开始流行,顺便把原文内容发到博客园上遛一遛。注意事项和指引:注册openai账号,需要有梯子进行访问,最好是欧美国家的......