首页 > 其他分享 >基于 keras-js 快速实现浏览器内的 CNN 手写数字识别

基于 keras-js 快速实现浏览器内的 CNN 手写数字识别

时间:2023-02-01 22:45:16浏览次数:57  
标签:keras train 000 js 252 test CNN 253

https://zhuanlan.zhihu.com/p/33313340

在这篇文章中,我会快速地介绍如何使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。

DEMO 地址:https://starkwang.github.io/keras-js-demo/dist/

 

动图封面  

 


零、准备工作

首先需要给你的电脑安装 keras,具体安装的步骤请参考 keras 官方文档


一、快速入门

首先十分推荐阅读 tensorflow 官方文档中的 MNIST For ML Beginners,这里是极客学院的中文翻译

MNIST 是一个很流行的入门级机器学习/计算机视觉数据集,它包含 0 - 9 的各种手写数字图片:

 

 

每张图片的尺寸均为 28 * 28,用一个 28 * 28 的二维数组来表示,换句话说,每张图片都是由 784 个像素点组成,每个像素点的值在 0 - 255 之间。

比如下面就是一个 "3" 的数据:

(知乎web移动端代码强制换行,简直有毒)

000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 

使用 keras,可以很方便地导入 MNIST 数据集:

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() 

总体来说,我们的想要得到的网络模型,是有一个固定的输入输出的:

  • 输入为一个 28 * 28 的二维整数数组
  • 输出是一个长度为 10 的数组,依次表示 0-9 的可能性(例如如果有一张图片 80% 概率为 1, 20% 概率为 7的话,那么这个数组就是 [0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0]

二、使用 keras 训练网络

我们想要训练的模型,由以下几层网络组成:

  1. 32 个 3x3 卷积核的卷积层
  2. 64 个 3x3 卷积核的卷积层
  3. 采样因子为 (2, 2) 的池化层
  4. Dropout 层
  5. Flatten 层
  6. ReLu 全连接层
  7. Dropout 层
  8. Softmax 全连接层

用 keras 训练一个识别 MNIST 的 CNN 网络非常方便,下面是一个官方给出的例子(源码在此):

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Save model
model.save('myMnistCNN.h5')

如果已经安装好了 keras,直接运行即可:

python mnist_cnn.py 

三、转换输出模型

获得训练好的 .h5 文件之后,模型还不能直接使用,因为我们需要对它进行转编码,keras-js 提供了一个 python 脚本来自动执行:

python ./python/encoder.py -q myMnistCNN.h5 

这个脚本会把 .h5 文件转编码为 keras-js 可读的格式,里面包含了训练好的神经网络的所有模型和参数。


四、使用 keras-js 导入模型

首先需要引入 keras-js,可以通过 script 标签直接引入:

<script src="https://unpkg.com/keras-js"></script> 

也可以通过 npm 安装后使用 webpack 构建引入,参考这里

接下来就可以直接创建一个 Model,keras-js 会自动加载对应的 bin 文件:

const model = new KerasJS.Model({
    filepath: '/path/to/mnist_cnn.bin',
    gpu: true,
    transferLayerOutputs: true
}) 

初始化完毕之后,就可以用于 MNIST 识别了,输入是一个长度为 784 的数组(包含 28*28 各个像素点的灰度值),输出是一个长度为 10 的数组(0-9的概率):

(可以使用上文中给的那个 "3" 的数据范例)

model
  .ready()
  .then(() => {
    // data 是一个长度为 784 的数组,每一项都介于 0 - 255 之间
    // 这里我们需要把数组转换为 Float32 类型
    const inputData = new Float32Array(data)
    // 识别
    return model.predict(inputData)
  })
  .then(outputData => {
    // 输出为 0-9 的概率,例如:
    // { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
  })
  .catch(err => {
    // ...
  })

五、Canvas 实现一个手写板

最后一步就是实现一个手写板,具体的代码就不放上来了,主要就是通过 mousedownmousemovemouseup 事件来绘制图形。

绘制完毕之后,调用 ctx.getImageData,就可以得到 canvas 内的像素数据,每个像素对应四个数值,依次是每个点的 rgba 值,处理之后就可以得到长度为 784 的灰度数组了。然后使用上文提到的 model.predict 即可。

 

动图封面

标签:keras,train,000,js,252,test,CNN,253
From: https://www.cnblogs.com/chinasoft/p/17084356.html

相关文章

  • vite.config.js
    import{defineConfig}from'vite'//动态配置函数import{createVuePlugin}from'vite-plugin-vue2'importvuefrom'@vitejs/plugin-vue';exportdefault()=>......
  • js实现替换对象(json)格式的键名
    某些场景下,我们拿到的键名与预期的键名不符,这个时候就需要替换键名来得到我们想要的内容letobj=[{id:1,title:'zs'},{id:2,title:'l......
  • js防抖函数
    1、使用场景:例如:搜索框搜索输入。只需用户最后一次输入完,再发送请求2、函数防抖的要点:需要一个 setTimeout 来辅助实现,延迟运行需要执行的代码。如果该方法多......
  • 关于node.js
    浏览器是JavaScript的前端运行环境。Node.js是JavaScript的后端运行环境。Node.js中无法调用DOM和BOM等浏览器内置API。基于Express框架(http://www.expres......
  • 【Frida】调试js代码
    方法一attach启动js代码动态注入app,app需要保持运行状态#coding:utf-8importsysimportfridaapp_name="猿人学APP"#app的名字js_file_path="./demo.js"#......
  • json .net 反序列化
    引用链接https://www.cnblogs.com/nice0e3/p/15294585.html#%E5%8F%8D%E5%BA%8F%E5%88%97%E5%8C%96%E6%94%BB%E5%87%BBhttps://www.anquanke.com/post/id/172920#h3-3j......
  • 书城9 - 前后端 json 数据的交互
    解析请求中的json数据,返回json数据1.加入Gson.jar包2.通过输入流读取数据,使用Gson对象解析字符串protectedvoidrequestBodyJSON(HttpServletRequestrequ......
  • 2023年JS学习记录
    2023/1/30星期一https://blog.csdn.net/Augenstern_QXL/article/details/119249534短路运算(逻辑中断)短路运算的原理:当有多个表达式(值)时,左边的表达式值可以确定结果时......
  • 数据交换格式JSON和xml
    数据交换格式,就是服务器端与客户端之间进行数据传输与交换的格式前端领域,经常提及的两种数据交换格式分别是XML和JSON。其中XML用的非常少,所以,我们重点要学习的数据......
  • JSTL常用标签choose和foreach常用标签
    JSTL的常用标签choosechoose相当于java代码中的switch语句完成数字编号对应星期几案例1、域中存储数字2、使用choose标签取出数字 相当于switch声明......