首页 > 其他分享 >[深度学习]丢弃法(drop out)

[深度学习]丢弃法(drop out)

时间:2024-04-15 17:55:06浏览次数:11  
标签:torch nn 丢弃 dropout drop num 256 self out

丢弃法(drop out)

一、介绍

1.动机

  • 一个好的模型需要对输入数据的扰动鲁棒
    • 使用有噪音的数据等价于Tikhonov正则
    • 丢弃法:在层之间加入噪音

2.丢弃法的定义

image

这里除以\(1-p\)是为了\(x_i^{'}\)与原来的\(x_i\)的期望相同。

\[ 0\times p + (1-p)\times \dfrac{x_i}{1-p} = x_i \]

3.使用丢弃法

image

其中:

  • \(h\) 为隐藏层
  • \(\sigma\) 为激活函数
  • \(o\) 为输出
  • \(y\) 将 \(o\) 经过 \(softmax\) 层得到分类结果

image

4.总结

image

二、代码部分

1.丢弃法(使用自定义)

实现dropout_layer函数,该函数以dropout的概率丢弃张量输入x中的元素

# 实现dropout_layer函数,该函数以dropout的概率丢弃张量输入x中的元素
import torch
from torch import nn
from d2l import torch as d2l

def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1 # dropout大于等于0,小于等于1,否则报错
    if dropout == 1:
        return torch.zeros_like(X) # 如果dropout为1,则X返回为全0
    if dropout == 0:
        return X # 如果dropout为1,则X返回为全原值
    mask = (torch.rand(X.shape)>dropout).float() # 取X.shape里面0到1之间的均匀分布,如果值大于dropout,则把它选出来
    #print((torch.randn(X.shape)>dropout)) # 返回的是布尔值,然后转布尔值为0、1
    return mask * X / (1.0 - dropout) 

X = torch.arange(16,dtype=torch.float32).reshape((2,8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5)) # 有百分之50的概率变为0
print(dropout_layer(X, 1.))

输出

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0.,  0., 14.],
        [16.,  0.,  0.,  0.,  0.,  0.,  0., 30.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元

# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10 ,256, 256

dropout1, dropout2 = 0.2, 0.5

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training=True):       
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()
        
    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1,self.num_inputs))))
        if self.training == True: # 如果是在训练,则作用dropout,否则则不作用
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            H2 = dropout_layer(H2,dropout2)
        out = self.lin3(H2) # 输出层不作用dropout
        return out
        
net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

# 训练和测试
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(),lr=lr)
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

image

2.丢弃法(使用框架)

import torch
from torch import nn
from d2l import torch as d2l

# 简洁实现
num_epochs, lr, batch_size = 10, 0.5, 256
dropout1, dropout2 = 0.2, 0.5
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

net = nn.Sequential(nn.Flatten(),nn.Linear(784,256),nn.ReLU(),
                   nn.Dropout(dropout1),nn.Linear(256,256),nn.ReLU(),
                   nn.Dropout(dropout2),nn.Linear(256,10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
    
net.apply(init_weights)

trainer = torch.optim.SGD(net.parameters(),lr=lr)
d2l.train_ch3(net,train_iter, test_iter, loss, num_epochs,trainer)

image

标签:torch,nn,丢弃,dropout,drop,num,256,self,out
From: https://www.cnblogs.com/nannandbk/p/18136618

相关文章

  • frp i/o timeout 解决方案
    通过防火墙,开启端口1.安装防火墙安装iptables-services:2.防火墙基本操作查看版本:firewall-cmd--version显示状态:firewall-cmd--state查看所有打开的端口:netstat-anp开启防火墙systemctlstartfirewalld关闭防火墙systemctlstopfirewalld开启防火墙servicefirew......
  • kettle从入门到精通 第五十二课 ETL之kettle Avro output
    1、上一节课我们学习了avroinput,本节课我们一起学习下avroout步骤。本节课通过jsoninput加载json文件,通过avroout生成avro二进制文件,写日志步骤打印日志。将jsoninput、avrooutput、写日志三个步骤拖到画布,然后连线,如下图所示:jsoninput步骤不在过多讲解,不了解的可以学......
  • eBPF指定网口丢弃icmp报文
    安装eBPF依赖#安装编译工具aptinstall-yllvmclang#确认内核具有BTF支持,路径存在,内核没有BTF支持,使用vmlinux.h无法通过编译ls/sys/kernel/btf#生成vmlinux.h#aptinstall-ylinux-tools-genericaptinstall-ylinux-tools-6.5.0-26-genericbpftoolbtfdump......
  • Lock wait timeout exceeded; try restarting transaction 问题分析
    问题描述在项目中有一个MySQL数据库归档程序,每天会定时跑,在归档逻辑中,会涉及到对大表的查询(根据创建时间查询,它是索引),这个过程中会锁数据(行级锁),然后我们插入新的数据就会报错:获取锁超时Causedby:com.mysql.cj.jdbc.exceptions.MySQLTransactionRollbackException:Lockwait......
  • 52 Things: Number 30: Roughly outline the BR security definition for key agreeme
    52Things:Number30:RoughlyoutlinetheBRsecuritydefinitionforkeyagreement52件事:第30件:大致概述密钥协议的BR安全定义 Thisisthelatestinaseriesofblogpoststoaddressthelistof'52ThingsEveryPhDStudentShouldKnowToDoCryptography':a......
  • VS studio上查看标准cout输出
    VSstudio上查看标准cout输出网上的方法在解决方案管理器中,单击选中项目后,点击菜单【视图】->【属性页】在生成事件->生成后事件->命令行(BuildEvents->Post-BuildEvent->Command)Line)中增加$(OutDir)$(ProjectName).exe顾名思义,这个方法是在生成结束后,使用命令行执行生成的......
  • react native layout
    官方文档:https://reactnative.dev/docs/flexbox/#absolute--relative-layout另外一片文档:https://medium.com/wix-engineering/the-full-react-native-layout-cheat-sheet-a4147802405c需要注意的是position的relative的含义:它是先计算没有设定position的时候的位置,然后基于这个......
  • 阿里邮箱网页正常登陆,outlook报错
    事件起因:某客户使用阿里邮箱办公,然又使用outlook绑定阿里邮箱;在网页端可以登录阿里邮箱,但是在outlook的登录的时候,服务器、端口均设置无误,但是就是登录不上去,死活都等登录不上去,总是弹窗让输入账号密码 解决办法:经过多方排查,在网上找寻资料,终于排查出是什么问题了在邮箱后台......
  • About After-school Classes.
    昨天语文大作文,今天英语作文。$\\$Nowadays,childrenalwayshavemanyafter-schoolclasses.MyparentsandIhavedifferentopinionsonthissituation.$\\$Myparentsbelieve,after-schoolclassescanhelpusfindgoodjobsinthefuture.Goodcompanies......
  • 52 Things: Number 13: Outline the use and advantages of projective point represe
    52Things:Number13:Outlinetheuseandadvantagesofprojectivepointrepresentation.52件事:第13件:概述投影点表示的用途和优点。 Thisisthelatestinaseriesofblogpoststoaddressthelistof '52ThingsEveryPhDStudentShouldKnow' todoCryptogr......