首页 > 其他分享 >MindSpore手写数字识别体验

MindSpore手写数字识别体验

时间:2022-10-22 10:03:43浏览次数:78  
标签:nn self num MindSpore ms import 手写 识别 mindspore

今天带大家体验一下 MindSpore 这个 AI 框架来完成手写数字识别的任务

1. 环境准备

使用Anaconda创建虚拟环境:

conda create -n mindspore python=3.8

请添加图片描述 创建完成后会显示以下的图像界面 请添加图片描述 这样我们的虚拟环境mindspore就创造完成

2. 安装minspore及其套件

mindspore 的安装可以参考: https://mindspore.cn/install 请添加图片描述

conda install mindspore-cpu=1.8.1 -c mindspore -c conda-forge

验证安装是否成功:

python -c "import mindspore;mindspore.run_check()"

如果输出下方内容,就成功了

MindSpore version: 版本号
The result of multiplication calculation is correct, MindSpore has been installed successfully!

接下来继续安装其他依赖

pip install mindvision jupyterlab

安装 mindvision 是为了使用MindSpore Vision套件,其提供了用于下载并处理MNIST数据集的Mnist模块。 安装jupyterlab使用jupyter-lab来编写程序

安装完成后激活jupyter-lab环境

jupyter-lab

3. 程序撰写

from mindvision.dataset import Mnist

# 下载并处理MNIST数据集
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)

download_eval = Mnist(path="./mnist", split="test", batch_size=32, resize=32, download=True)

dataset_train = download_train.run()
dataset_eval = download_eval.run()

import mindspore.nn as nn

class LeNet5(nn.Cell):
    """
    LeNet-5网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 卷积层,输入的通道数为num_channel,输出的通道数为6,卷积核大小为5*5
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        # 卷积层,输入的通道数为6,输出的通道数为16,卷积核大小为5*5
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        # 全连接层,输入个数为16*5*5,输出个数为120
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        # 全连接层,输入个数为120,输出个数为84
        self.fc2 = nn.Dense(120, 84)
        # 全连接层,输入个数为84,分类的个数为num_class
        self.fc3 = nn.Dense(84, num_class)
        # ReLU激活函数
        self.relu = nn.ReLU()
        # 池化层
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        # 多维数组展平为一维数组
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

network = LeNet5(num_class=10)

import mindspore.nn as nn

# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)

import mindspore as ms

# 设置模型保存参数,模型训练保存参数的step为1875。
config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)

# 应用模型保存参数
ckpoint = ms.ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)

from mindvision.engine.callback import LossMonitor
import mindspore as ms

# 初始化模型参数
model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})

# 训练网络模型,并保存为lenet-1_1875.ckpt文件
model.train(10, dataset_train, callbacks=[ckpoint, LossMonitor(0.01, 1875)])

acc = model.eval(dataset_eval)

print("{}".format(acc))


import mindspore as ms

# 加载已经保存的用于测试的模型
param_dict = ms.load_checkpoint("./lenet/lenet-1_1875.ckpt")
# 加载参数到网络中
ms.load_param_into_net(network, param_dict)

import numpy as np
import mindspore as ms
import matplotlib.pyplot as plt

mnist = Mnist("./mnist", split="train", batch_size=6, resize=32)
dataset_infer = mnist.run()
ds_test = dataset_infer.create_dict_iterator()
data = next(ds_test)
images = data["image"].asnumpy()
labels = data["label"].asnumpy()

plt.figure()
for i in range(1, 7):
    plt.subplot(2, 3, i)
    plt.imshow(images[i-1][0], interpolation="None", cmap="gray")
plt.show()

# 使用函数model.predict预测image对应分类
output = model.predict(ms.Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)

# 输出预测分类与实际分类
print(f'Predicted: "{predicted}", Actual: "{labels}"')

请添加图片描述 程序的输出和实际的一致,说明本次的模型训练和预测是很成功的。

4. 总结

整个流程下来,从模型的设计导训练,整个流程都比较清晰,优化器的设置和参数等的定义都很直观。整体上模型的体验都是不错的,但在jupyter-lab运行时候的warning无法消除,在观感上有点不大好。

标签:nn,self,num,MindSpore,ms,import,手写,识别,mindspore
From: https://blog.51cto.com/CANGYE0504/5785656

相关文章

  • 手写基于Java RMI的RPC框架
    留给读者其中最大的区别就是ZooKeeper注册中心,注册中心可以有读写监听器,这是一个优势,可以用来实现订阅通知,也能做数据的同步,甚至可以做基于读写分离的RPC框架,而且它是基......
  • AI安全帽识别/人脸识别智能分析网关何在EasyCVR配置告警信息推送
    智能分析网关是TSINGSEE青犀视频研发的智能硬件设备,部署了全新嵌入式多算法框架软件,可支持AI视频智能分析功能,包括人脸识别、车辆检测及识别、烟火识别、物体识别、行为识......
  • 华为云OCR文字识别服务,助力智能化疫情防控​
    ​让10000条防疫信息实现秒级录入与核验!​——华为云OCR文字识别服务助力智能化疫情防控​人工智能实现效率千倍提升,为基层抗疫工作减负提效​——华为云实现万条防疫信息秒......
  • 如何精准识别主数据?
    编 辑:彭文华​彭友们好,我是老彭啊。最近有彭友跟我欲言又止,吞吞吐吐不知道要干啥。我明白,他是有问题了。绕了半天,他才问:怎么才能精准识别主数据呢?我一看这个问法,肯定是遇到......
  • Blazor组件自做十 : 光学字符识别 OCR 组件
    光学字符识别OCR组件演示地址https://blazor.app1.es/ocr使用方法手机或者电脑点击拍照OCR可启动相机拍照,或者点击文件OCR选择文件,稍等片刻即可获得OCR结果.......
  • 你真的了解命名实体识别吗
    抛出几个问题,可以先思考一下,如果你都能完美的回答,请和我做朋友。1、命名实体识别如何构造大量有监督数据集?2、命名实体识别的常用解码思路有什么?3、如果命名实体识别任......
  • 用Transformer实现OCR字符识别!
     Datawhale干货 作者:安晟、袁明坤,Datawhale成员在CV领域中,transformer除了分类还能做什么?本文将采用一个单词识别任务数据集,讲解如何使用transformer实现一个简单的OCR文......
  • 张海腾:语音识别实践教程
    作者:张海腾,标贝科技,Datawhale优秀学习者作为智能语音交互相关的从业者,今天以天池学习赛:《零基础入门语音识别:食物声音识别》为例,带大家梳理一些自动语音识别技术(ASR)关的知识......
  • 直播软件开发,手写导航栏切换页面
    直播软件开发,手写导航栏切换页面 <style>  .container{    width:100%;    letter-spacing:1px;  }  .navBox{    width:1200......
  • 手写flat 与 flatMap
    今天又收获一个生产故障,原因是测试过程中在浏览器里测的,浏览器版本较高,然后这个项目是内嵌在客户端里面,客户端内的浏览器版本稍微低一点,不支持flat方法和flatMap方法,所以。......