首页 > 其他分享 >1.6.丢弃法

1.6.丢弃法

时间:2024-07-18 21:29:11浏览次数:11  
标签:1.6 nn 丢弃 self num 256 dropout

丢弃法

动机:一个好的模型需要对输入数据的扰动足够健壮,丢弃法就是在层之间加入噪音。也可以在数据中使用噪音,等价与Tikhonov正则

无偏差的加入噪音

​ 对于数据 x x x,加入噪音后的 x ′ x' x′的期望值是不变的, E [ x ′ ] = x E[x']=x E[x′]=x

​ 则我们可以构造出一个简单的期望运算 E [ x ′ ] = p ⋅ 0 + ( 1 − p ) ⋅ x i 1 − p = x i E[x']=p\cdot 0+(1-p)\cdot\frac{x_i}{1-p} =x_i E[x′]=p⋅0+(1−p)⋅1−pxi​​=xi​

​ 那么可以这样处理元素:

在这里插入图片描述

​ 其中丢弃概率是超参数。常用在多层感知机的隐藏层输出上。

通常将丢弃法作用在隐藏全连接层的输出上:
h = σ ( W 1 x + b 1 ) h ′ = d r o p o u t ( h ) o = W 2 h ′ + b 2 y = s o f t m a x ( o ) h=\sigma(W_1x+b_1)\\ h' = dropout(h)\\ o = W_2h' +b_2\\ y=softmax(o) h=σ(W1​x+b1​)h′=dropout(h)o=W2​h′+b2​y=softmax(o)
在这里插入图片描述

​ 如图本来有5个隐藏层,但丢弃函数可能取到0,那么可能会直接消失,剩下的3个隐藏层变大。

​ 丢弃项其实是正则项,只在训练中使用,他们影响模型参数的更新。

​ 在推理过程中,丢弃法直接返回输入 h = d r o p o u t ( h ) h = dropout(h) h=dropout(h),也可以保证确定性的输出

​ 实际上丢弃法的实质是每次训练中使用一个神经网络的子集来做训练, 则多次训练后得到的是多个神经网络的平均,效果自然要好一些。

​ 现在普遍将丢弃项认为是正则项,效果和正则项基本相同。

​ 在输入数据比较简单,但神经网络比较大时,dropout可能会比较有用。

​ dropout1=0.2,dropout2=0.5:

在这里插入图片描述

​ dropout1=0.dropout2=0"

在这里插入图片描述

​ 效果出乎意料的好,说明这个模型本身就没过拟合,这时候使用dropout可能效果不好。一般的小技巧是模型设大一点,然后使用dropout来进行调整。

代码实现

import torch
from torch import nn
from d2l import torch as d2l


def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1  # 丢弃概率必须在0到1之间
    if dropout == 1:
        return torch.zeros_like(X)  # 全0则全部丢弃
    if dropout == 0:
        return X  # 0则不丢弃
    mask = (torch.rand(X.shape) > dropout).float()  # rand生成0到1之间的随机数
    return mask * X / (1.0 - dropout)


num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

# dropout1, dropout2 = 0.2, 0.5
dropout1, dropout2 = 0., 0.


# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元,有三个线性层,最后一个是输出层
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training=True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out


net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

'''简洁实现'''

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

标签:1.6,nn,丢弃,self,num,256,dropout
From: https://blog.csdn.net/shiki217_/article/details/140438913

相关文章

  • 【OCPP】ocpp1.6协议第5.5章节Clear Charging Profile的介绍及翻译
    目录5.5清除充电配置ClearChargingProfile-概述ClearChargingProfile请求ClearChargingProfile响应操作流程适用场景5.5清除充电配置ClearChargingProfile-原文译文5.5清除充电配置ClearChargingProfile-概述OCPP1.6协议中的第5.5章节主要讲的是“Cl......
  • 【OCPP】ocpp1.6协议第5.3章节Change Configuration的介绍及翻译
    目录5.3更改配置Changeconfiguration-概述ChangeConfigurationOperation1.概要2.ChangeConfiguration请求3.ChangeConfiguration响应4.流程说明状态说明举例总结5.3更改配置Changeconfiguration-原文译文5.3更改配置Changeconfiguration-概述在OC......
  • 第三节SHELL脚本中的变量与运算(1.6-1.7.3)
    1,6常见的系统及变量在系统中被预设变量如下变量说明PATH命令的搜索路径,以冒号作为分隔符HOME用户的家目录的路径,是cd命令的默认参数COLUMNS命令行编辑模式下可使用命令的长度HISTFILE命令历史的文件路径HISTFLESIZE命令历史中包含的最大行数HISTSIZEhistory命令输出的......
  • Windows防火墙 日志 自定义 以记录被丢弃的数据包和成功的连接日志。以下是一个示例.r
     配置注册表,以记录被丢弃的数据包和成功的连接日志 WindowsRegistryEditorVersion5.00;WindowsDefender防火墙日志记录设置[HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\SharedAccess\Parameters\FirewallPolicy];以下是针对不同配置文件的设置,例如......
  • 丢弃奇数位置上的元素
    丢弃奇数位置上的元素题目描述给定[1,n]上的连续数字,每次去掉奇数位置上的数,当最后只剩一个数字时,这个数字是多少。题解如图所示:将每个数字转换成二进制表示,可以发现,在第i轮删掉的数字,都是在二进制表示中从右->左,第i个位置为1的数字。因此,留下来的数字,肯定是从左到......
  • java编译时出现错误[ERROR] 不再支持源选项 5。请使用 6 或更高版本。[ERROR] 不再支
    当我引入一个新项目在控制台输入命令mvn  clean install -U,报错出现原因是我们下载了多个java版本,我的电脑上就有1.8和11两个版本,此时只需在引入的pom文件中指定具体的版本即可<maven.compiler.source>11</maven.compiler.source><maven.compiler.target>11</mave......
  • 【OCPP】ocpp1.6协议第4.7章节Meter Values的介绍及翻译
    目录4.7、仪表值MeterValues-概述MeterValues请求消息MeterValues响应消息使用场景1.定期报告2.事务相关报告示例MeterValues请求示例处理MeterValues响应示例代码构建和发送MeterValues请求可能的错误处理总结4.7、仪表值MeterValues-原文译文4.7、......
  • FolkMQ 1.6.0(纯血国产,适合信创)
    FolkMQ是个“新式”的消息中间件。强调:“简而强”。可内嵌,可单机,可集群(部署包为9Mb)。功能简表角色功能生产者(客户端)发布普通消息、Qos0消息、定时消息、顺序消息、可过期消息、事务消息、广播消息消费者(客户端)订阅、取消订阅。消费-ACK(自动、手动)服......
  • Typora1.6.7安装使用教程;附安装包
    一、Typora简介Typora是一款在IT领域使用频率最高的编辑器和阅读器,其界面简洁、操作简单、支持多种Markdown语法,包括代码高亮、流程图、表格、公式等,此外还支持Windows、macOS、Linux等。总的来说,Typora是一款高效、易用、支持多平台的Markdown编辑器,适合技术文档、说明书、个人......
  • 模拟集成电路设计系列博客——7.1.6 多比特SAR ADC
    7.1.6多比特SARADC我们目前讨论的逐次逼近型ADC在每个周期都通过单次的比较将搜索空间一分为二。这个搜索可以通过在每个周期进行多次比较来实现加速,每次将搜索空间切分为更小的区域。例如,如果我们想要猜测一个1到128之间的数时,我们除了提问“这个数是否大于64”,还可以同时提问......