首页 > 其他分享 >在多分类任务实验中手动实现dropout

在多分类任务实验中手动实现dropout

时间:2022-10-24 13:22:51浏览次数:76  
标签:torch num 分类 dropout 手动 train ls test drop

9、在多分类任务实验中手动实现dropout

import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
#读取数据
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./data', train=False,download=False, transform=transforms.ToTensor())  

  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  
#初始化参数  
num_inputs,num_hiddens,num_outputs =784, 256,10
num_epochs=30
lr = 0.001
def init_param():
    W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
    b1 = torch.zeros(1, dtype=torch.float32)  
    W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
    b2 = torch.zeros(1, dtype=torch.float32)  
    params =[W1,b1,W2,b2]
    for param in params:  
        param.requires_grad_(requires_grad=True)  
    return W1,b1,W2,b2
#手动定义dropout函数
def dropout(X, drop_prob): #drop_porb是一个概率值,介于0-1,代表要丢弃多少比例的神经元
    X = X.float()
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape) < keep_prob).float()
    return mask * X / keep_prob
#定义模型
def net(X, is_training=True):
    X = X.view(-1, num_inputs)
    H1 = (torch.matmul(X, W1.t()) + b1).relu()
    if is_training:      #只有在训练模式下才需要dropout,测试情况下不需要
        H1 = dropout(H1, drop_prob1)
    return (torch.matmul(H1,W2.t()) + b2).relu()
#定义训练函数
def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            ls += l.item()
            count += y.shape[0]
        train_ls.append(ls)
        ls, count = 0, 0
        for X,y in test_iter:
            l=loss(net(X,is_training=False),y)
            ls += l.item()
            count += y.shape[0]
        test_ls.append(ls)
        if(epoch+1)%10==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
    return train_ls,test_ls
#定义drop从0到1,训练十次,观察不同drop对结果的影响
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []
#开始训练
for drop_prob in drop_probs:
    drop_prob1 = drop_prob
    W1,b1,W2,b2 = init_param()
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
    train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
#结果可视化
x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

标签:torch,num,分类,dropout,手动,train,ls,test,drop
From: https://www.cnblogs.com/cyberbase/p/16821148.html

相关文章

  • 对二分类模型采用十折交叉验证评估
    14、对二分类模型采用十折交叉验证评估#导入必要的包importtorchimporttorch.nnasnnfromtorch.utils.dataimportTensorDataset,DataLoaderfromtorch.nnimpo......
  • 概念介绍_软件架构和概念介绍_资源分类
    概念介绍_软件架构:web概念概述Javaweb:使用iava语言开发基于互联网的项目软件架构∶1.c/s:client/server客户端/服务器端在用户本地有一个客户端程序,在远程有一个......
  • 【计算机视觉(CV)】基于图像分类网络VGG实现中草药识别(一)
    【计算机视觉(CV)】基于图像分类网络VGG实现中草药识别(一)作者简介:在校大学生一枚,华为云享专家,阿里云专家博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学......
  • 嵌入式-C语言基础:指针是存放变量的地址,那为什么要区分类型?
    指针是存放变量的地址,那为什么要区分类型?不能所有类型的变量都用一个类型吗?下面用一个例子来说明这个问题。#include<stdio.h>intmain(){inta=0x1234;int......
  • 概念介绍-软件架构以及资源分类
    概念介绍-软件架构web概念概述:1.JavaWeb:使用JAVA语言开发基于互联网的项目2.软甲架构:1.C/S:Client/Server客户端/服务器端在用户本地有......
  • 使用KNN进行分类和回归
    一般情况下k-NearestNeighbor(KNN)都是用来解决分类的问题,其实KNN是一种可以应用于数据分类和预测的简单算法,本文中我们将它与简单的线性回归进行比较。KNN模型是一个简......
  • 基于PaddleX的树叶数据集的分类训练与安卓部署
    1.简述本次的任务是针对树叶的图片数据集进行分类。约24694个图片。185个类别。本次是用的是PaddleX的AI快速开发套件和安卓demo部署。2.数据预处理2.1.导入Paddle#......
  • 动手动脑1多层的异常捕获
    1.packagetext;publicclasstext{publicstaticvoidmain(String[]args){try{ try{ thrownewArrayIndexO......
  • mysql锁的分类
    锁的分类按照标准划分:锁:共享锁和排他锁按照加锁范围,锁分为:全局锁、表级锁、行锁。全局锁使用场景:全库逻辑备份。也就是把整库每个表都select出来存成文本。对于支持事......
  • SQL通用语法和SQL分类
    SQL通用语法1、SQL语句可以单行或多行书写,以分号结尾2、可以使用空格和缩进来增强语句的可读性3、MySQL数据库的SQL语句不区分大小写,关键字建议使用大写4、3中注释......