首页 > 其他分享 >ResNet18实现手写数字识别

ResNet18实现手写数字识别

时间:2023-07-24 19:13:11浏览次数:38  
标签:ResNet18 nn self put 手写 识别 data model out

  • 项目结构

  •  ResNet18模型搭建
from torch import nn
from torch.nn.functional import relu


class BaseBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(BaseBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out_put = self.conv1(x)
        out_put = relu(self.bn1(out_put))
        out_put = self.conv2(out_put)
        out_put = self.bn2(out_put)
        out_put = relu(out_put + out_put)
        return out_put


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        out_put = self.conv1(x)
        out_put = relu(self.bn1(out_put))
        out_put = self.conv2(out_put)
        out_put = self.bn1(out_put)
        out_put = relu(out_put + extra_x)
        return out_put

  

from torch import nn
from torch.nn.functional import relu
from .base_block import BaseBlock, DownBlock


class RestNEt18(nn.Module):

    def __init__(self, in_channels):
        super(RestNEt18, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 7, 2, 3)
        self.bn1 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = nn.Sequential(BaseBlock(64, 64, 1),
                                    BaseBlock(64, 64, 1))

        self.layer2 = nn.Sequential(DownBlock(64, 128, [2, 1]),
                                    BaseBlock(128, 128, 1))

        self.layer3 = nn.Sequential(DownBlock(128, 256, [2, 1]),
                                    BaseBlock(256, 256, 1))

        self.layer4 = nn.Sequential(DownBlock(256, 512, [2, 1]),
                                    BaseBlock(512, 512, 1))

        # 二位自适应平均池化,根据输入的尺寸以及设定的输出大小计算出输出元素在输入中的感受野,然后再感受野上进行平均池化
        self.agv_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out_put = self.conv1(x)
        # print(out_put.shape)
        out_put = relu(self.bn1(out_put))
        out_put = self.maxpool(out_put)
        # print(out_put.shape)
        out_put = self.layer1(out_put)
        # print(out_put.shape)

        out_put = self.layer2(out_put)
        # print(out_put.shape)

        out_put = self.layer3(out_put)
        # print(out_put.shape)

        out_put = self.layer4(out_put)
        # print(out_put.shape)

        out_put = self.agv_pool(out_put)
        # print(out_put.shape)
        out_put = out_put.reshape(out_put.shape[0], -1)
        # print(out_put.shape)

        out_put = self.fc(out_put)
        # print(out_put.shape)
        return out_put
  •   模型训练脚本
from abc import ABCMeta, abstractmethod


class TrainBase(metaclass=ABCMeta):

    @abstractmethod
    def _get_data_loader(self, is_train=True):
        """
        获取训练数据和测试数据
        :param is_train: 标识测试数据或训练数据
        :return: (Feature, target)
        """
        pass

    @abstractmethod
    def _train(self, my_model, loss_fn, optimizer, train_data):
        """
        训练模型
        :param my_model: 模型
        :param loss_fn: 损失函数
        :param optimizer: 优化器
        :param train_data: 训练数据
        :return:
        """

    @abstractmethod
    def _test(self, my_model, test_data):
        """
        测试模型
        :param my_model: 模型
        :param test_data: 测试数据
        :return:
        """
        pass

    @abstractmethod
    def predict(self, img_path):
        """
        使用模型预测数据
        :param img_path: 图片文件地址
        :return: 预测值
        """
        pass

    def run(self):
        """
        执行深度学习模型训练
        :return:
        """
        pass

  

import os.path

from torchvision import transforms
from torchvision import datasets
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from PIL import Image
from models.resnet.resnet18 import RestNEt18
from Itrain import TrainBase
import numpy as np
import torch


class NumPredict(TrainBase):

    def __init__(self, model_path, epoch=100, lr=0.001):
        # 定义超参数
        self.epoch = epoch
        self.lr = lr
        self.model_path = model_path

        # 固定参数
        self.current_train_step = 0
        self.acc = 0
        self.test_count = 0
        self.writer = None
        self.transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.1307, std=0.3081)
        ])

    def _get_data_loader(self, is_train=True):
        data = datasets.MNIST(root="../datasets", train=is_train, transform=self.transforms, download=True)
        return data

    def _train(self, my_model, loss_fn, optimizer, train_data):
        # 训练模型
        for data in train_data:
            imgs, target = data
            # 计算预测值
            predict = my_model(imgs)
            # 计算损失值
            loss = loss_fn(predict, target)
            # 梯度归0
            optimizer.zero_grad()
            # 反向计算梯度
            loss.backward()
            # 反向更新
            optimizer.step()

            # 输出损失值
            self.current_train_step += 1
            if self.current_train_step % 10 == 0:
                print("训练次数%s  损失值:%s" % (self.current_train_step, loss.item()))
                self.writer.add_scalar("train_loss", loss.item(), global_step=self.current_train_step)

    def _test(self, my_model, test_data):
        # 测试模型
        total_acc = 0
        for data in test_data:
            imgs, target = data
            with torch.no_grad():
                predict = my_model(imgs)
                acc = (predict.argmax(1) == target).sum()
                print("准确率:%s" % str(acc/imgs.shape[0]))
                total_acc += acc

        self.writer.add_scalar("test_acc", total_acc / self.test_count, global_step=self.current_train_step)

    def run(self):
        """

        :return:
        """
        # 准备数据
        train_data = self._get_data_loader()
        train_data = DataLoader(train_data, batch_size=64, shuffle=True)
        test_data = self._get_data_loader(is_train=False)
        self.test_count = len(test_data)
        test_data = DataLoader(test_data, batch_size=1000, shuffle=True)

        # 定义神经网络
        my_model = RestNEt18(in_channels=1)

        # 断点续训
        if os.path.exists(self.model_path):
            my_model = torch.load(self.model_path)

        # 定义损失函数
        loss_fn = nn.CrossEntropyLoss()

        # 定义优化器
        optimizer = optim.SGD(my_model.parameters(), lr=self.lr)

        self.writer = SummaryWriter(log_dir="../logs")
        for item in range(self.epoch):
            print("-----开始第%s轮测试-----" % str(item))
            # 训练模型
            my_model.train()
            self._train(my_model, loss_fn, optimizer, train_data)

            # 测试模型
            my_model.eval()
            self._test(my_model, test_data)

        # 存储模型
        torch.save(my_model, self.model_path)

    def predict(self, img_path):
        """
        使用模型预测数据
        :param img_path: 图片文件地址
        :return: 预测值
        """
        # 读取图片
        img = Image.open(img_path)

        # 图片预处理满足模型需求
        # png图片是四通道转化为3通道
        img = img.convert("RGB")
        # 转为灰度图
        img = img.convert("L")
        # 二值化
        img = img.point(lambda x:  255 if x >= 100 else 0)
        # img.show()
        # 由于训练数据以黑色为底的图片,预测图片为白色为底,需要进行转换
        img = img.point(lambda x: 255-x)
        # img.show()
        # 图片尺寸修改,转化为张量
        img = self.transforms(img)

        # 加载模型
        my_model = torch.load(self.model_path)

        # 预测
        img = torch.reshape(img, (1, 1, 224, 224))
        my_model.eval()
        with torch.no_grad():
            predict = my_model(img)
            print(predict)
            predict = predict.argmax(1)
            print(predict.item())
            return predict


