首页 > 其他分享 >实验五:全连接神经网络手写数字识别实验

实验五:全连接神经网络手写数字识别实验

时间:2022-11-26 17:25:52浏览次数:31  
标签:loss nn self torch batch 神经网络 train 实验 手写

【实验目的】

理解神经网络原理,掌握神经网络前向推理和后向传播方法;

掌握使用pytorch框架训练和推理全连接神经网络模型的编程实现方法。

 

【实验内容】

1.使用pytorch框架,设计一个全连接神经网络,实现Mnist手写数字字符集的训练与识别。

 

【实验报告要求】

修改神经网络结构,改变层数观察层数对训练和检测时间,准确度等参数的影响;
修改神经网络的学习率,观察对训练和检测效果的影响;
修改神经网络结构,增强或减少神经元的数量,观察对训练的检测效果的影响。

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# 准备数据集
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)


# 设计模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)


model = Net()

# 构建损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# 定义训练函数
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()
        # 前馈+反馈+更新
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d,%5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0


# 定义测试函数
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set:%d %%' % (100 * correct / total))


# 实例化训练和测试
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

【结果】:

 

标签:loss,nn,self,torch,batch,神经网络,train,实验,手写
From: https://www.cnblogs.com/duyidan/p/16927784.html

相关文章

  • 安全编程技术实验四
    随机数生成算法实现及质量测试一、实验目的该实验为设计性实验,实验目的如下:1.学会如何采用软件方式设计和实现一个高质量的随机数生成算法。2.掌握常用的随机数质量测......
  • 安全编程技术实验五
    WindowsCryptoAPI的使用一、实验目的该实验为设计性实验,实验目的如下:1.熟悉WindowsCryptoAPI提供的常用函数接口。2.掌握WindowsCryptoAPI的使用。二、实验内容及步骤......
  • 安全编程技术实验一
    缓冲区溢出一、实验目的该实验为验证性实验,实验目的如下:1、掌握缓冲区溢出的基本原理。2、掌握预防缓冲区溢出的方法,并且在实际编程中严格遵循这些方法。二、实验内容......
  • 安全编程技术实验二
    Windows系统中的访问控制一、实验目的该实验为验证性实验,实验目的如下:1.掌握访问控制列表的基本原理。2.学会通过编程实现更改Windows操作系统中文件或目录的访问控制......
  • 实验5 继承和多态
    1.实验4pets.hpp1#include<iostream>2#include<string>3usingnamespacestd;4classMachinePets{5private:6stringnickname;7......
  • 实验5-类的继承
    1.pets.hpp.1#pragmaonce2#include<iostream>3#include<string>45usingnamespacestd;67classMachinePets{8public:9MachinePets......
  • 实验五:全连接神经网络手写数字识别实验
    |班级链接|https://edu.cnblogs.com/campus/czu/classof2020BigDataClass3-MachineLearning||作业链接|https://edu.cnblogs.com/campus/czu/classof2020BigDataClass3-Ma......
  • 实验四
    任务一#include<stdio.h>#defineN4intmain(){inta[N]={1,9,8,4};charb[N]={'1','9','8','4'};inti;printf("sizeof(int)=%d\n",......
  • oop 实验5 继承和多态
    task1_1程序源码task1_1.cpp1#include<iostream>2#include<map>3usingnamespacestd;4intmain(){5map<int,char>grade_dict{{1,'A'},{2,......
  • 实验5
    实验任务4:pets.hpp:#include<iostream>#include<string.h>usingnamespacestd;classMachinePets{private:stringnickname;public:MachinePets(conststr......