首页 > 其他分享 >在多分类任务实验中用torch.nn实现dropout

在多分类任务实验中用torch.nn实现dropout

时间:2022-10-24 13:23:06浏览次数:49  
标签:drop num nn dropout torch train ls self

10、在多分类任务实验中用torch.nn实现dropout

#导入必要的包
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
#读取数据
mnist_train = datasets.MNIST(root = './data',train = True,download = False,transform =transforms.ToTensor())
mnist_test = datasets.MNIST(root ='./data',train = False,download = False,transform = transforms.ToTensor())

batch_size = 256
train_iter = DataLoader( 
    dataset = mnist_train,
    shuffle = True,
    batch_size = batch_size,
    num_workers = 0
)
test_iter = DataLoader(
    dataset  = mnist_test,
    shuffle  =False,
    batch_size = batch_size,
    num_workers = 0
)
#定义模型
class LinearNet(nn.Module):
    def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):
        super(LinearNet,self).__init__()
        self.linear1 = nn.Linear(num_inputs,num_hiddens1)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop_prob1) #nn模块封装好了Dropout层,只需要输入dropout值即可
        self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)
        self.drop2 = nn.Dropout(drop_prob2)
        self.linear3 = nn.Linear(num_hiddens2,num_outputs)
        self.flatten  = nn.Flatten()
    
    def forward(self,x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.drop1(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.drop2(x)
        x = self.linear3(x)
        y = self.relu(x)
        return y
#定义训练函数
def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            ls += l.item()
            count += y.shape[0]
        train_ls.append(ls)
        ls, count = 0, 0
        for X,y in test_iter:
            l=loss(net(X),y)
            ls += l.item()
            count += y.shape[0]
        test_ls.append(ls)
        if(epoch+1)%5==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
    return train_ls,test_ls
#初始化参数,定义隐藏层单元个数
num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.1
#drop从0至1,训练十次,观察不同drop对训练结果的影响
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []
#开始训练
for drop_prob in drop_probs:
    net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)
    for param in net.parameters():
        nn.init.normal_(param,mean=0, std= 0.01)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),lr)
    train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
#训练结果可视化
x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

标签:drop,num,nn,dropout,torch,train,ls,self
From: https://www.cnblogs.com/cyberbase/p/16821149.html

相关文章

  • 在多分类任务实验中手动实现dropout
    9、在多分类任务实验中手动实现dropoutimporttorchimporttorch.nnasnnimportnumpyasnpimporttorchvisionimporttorchvision.transformsastransformsimpor......
  • Xshell连接虚拟机的Centos报错Could not connect to 主机地址感慨
    网上搜的一堆方法对我都没用,搜到的几乎都是一致的答案什么改ens33的配置,关防火墙。首先应该明确自己问题出在哪里,比如我就是ping外网显示没有这个名字(也就是失败了,可能......
  • Mysql索引原理揭秘之——MyISAM和InnoDB
    MyISAM引擎的索引实现在MyISAM里面,另外有两个文件,一个是.MYD文件,D代表Data,是MyISAM的数据文件,存放数据记录,比如我们的user_myisam表的所有的表数据;一个是.MYI文件,I代表Inde......
  • Scanner类
    ScannerScannerScanner类可以获取用户的输入基本语法Scanners=newScanner(System.in);通过Scanner类的next()和nextLine()方法来获取输入的字符串,读取之前一般会......
  • RuntimeError: Deterministic behavior was enabled with either `torch.use_determin
    在CORL的代码中,出现了一种error:  可经过如下方法解决:cuda10.1os.environ['CUDA_LAUNCH_BLOCKING']='1'cuda10.2及以上os.environ['CUBLAS_WORKSPACE_CONFIG']......
  • innordb并发基础知识
    说明:以下内容仅仅针对innordb引擎,其他的不一定通用 1、并发中可能存在的问题(1)读读不会有问题(2)读写有脏读、不可重复读、幻读的问题(3)写写有写丢失的问题 2、事......
  • 【EF Core】Data Annotations之ComplexType 复杂类型
    EFCoreCodeFirst代码优先中的复杂类型复杂类型在EF4.1中很容易实现。想象客户实体类有一些像城市,邮政编码和街道的属性,我们发现把这些属性组织成一个叫地址的复杂类......
  • 《PyTorch 深度学习实践 》 刘二大人 第十讲
    课堂练习:1importtorch2fromtorchvisionimporttransforms3fromtorchvisionimportdatasets4fromtorch.utils.dataimportDataLoader5importtorch.......
  • Faster R-CNN理论合集
    FasterR-CNNR-CNN(RegionwithCNNfeature)算法流程​ RCNN算法流程可分为4个步骤一张图像生成1k~2k个候选区域(使用SelectiveSearch方法)对每个候选区域,使用......
  • 《PyTorch深度学习实践》-刘二大人 第九讲
    课堂练习,课后作业不想做了……1importtorch2fromtorchvisionimporttransforms3fromtorchvisionimportdatasets4fromtorch.utils.dataimportDataLoa......