if __name__ == "__main__":
    # 生成模型
    # pth_path = "../pth/num_predict2.pth"
    # num_predict = NumPredict(pth_path, epoch=1)
    # num_predict.run()

    # 预测数据
    img_path = "../img/9.png"
    pth_path = "../pth/num_predict2.pth"
    num_predict = NumPredict(pth_path)
    num_predict.predict(img_path)

  

标签:ResNet18,nn,self,put,手写,识别,data,model,out
From: https://www.cnblogs.com/fuchenjie/p/17578067.html

相关文章

  • 基于KNN近邻分类的情感识别算法matlab仿真
    1.算法理论概述      情感识别是自然语言处理领域中的一个重要研究方向。本文介绍了一种基于KNN近邻分类的情感识别算法,该算法使用词袋模型提取文本特征向量,计算文本特征向量之间的距离,并使用加权投票的方法确定待分类文本的情感类别。本文详细介绍了算法的数学模型和实现......
  • .NET 验证码图片识别
    .NET验证码图片识别流程作为一名经验丰富的开发者,我将向你介绍如何实现".NET验证码图片识别"这一任务。下面是整个流程的步骤:步骤操作1下载验证码图片2预处理图片3图片二值化4分割字符5训练模型6预测验证码现在,让我们逐步详细解释每个步骤需......
  • AI识别检验报告 -PaddleNLP UIE-X 在医疗领域的实战
    目录UIE-X在医疗领域的实战1.项目背景2.案例简介3.环境准备数据转换5.模型微调6.模型评估7.Taskflow一键部署UIE-X在医疗领域的实战PaddleNLP全新发布UIE-X......
  • 基于mnist手写数字数据库识别算法matlab仿真,对比SVM,LDA以及决策树
    1.算法理论概述      基于MNIST手写数字数据库识别算法,对比SVM、LDA以及决策树。首先,我们将介绍MNIST数据库的基本信息和手写数字识别的背景,然后分别介绍SVM、LDA和决策树的基本原理和数学模型,并对比它们在手写数字识别任务中的性能。 1.1、MNIST手写数字数据库   ......
  • 无法将“yarn”项识别为 cmdlet、函数、脚本文件或可运
    如何解决"无法将“yarn”项识别为cmdlet、函数、脚本文件或可运"错误引言作为一名经验丰富的开发者,你可能会遇到一些新手常见的问题。其中一个常见的问题是在使用Yarn(一个流行的包管理工具)时可能会遇到错误:“无法将“yarn”项识别为cmdlet、函数、脚本文件或可运”。这篇文章将......
  • 移动平均线Forexclub这样用,一眼识别买卖信号
    Forexclub建议使用H1时间框架和欧元/美元货币对。在该策略中,Forexclub使用了线性加权移动平均线作为主要指标,同时将其作为一个额外的过滤器。线性加权移动平均线(LWMA)的优势在于,它更重视最近的价格变动,而且长期时间框架几乎没有延迟。此外,Forexclub仅根据MA相对于价格变动的位置来......
  • android 热更新手写框架
    Android热更新手写框架实现流程热更新是指在不修改已安装应用程序的情况下,通过下载差异化的资源文件,实现应用程序的更新。在Android开发中,我们可以手动实现一个热更新框架,使得应用程序能够在不重新安装的情况下更新。下面是实现Android热更新框架的步骤:步骤描述1从服......
  • 手写一个Promise
    Promise背景JavaScript这种单线程事件循环模型,异步行为是为了优化因计算量大而时间长的操作。在JavaScript中我们可以见到很多异步行为,比如计时器、ui渲染、请求数据等等。Promise的主要功能,是为异步代码提供了清晰的抽象,支持优雅地定义和组织异步逻辑。可以用Promise表示异步......
  • Qt(5.8.0)-Cmd模拟(纯手写)
    以下是对上述Qt程序的详细博客,使用Markdown的代码块方式呈现:Qt编程:实现一个简单的命令行窗口Qt是一种跨平台的C++应用程序开发框架,可以用于开发各种类型的应用程序,包括图形界面(GUI)应用程序。本文将介绍如何使用Qt框架实现一个简单的命令行窗口,类似于Windows的运行框,用户可以在......
  • python 识别图片文本 及 位置
    Python识别图片文本及位置在处理图片时,有时候我们需要获取图片中的文本内容,并且知道文本在图片中的位置。Python提供了一些库和工具,可以帮助我们实现这一功能。本文将介绍如何使用Python识别图片中的文本,并获取文本在图片中的位置信息。1.安装依赖库在开始之前,我们需要安装一些......