首页 > 其他分享 >深度学习--实战 LeNet5

深度学习--实战 LeNet5

时间:2023-04-24 16:02:50浏览次数:36  
标签:实战 __ LeNet5 nn -- 32 torch label size

深度学习--实战 LeNet5

数据集

数据集选用CIFAR-10的数据集,Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。

模型实现

模型需要继承nn.module

import torch
from torch import  nn


class Lenet5(nn.Module):
    """
    for cifar10 dataset.
    """
    def __init__(self):
        super(Lenet5,self).__init__()

        self.conv_unit = nn.Sequential(
            #input:[b,3,32,32] ===> output:[b,6,x,x]
            #Conv2d(Input_channel:输入的通道数,kernel_channels:卷积核的数量,输出的通道数,kernel_size:卷积核的大小,stride:步长,padding:边缘补足)
            nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),

            #池化
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),

            #卷积层
            nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),

            #池化
            nn.AvgPool2d(kernel_size=2,stride=2,padding=0)

            #output:[b,16,5,5]
        )

        #flatten

        #Linear层
        self.fc_unit=nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )

        #测试卷积输出到全连接层的输入
        #tmp = torch.rand(2,3,32,32)
        #out = self.conv_unit(tmp)
        #print("conv_out:",out.shape)

        #Loss评价  Cross Entropy Loss  分类  在其中包含一个softmax()操作
        #self.criteon = nn.MSELoss()  回归
        #self.criteon = nn.CrossEntropyLoss()

    def forward(self,x):
        """

        :param x:[b,3,32,32]
        :return:
        """
        batchsz = x.size(0)
        #[b,3,32,32]=>[b,16,5,5]
        x = self.conv_unit(x)
        #[b,16,5,5]=>[b,16*5*5]
        x = x.view(batchsz,16*5*5)
        #[b,16*5*5]=>[b,10]
        logits = self.fc_unit(x)

        return logits

        # [b,10]
        # pred = F.softmax(logits,dim=1)  这步在CEL中包含了,所以不需要再写一次
        #loss = self.criteon(logits,y)




def main():
    net = Lenet5()
    tmp = torch.rand(2,3,32,32)
    out = net(tmp)
    print("lenet_out:",out.shape)

if __name__ == '__main__':
    main()

训练与测试

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import  nn,optim

def main():

    batch_size = 32
    epochs = 1000
    learn_rate = 1e-3

    #导入图片,一次只导入一张
    cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)

    #导入图片,一次只导入一张
    cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)

    #iter迭代器,__next__()方法可以获得数据
    x, label = iter(cifer_train).__next__()
    print("x:",x.shape,"label:",label.shape)
    #x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])


    device = torch.device('cuda')
    model = Lenet5().to(device)
    print(model)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=learn_rate)


    for epoch in range(epochs):
        model.train()
        for batchidx,(x,label) in enumerate(cifer_train):
            x,label = x.to(device),label.to(device)

            logits = model(x)
            #logits:[b,10]

            loss = criteon(logits,label)

            #backprop
            optimizer.zero_grad()  #梯度清零
            loss.backward()
            optimizer.step()  #梯度更新
        #
        print(epoch,loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x,label in cifer_test:
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred =logits.argmax(dim=1)

                #[b] vs [b] => scalar tensor
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)

        acc = total_correct/total_num
        print("epoch:",epoch,"acc:",acc)


if __name__ == '__main__':
    main()

标签:实战,__,LeNet5,nn,--,32,torch,label,size
From: https://www.cnblogs.com/ssl-study/p/17349754.html

相关文章

  • 4月20日
    创建Maven项目首先,需要创建一个Maven项目并导入所需的依赖库:<dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><depende......
  • Redis Plus 来了,性能炸裂!
    来源:https://developer.aliyun.com/article/7052391什么是KeyDB?KeyDB是Redis的高性能分支,专注于多线程,内存效率和高吞吐量。除了多线程之外,KeyDB还具有仅在RedisEnterprise中可用的功能,例如ActiveReplication,FLASH存储支持以及一些根本不可用的功能,例如直接备份到AWSS3。Ke......
  • 照猫画虎之WinDbg
    1.在任务管理器导出dmp文件2.使用WinDbg=>File=>OpenCrashDump...选择导出的dmp文件3.使用WinDbg=>File=>SymbolFilePath...输入srv*c:\symbols*http://msdl.microsoft.com/download/symbols;C:\Windows\symbols4. 加载SOS和CLR=> .loadbysosclr5. 加载ntdll=>.......
  • cpu监控
    1、procs进程......
  • Redis持久化机制
    Redis是内存数据库,但一旦服务器宕机,内存中的数据将全部丢失。作为缓存,虽然可以从慢速数据库重新读取数据,但是也会增加慢速数据库压力。所以选择数据持久化方式,避免从后端数据库中进行恢复3种持久化方式AOF:只追加文件(Append-OnlyFile)RDB:快照(snapshotting)RDB和AOF的混......
  • stata sfi.Data举例
    sysuseautopythonfromsfiimportDatadataraw=Data.get('foreign')datarawend//.python//-----------------------------------------------python(typeendtoexit)--------------------------------------------------------------------------......
  • skywalking自定义插件开发
    skywalking是使用字节码操作技术和AOP概念拦截Java类方法的方式来追踪链路的,由于skywalking已经打包了字节码操作技术和链路追踪的上下文传播,因此只需定义拦截点即可。这里以skywalking-8.7.0版本为例。关于插件拦截的原理,可以看我的另一篇文章:skywalking插件工作原理剖析1.......
  • 充电桩测试设备TK4860B非车载充电机检定装置
    充电桩测试设备TK4860系列是专门针对现有交流充电桩现场检测过程中接线复杂、负载笨重、现场检测效率低等问题而研制的一系列高效检测仪器,充电桩测试设备TK4860旨在更好的开展充电桩的强制检定工作。TK4860B是一款在可交流充电桩充电过程中实时检测充电电量的标准仪器,仪器以新能......
  • Linux
    Linux命令TinyMCE编辑器删除rm 文件 查看内容ls查看目录pwd   .当前目录..上一级目录cd/ 切换到顶点who 更换目录cd文件夹绝对路径一切从根目录的路径开始  /opt/zl/1.txt相对路径  ./opt/zl/1.txtmkdir创建文件夹绝对路径 /opt......
  • 把nginx的access_log以json的格式输出
    #在`nginx.conf`中添加如下配置log_formatjsonescape=json'{"@timestamp":"$time_iso8601",''"server_addr":"$server_addr",''"remote_addr":"......